package scala.tools.nsc
package transform
import symtab._
import Flags._
import util.TreeSet
import scala.collection.mutable.{ LinkedHashMap, ListBuffer }
abstract class LambdaLift extends InfoTransform {
import global._
import definitions._
val phaseName: String = "lambdalift"
private val lifted = new TypeMap {
def apply(tp: Type): Type = tp match {
case TypeRef(NoPrefix, sym, Nil) if sym.isClass && !sym.isPackageClass =>
typeRef(apply(sym.owner.enclClass.thisType), sym, Nil)
case ClassInfoType(parents, decls, clazz) =>
val parents1 = parents mapConserve this
if (parents1 eq parents) tp
else ClassInfoType(parents1, decls, clazz)
case _ =>
mapOver(tp)
}
}
def transformInfo(sym: Symbol, tp: Type): Type = lifted(tp)
protected def newTransformer(unit: CompilationUnit): Transformer =
new LambdaLifter(unit)
class LambdaLifter(unit: CompilationUnit) extends explicitOuter.OuterPathTransformer(unit) {
private val free = new LinkedHashMap[Symbol, SymSet]
private val proxies = new LinkedHashMap[Symbol, List[Symbol]]
private val called = new LinkedHashMap[Symbol, SymSet]
private val renamable = newSymSet
private var changedFreeVars: Boolean = _
private val liftedDefs = new LinkedHashMap[Symbol, List[Tree]]
private type SymSet = TreeSet[Symbol]
private def newSymSet = new TreeSet[Symbol](_ isLess _)
private def symSet(f: LinkedHashMap[Symbol, SymSet], sym: Symbol): SymSet =
f.getOrElseUpdate(sym, newSymSet)
private def isSameOwnerEnclosure(sym: Symbol) =
sym.owner.logicallyEnclosingMember == currentOwner.logicallyEnclosingMember
private def markFree(sym: Symbol, enclosure: Symbol): Boolean = {
if (settings.debug.value)
log("mark free: " + sym + " of " + sym.owner + " marked free in " + enclosure)
if (enclosure == sym.owner.logicallyEnclosingMember) true
else if (enclosure.isPackageClass || !markFree(sym, enclosure.skipConstructor.owner.logicallyEnclosingMember)) false
else {
val ss = symSet(free, enclosure)
if (!ss(sym)) {
ss addEntry sym
renamable addEntry sym
atPhase(currentRun.picklerPhase) {
if (sym.isParameter && sym.owner.info.paramss.exists(_ contains sym))
sym.owner.setInfo(sym.owner.info.cloneInfo(sym.owner))
}
changedFreeVars = true
if (settings.debug.value) log("" + sym + " is free in " + enclosure);
if ((sym.isVariable || (sym.isValue && sym.isLazy)) && !sym.hasFlag(CAPTURED)) {
sym setFlag CAPTURED
val symClass = sym.tpe.typeSymbol
atPhase(phase.next) {
sym updateInfo (
if (sym.hasAnnotation(VolatileAttr))
if (isValueClass(symClass)) volatileRefClass(symClass).tpe else VolatileObjectRefClass.tpe
else
if (isValueClass(symClass)) refClass(symClass).tpe else ObjectRefClass.tpe
)
}
}
}
!enclosure.isClass
}
}
private def markCalled(sym: Symbol, owner: Symbol) {
if (settings.debug.value)
log("mark called: " + sym + " of " + sym.owner + " is called by " + owner)
symSet(called, owner) addEntry sym
}
private val freeVarTraverser = new Traverser {
override def traverse(tree: Tree) {
try {
val sym = tree.symbol;
tree match {
case ClassDef(_, _, _, _) =>
liftedDefs(tree.symbol) = Nil
if (sym.isLocal) renamable addEntry sym
case DefDef(_, _, _, _, _, _) =>
if (sym.isLocal) {
renamable addEntry sym
sym setFlag (PRIVATE | LOCAL | FINAL)
} else if (sym.isPrimaryConstructor) {
symSet(called, sym) addEntry sym.owner
}
case Ident(name) =>
if (sym == NoSymbol) {
assert(name == nme.WILDCARD)
} else if (sym.isLocal) {
val owner = currentOwner.logicallyEnclosingMember
if (sym.isTerm && !sym.isMethod) markFree(sym, owner)
else if (sym.isMethod) markCalled(sym, owner)
}
case Select(_, _) =>
if (sym.isConstructor && sym.owner.isLocal)
markCalled(sym, currentOwner.logicallyEnclosingMember)
case _ =>
}
super.traverse(tree)
} catch {
case ex: Throwable =>
Console.println("exception when traversing " + tree)
throw ex
}
}
}
private def computeFreeVars() {
freeVarTraverser.traverse(unit.body)
do {
changedFreeVars = false
for (caller <- called.keys ; callee <- called(caller) ; fvs <- free get callee ; fv <- fvs)
markFree(fv, caller)
} while (changedFreeVars)
for (sym <- renamable) {
val originalName = sym.name
val base = sym.name + "$" + (
if (sym.isAnonymousFunction && sym.owner.isMethod)
sym.owner.name + "$"
else ""
)
sym.name =
if (sym.name.isTypeName) unit.freshTypeName(base)
else unit.freshTermName(base)
if (settings.debug.value)
log("renaming in %s: %s => %s".format(sym.owner.fullLocationString, originalName, sym.name))
}
atPhase(phase.next) {
for ((owner, freeValues) <- free.toList) {
if (settings.debug.value)
log("free var proxy: %s, %s".format(owner.fullLocationString, freeValues.toList.mkString(", ")))
proxies(owner) =
for (fv <- freeValues.toList) yield {
val proxy = owner.newValue(owner.pos, fv.name)
.setFlag(if (owner.isClass) PARAMACCESSOR | PRIVATE | LOCAL else PARAM)
.setFlag(SYNTHETIC)
.setInfo(fv.info);
if (owner.isClass) owner.info.decls enter proxy;
proxy
}
}
}
}
private def proxy(sym: Symbol) = {
def searchIn(searchee: Symbol): Symbol = {
if (settings.debug.value)
log("searching for " + sym + "(" + sym.owner + ") in " + searchee + " " + searchee.logicallyEnclosingMember)
val ps = (proxies get searchee.logicallyEnclosingMember).toList.flatten filter (_.name == sym.name)
if (ps.isEmpty) searchIn(searchee.skipConstructor.owner)
else ps.head
}
if (settings.debug.value)
log("proxy " + sym + " in " + sym.owner + " from " + currentOwner.ownerChain.mkString(" -> ") +
" " + sym.owner.logicallyEnclosingMember)
if (isSameOwnerEnclosure(sym)) sym
else searchIn(currentOwner)
}
private def memberRef(sym: Symbol) = {
val clazz = sym.owner.enclClass
val qual = if (clazz == currentClass) gen.mkAttributedThis(clazz)
else {
sym resetFlag(LOCAL | PRIVATE)
if (clazz.isStaticOwner) gen.mkAttributedQualifier(clazz.thisType)
else outerPath(outerValue, currentClass.outerClass, clazz)
}
Select(qual, sym) setType sym.tpe
}
private def proxyRef(sym: Symbol) = {
val psym = proxy(sym)
if (psym.isLocal) gen.mkAttributedIdent(psym) else memberRef(psym)
}
private def addFreeArgs(pos: Position, sym: Symbol, args: List[Tree]) = {
free get sym match {
case Some(fvs) => args ++ (fvs.toList map (fv => atPos(pos)(proxyRef(fv))))
case _ => args
}
}
private def addFreeParams(tree: Tree, sym: Symbol): Tree = proxies.get(sym) match {
case Some(ps) =>
val freeParams = ps map (p => ValDef(p) setPos tree.pos setType NoType)
tree match {
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
val addParams = cloneSymbols(ps).map(_.setFlag(PARAM))
sym.updateInfo(
lifted(MethodType(sym.info.params ::: addParams, sym.info.resultType)))
treeCopy.DefDef(tree, mods, name, tparams, List(vparamss.head ++ freeParams), tpt, rhs)
case ClassDef(mods, name, tparams, impl @ Template(parents, self, body)) =>
treeCopy.ClassDef(tree, mods, name, tparams,
treeCopy.Template(impl, parents, self, body ::: freeParams))
}
case None =>
tree
}
private def liftDef(tree: Tree): Tree = {
val sym = tree.symbol
val oldOwner = sym.owner
if (sym.owner.isAuxiliaryConstructor && sym.isMethod)
sym setFlag STATIC
sym.owner = sym.owner.enclClass
if (sym.isClass) sym.owner = sym.owner.toInterface
if (sym.isMethod) sym setFlag LIFTED
liftedDefs(sym.owner) ::= tree
sym.owner.info.decls enterUnique sym
if (settings.debug.value) log("lifted: " + sym + " from " + oldOwner + " to " + sym.owner)
EmptyTree
}
private def postTransform(tree: Tree): Tree = {
val sym = tree.symbol
tree match {
case ClassDef(_, _, _, _) =>
val tree1 = addFreeParams(tree, sym)
if (sym.isLocal) liftDef(tree1) else tree1
case DefDef(_, _, _, _, _, _) =>
val tree1 = addFreeParams(tree, sym)
if (sym.isLocal) liftDef(tree1) else tree1
case ValDef(mods, name, tpt, rhs) =>
if (sym.isCapturedVariable) {
val tpt1 = TypeTree(sym.tpe) setPos tpt.pos
val constructorArg = rhs match {
case EmptyTree =>
sym.primaryConstructor.info.paramTypes match {
case List(tp) => gen.mkZero(tp)
case _ =>
log("Couldn't determine how to properly construct " + sym)
rhs
}
case arg => arg
}
treeCopy.ValDef(tree, mods, name, tpt1, typer.typedPos(rhs.pos) {
Apply(Select(New(TypeTree(sym.tpe)), nme.CONSTRUCTOR), List(constructorArg))
})
} else tree
case Return(Block(stats, value)) =>
Block(stats, treeCopy.Return(tree, value)) setType tree.tpe setPos tree.pos
case Return(expr) =>
assert(sym == currentMethod, sym)
tree
case Apply(fn, args) =>
treeCopy.Apply(tree, fn, addFreeArgs(tree.pos, sym, args))
case Assign(Apply(TypeApply(sel @ Select(qual, _), _), List()), rhs) =>
assert(sel.symbol == Object_asInstanceOf)
treeCopy.Assign(tree, qual, rhs)
case Ident(name) =>
val tree1 =
if (sym != NoSymbol && sym.isTerm && !sym.isLabel)
if (sym.isMethod)
atPos(tree.pos)(memberRef(sym))
else if (sym.isLocal && !isSameOwnerEnclosure(sym))
atPos(tree.pos)(proxyRef(sym))
else tree
else tree
if (sym.isCapturedVariable)
atPos(tree.pos) {
val tp = tree.tpe
val elemTree = typer typed Select(tree1 setType sym.tpe, nme.elem)
if (elemTree.tpe.typeSymbol != tp.typeSymbol) gen.mkAttributedCast(elemTree, tp) else elemTree
}
else tree1
case Block(stats, expr0) =>
val (lzyVals, rest) = stats.partition {
case stat@ValDef(_, _, _, _) if stat.symbol.isLazy => true
case stat@ValDef(_, _, _, _) if stat.symbol.hasFlag(MODULEVAR) => true
case _ => false
}
treeCopy.Block(tree, lzyVals:::rest, expr0)
case _ =>
tree
}
}
override def transform(tree: Tree): Tree =
postTransform(super.transform(tree) setType lifted(tree.tpe))
override def transformStats(stats: List[Tree], exprOwner: Symbol): List[Tree] = {
def addLifted(stat: Tree): Tree = stat match {
case ClassDef(mods, name, tparams, impl @ Template(parents, self, body)) =>
val lifted = liftedDefs get stat.symbol match {
case Some(xs) => xs reverseMap addLifted
case _ => log("unexpectedly no lifted defs for " + stat.symbol) ; Nil
}
val result = treeCopy.ClassDef(
stat, mods, name, tparams, treeCopy.Template(impl, parents, self, body ::: lifted))
liftedDefs -= stat.symbol
result
case DefDef(mods, name, tp, vp, tpt, Block(Nil, expr)) if !stat.symbol.isConstructor =>
treeCopy.DefDef(stat, mods, name, tp, vp, tpt, expr)
case _ =>
stat
}
super.transformStats(stats, exprOwner) map addLifted
}
override def transformUnit(unit: CompilationUnit) {
computeFreeVars
atPhase(phase.next)(super.transformUnit(unit))
assert(liftedDefs.isEmpty, liftedDefs.keys mkString ", ")
}
}
}