package scala.tools.nsc
package transform
import symtab.Flags
import Flags.SYNTHETIC
abstract class TailCalls extends Transform {
import global._
import definitions._
import typer.{ typed, typedPos }
val phaseName: String = "tailcalls"
def newTransformer(unit: CompilationUnit): Transformer =
new TailCallElimination(unit)
override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev)
class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) {
def apply(unit: global.CompilationUnit) {
if (!(settings.debuginfo.value == "notailcalls")) {
newTransformer(unit).transformUnit(unit);
}
}
}
class TailCallElimination(unit: CompilationUnit) extends Transformer {
private val defaultReason = "it contains a recursive call not in tail position"
class Context() {
var method: Symbol = NoSymbol
var label: Symbol = NoSymbol
var tparams: List[Symbol] = Nil
var tailPos = false
var failReason = defaultReason
var failPos = method.pos
var accessed = false
def this(that: Context) = {
this()
this.method = that.method
this.tparams = that.tparams
this.tailPos = that.tailPos
this.accessed = that.accessed
this.failPos = that.failPos
this.label = that.label
}
def this(dd: DefDef) {
this()
this.method = dd.symbol
this.tparams = dd.tparams map (_.symbol)
this.tailPos = true
this.accessed = false
this.failPos = dd.pos
this.label = {
val label = method.newLabel(method.pos, "_" + method.name)
val thisParam = method.newSyntheticValueParam(currentClass.typeOfThis)
label setInfo MethodType(thisParam :: method.tpe.params, method.tpe.finalResultType)
}
if (isEligible)
label setInfo label.tpe.substSym(method.tpe.typeParams, tparams)
}
def enclosingType = method.enclClass.typeOfThis
def methodTypeParams = method.tpe.typeParams
def isEligible = method.isEffectivelyFinal
def isMandatory = method.hasAnnotation(TailrecClass) && !forMSIL
def isTransformed = isEligible && accessed
def tailrecFailure() = unit.error(failPos, "could not optimize @tailrec annotated " + method + ": " + failReason)
def newThis(pos: Position) = method.newValue(pos, nme.THIS) setInfo currentClass.typeOfThis setFlag SYNTHETIC
override def toString(): String = (
"" + method.name + " tparams: " + tparams + " tailPos: " + tailPos +
" accessed: " + accessed + "\nLabel: " + label + "\nLabel type: " + label.info
)
}
private var ctx: Context = new Context()
private def noTailContext() = {
val t = new Context(ctx)
t.tailPos = false
t
}
def transform(tree: Tree, nctx: Context): Tree = {
val saved = ctx
ctx = nctx
try transform(tree)
finally this.ctx = saved
}
def noTailTransform(tree: Tree): Tree = transform(tree, noTailContext())
def noTailTransforms(trees: List[Tree]) = {
val nctx = noTailContext()
trees map (t => transform(t, nctx))
}
override def transform(tree: Tree): Tree = {
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree]) = {
val receiver: Tree = fun match {
case Select(qual, _) => qual
case _ => EmptyTree
}
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
def transformArgs = noTailTransforms(args)
def matchesTypeArgs = ctx.tparams sameElements (targs map (_.tpe.typeSymbol))
def fail(reason: String) = {
if (settings.debug.value)
log("Cannot rewrite recursive call at: " + fun.pos + " because: " + reason)
ctx.failReason = reason
treeCopy.Apply(tree, target, transformArgs)
}
def failHere(reason: String) = {
ctx.failPos = fun.pos
fail(reason)
}
def rewriteTailCall(recv: Tree): Tree = {
log("Rewriting tail recursive method call at: " + fun.pos)
ctx.accessed = true
typedPos(fun.pos)(Apply(Ident(ctx.label), recv :: transformArgs))
}
if (!ctx.isEligible) fail("it is neither private nor final so can be overridden")
else if (!isRecursiveCall) {
if (receiverIsSuper) failHere("it contains a recursive call targetting a supertype")
else failHere(defaultReason)
}
else if (!matchesTypeArgs) failHere("it is called recursively with different type arguments")
else if (receiver == EmptyTree) rewriteTailCall(This(currentClass))
else if (forMSIL) fail("it cannot be optimized on MSIL")
else if (!receiverIsSame) failHere("it changes type of 'this' on a polymorphic recursive call")
else rewriteTailCall(receiver)
}
tree match {
case dd @ DefDef(mods, name, tparams, vparams, tpt, rhs) =>
log("Entering DefDef: " + name)
val newCtx = new Context(dd)
log("Considering " + name + " for tailcalls")
val newRHS = transform(rhs, newCtx)
treeCopy.DefDef(tree, mods, name, tparams, vparams, tpt, {
if (newCtx.isTransformed) {
if (newCtx.isMandatory) {
for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.method) {
newCtx.failPos = t.pos
newCtx.tailrecFailure()
}
}
val newThis = newCtx.newThis(tree.pos)
val vpSyms = vparams.flatten map (_.symbol)
typedPos(tree.pos)(Block(
List(ValDef(newThis, This(currentClass))),
LabelDef(newCtx.label, newThis :: vpSyms, newRHS)
))
}
else {
if (newCtx.isMandatory)
newCtx.tailrecFailure()
newRHS
}
})
case Block(stats, expr) =>
treeCopy.Block(tree,
noTailTransforms(stats),
transform(expr)
)
case CaseDef(pat, guard, body) =>
treeCopy.CaseDef(tree,
pat,
guard,
transform(body)
)
case If(cond, thenp, elsep) =>
treeCopy.If(tree,
cond,
transform(thenp),
transform(elsep)
)
case Match(selector, cases) =>
treeCopy.Match(tree,
noTailTransform(selector),
transformTrees(cases).asInstanceOf[List[CaseDef]]
)
case Try(block, catches, finalizer) =>
treeCopy.Try(tree,
noTailTransform(block),
noTailTransforms(catches).asInstanceOf[List[CaseDef]],
noTailTransform(finalizer)
)
case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
case Apply(fun, args) =>
if (fun.symbol == Boolean_or || fun.symbol == Boolean_and)
treeCopy.Apply(tree, fun, transformTrees(args))
else
rewriteApply(fun, fun, Nil, args)
case Alternative(_) | Star(_) | Bind(_, _) =>
sys.error("We should've never gotten inside a pattern")
case EmptyTree | Super(_, _) | This(_) | Select(_, _) | Ident(_) | Literal(_) | Function(_, _) | TypeTree() =>
tree
case _ =>
super.transform(tree)
}
}
}
}