package scala.tools.nsc
package transform
import symtab.Flags._
import scala.collection.{ mutable, immutable }
abstract class UnCurry extends InfoTransform with TypingTransformers with ast.TreeDSL {
import global._
import definitions._
import CODE._
val phaseName: String = "uncurry"
def newTransformer(unit: CompilationUnit): Transformer = new UnCurryTransformer(unit)
override def changesBaseClasses = false
private def expandAlias(tp: Type): Type = if (!tp.isHigherKinded) tp.normalize else tp
private def isUnboundedGeneric(tp: Type) = tp match {
case t @ TypeRef(_, sym, _) => sym.isAbstractType && !(t <:< AnyRefClass.tpe)
case _ => false
}
private val uncurry: TypeMap = new TypeMap {
def apply(tp0: Type): Type = {
val tp = expandAlias(tp0)
tp match {
case MethodType(params, MethodType(params1, restpe)) =>
apply(MethodType(params ::: params1, restpe))
case MethodType(params, ExistentialType(tparams, restpe @ MethodType(_, _))) =>
assert(false, "unexpected curried method types with intervening existential")
tp0
case MethodType(h :: t, restpe) if h.isImplicit =>
apply(MethodType(h.cloneSymbol.resetFlag(IMPLICIT) :: t, restpe))
case NullaryMethodType(restpe) =>
apply(MethodType(List(), restpe))
case TypeRef(pre, ByNameParamClass, List(arg)) =>
apply(functionType(List(), arg))
case TypeRef(pre, RepeatedParamClass, args) =>
apply(appliedType(SeqClass.typeConstructor, args))
case TypeRef(pre, JavaRepeatedParamClass, args) =>
apply(arrayType(
if (isUnboundedGeneric(args.head)) ObjectClass.tpe else args.head))
case _ =>
expandAlias(mapOver(tp))
}
}
}
private val uncurryType = new TypeMap {
def apply(tp0: Type): Type = {
val tp = expandAlias(tp0)
tp match {
case ClassInfoType(parents, decls, clazz) =>
val parents1 = parents mapConserve uncurry
if (parents1 eq parents) tp
else ClassInfoType(parents1, decls, clazz)
case PolyType(_, _) =>
mapOver(tp)
case _ =>
tp
}
}
}
def transformInfo(sym: Symbol, tp: Type): Type =
if (sym.isType) uncurryType(tp) else uncurry(tp)
private object lookForReturns extends Traverser {
var returnFound = false
override def traverse(tree: Tree): Unit = tree match {
case Return(_) => returnFound = true
case DefDef(_, _, _, _, _, _) => ;
case _ => super.traverse(tree)
}
def found(tree: Tree) = {
returnFound = false
traverse(tree)
returnFound
}
}
class UnCurryTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
private var needTryLift = false
private var inPattern = false
private var inConstructorFlag = 0L
private val byNameArgs = new mutable.HashSet[Tree]
private val noApply = new mutable.HashSet[Tree]
private val newMembers = mutable.ArrayBuffer[Tree]()
private val repeatedParams = mutable.Map[Symbol, List[ValDef]]()
private lazy val serialVersionUIDAnnotation =
AnnotationInfo(SerialVersionUIDAttr.tpe, List(Literal(Constant(0))), List())
override def transformUnit(unit: CompilationUnit) {
freeMutableVars.clear()
freeLocalsTraverser(unit.body)
super.transformUnit(unit)
}
private var nprinted = 0
override def transform(tree: Tree): Tree = try {
postTransform(mainTransform(tree))
} catch {
case ex: Throwable =>
if (nprinted < 10) {
Console.println("exception when traversing " + tree)
nprinted += 1
}
throw ex
}
def isByNameRef(tree: Tree): Boolean =
tree.isTerm && tree.hasSymbol &&
isByNameParamType(tree.symbol.tpe) &&
!byNameArgs(tree)
def uncurryTreeType(tp: Type): Type = tp match {
case MethodType(params, MethodType(params1, restpe)) if inPattern =>
uncurryTreeType(MethodType(params, restpe))
case _ =>
uncurry(tp)
}
private def nonLocalReturnExceptionType(argtype: Type) =
appliedType(NonLocalReturnControlClass.typeConstructor, List(argtype))
private val nonLocalReturnKeys = new mutable.HashMap[Symbol, Symbol]
private def nonLocalReturnKey(meth: Symbol) =
nonLocalReturnKeys.getOrElseUpdate(meth, {
meth.newValue(meth.pos, unit.freshTermName("nonLocalReturnKey"))
.setFlag (SYNTHETIC)
.setInfo (ObjectClass.tpe)
})
private def nonLocalReturnThrow(expr: Tree, meth: Symbol) =
localTyper.typed {
Throw(
New(
TypeTree(nonLocalReturnExceptionType(expr.tpe)),
List(List(Ident(nonLocalReturnKey(meth)), expr))))
}
private def nonLocalReturnTry(body: Tree, key: Symbol, meth: Symbol) = {
localTyper.typed {
val extpe = nonLocalReturnExceptionType(meth.tpe.finalResultType)
val ex = meth.newValue(body.pos, nme.ex) setInfo extpe
val pat = Bind(ex,
Typed(Ident(nme.WILDCARD),
AppliedTypeTree(Ident(NonLocalReturnControlClass),
List(Bind(tpnme.WILDCARD,
EmptyTree)))))
val rhs =
If(
Apply(
Select(
Apply(Select(Ident(ex), "key"), List()),
Object_eq),
List(Ident(key))),
Apply(
TypeApply(
Select(
Apply(Select(Ident(ex), "value"), List()),
Any_asInstanceOf),
List(TypeTree(meth.tpe.finalResultType))),
List()),
Throw(Ident(ex)))
val keyDef = ValDef(key, New(TypeTree(ObjectClass.tpe), List(List())))
val tryCatch = Try(body, List(CaseDef(pat, EmptyTree, rhs)), EmptyTree)
Block(List(keyDef), tryCatch)
}
}
def deEta(fun: Function): Tree = fun match {
case Function(List(), Apply(expr, List())) if treeInfo.isPureExpr(expr) =>
if (expr hasSymbolWhich (_.isLazy))
fun
else
expr
case Function(List(), expr) if isByNameRef(expr) =>
noApply += expr
expr
case _ =>
fun
}
def transformFunction(fun: Function): Tree = {
val fun1 = deEta(fun)
def owner = fun.symbol.owner
def targs = fun.tpe.typeArgs
def isPartial = fun.tpe.typeSymbol == PartialFunctionClass
if (fun1 ne fun) fun1
else {
val (formals, restpe) = (targs.init, targs.last)
val anonClass = owner newAnonymousFunctionClass fun.pos setFlag (FINAL | SYNTHETIC | inConstructorFlag)
def parents =
if (isFunctionType(fun.tpe)) List(abstractFunctionForFunctionType(fun.tpe), SerializableClass.tpe)
else List(ObjectClass.tpe, fun.tpe, SerializableClass.tpe)
anonClass setInfo ClassInfoType(parents, new Scope, anonClass)
val applyMethod = anonClass.newMethod(fun.pos, nme.apply) setFlag FINAL
applyMethod setInfo MethodType(applyMethod newSyntheticValueParams formals, restpe)
anonClass.info.decls enter applyMethod
anonClass.addAnnotation(serialVersionUIDAnnotation)
fun.vparams foreach (_.symbol.owner = applyMethod)
new ChangeOwnerTraverser(fun.symbol, applyMethod) traverse fun.body
def mkUnchecked(tree: Tree) = {
def newUnchecked(expr: Tree) = Annotated(New(gen.scalaDot(UncheckedClass.name), List(Nil)), expr)
tree match {
case Match(selector, cases) => atPos(tree.pos) { Match(newUnchecked(selector), cases) }
case _ => tree
}
}
def applyMethodDef() = {
val body = if (isPartial) mkUnchecked(fun.body) else fun.body
DefDef(Modifiers(FINAL), nme.apply, Nil, List(fun.vparams), TypeTree(restpe), body) setSymbol applyMethod
}
def isDefinedAtMethodDef() = {
val m = anonClass.newMethod(fun.pos, nme.isDefinedAt) setFlag FINAL
m setInfo MethodType(m newSyntheticValueParams formals, BooleanClass.tpe)
anonClass.info.decls enter m
val Match(selector, cases) = fun.body
val vparam = fun.vparams.head.symbol
val idparam = m.paramss.head.head
val substParam = new TreeSymSubstituter(List(vparam), List(idparam))
def substTree[T <: Tree](t: T): T = substParam(resetLocalAttrs(t))
def transformCase(cdef: CaseDef): CaseDef =
substTree(CaseDef(cdef.pat.duplicate, cdef.guard.duplicate, Literal(true)))
def defaultCase = CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(false))
DefDef(m, mkUnchecked(
if (cases exists treeInfo.isDefaultCase) Literal(true)
else Match(substTree(selector.duplicate), (cases map transformCase) :+ defaultCase)
))
}
val members =
if (isPartial) List(applyMethodDef, isDefinedAtMethodDef)
else List(applyMethodDef)
localTyper.typedPos(fun.pos) {
Block(
List(ClassDef(anonClass, NoMods, List(List()), List(List()), members, fun.pos)),
Typed(
New(TypeTree(anonClass.tpe), List(List())),
TypeTree(fun.tpe)))
}
}
}
def transformArgs(pos: Position, fun: Symbol, args: List[Tree], formals: List[Type]) = {
val isJava = fun.isJavaDefined
def transformVarargs(varargsElemType: Type) = {
def mkArrayValue(ts: List[Tree], elemtp: Type) =
ArrayValue(TypeTree(elemtp), ts) setType arrayType(elemtp)
def arrayToSequence(tree: Tree, elemtp: Type) = {
atPhase(phase.next) {
localTyper.typedPos(pos) {
val pt = arrayType(elemtp)
val adaptedTree =
if (tree.tpe <:< pt) tree
else gen.mkCastArray(tree, elemtp, pt)
gen.mkWrapArray(adaptedTree, elemtp)
}
}
}
def sequenceToArray(tree: Tree) = {
val toArraySym = tree.tpe member nme.toArray
assert(toArraySym != NoSymbol)
def getManifest(tp: Type): Tree = {
val manifestOpt = localTyper.findManifest(tp, false)
if (!manifestOpt.tree.isEmpty) manifestOpt.tree
else if (tp.bounds.hi ne tp) getManifest(tp.bounds.hi)
else localTyper.getManifestTree(tree.pos, tp, false)
}
atPhase(phase.next) {
localTyper.typedPos(pos) {
Apply(gen.mkAttributedSelect(tree, toArraySym),
List(getManifest(tree.tpe.baseType(TraversableClass).typeArgs.head)))
}
}
}
var suffix: Tree =
if (treeInfo isWildcardStarArgList args) {
val Typed(tree, _) = args.last;
if (isJava)
if (tree.tpe.typeSymbol == ArrayClass) tree
else sequenceToArray(tree)
else
if (tree.tpe.typeSymbol isSubClass TraversableClass) tree
else arrayToSequence(tree, varargsElemType)
}
else {
def mkArray = mkArrayValue(args drop (formals.length - 1), varargsElemType)
if (isJava || inPattern) mkArray
else if (args.isEmpty) gen.mkNil
else arrayToSequence(mkArray, varargsElemType)
}
atPhase(phase.next) {
if (isJava && isPrimitiveArray(suffix.tpe) && isArrayOfSymbol(fun.tpe.params.last.tpe, ObjectClass)) {
suffix = localTyper.typedPos(pos) {
gen.mkRuntimeCall("toObjectArray", List(suffix))
}
}
}
args.take(formals.length - 1) :+ (suffix setType formals.last)
}
val args1 = if (isVarArgTypes(formals)) transformVarargs(formals.last.typeArgs.head) else args
(formals, args1).zipped map { (formal, arg) =>
if (!isByNameParamType(formal)) {
arg
} else if (isByNameRef(arg)) {
byNameArgs += arg
arg setType functionType(List(), arg.tpe)
} else {
if (opt.verboseDebug) {
val posstr = arg.pos.source.path + ":" + arg.pos.line
val permstr = if (fun.isPrivate) "private" else "notprivate"
log("byname | %s | %s | %s".format(posstr, fun.fullName, permstr))
}
val result = localTyper.typed(
Function(Nil, arg) setPos arg.pos).asInstanceOf[Function]
new ChangeOwnerTraverser(currentOwner, result.symbol).traverse(arg)
transformFunction(result)
}
}
}
def elideIntoUnit(tree: Tree): Tree = Literal(()) setPos tree.pos setType UnitClass.tpe
def isElidable(tree: Tree) = {
val sym = treeInfo.methPart(tree).symbol
sym != null && sym.elisionLevel.exists(x => x < settings.elidebelow.value || settings.noassertions.value) && {
log("Eliding call from " + tree.symbol.owner + " to " + sym + " based on its elision threshold of " + sym.elisionLevel.get)
true
}
}
def mainTransform(tree: Tree): Tree = {
@inline def withNeedLift(needLift: Boolean)(f: => Tree): Tree = {
val saved = needTryLift
needTryLift = needLift
try f
finally needTryLift = saved
}
def shouldBeLiftedAnyway(tree: Tree) = false &&
forMSIL && lookForReturns.found(tree)
def liftTree(tree: Tree) = {
if (settings.debug.value)
log("lifting tree at: " + (tree.pos))
val sym = currentOwner.newMethod(tree.pos, unit.freshTermName("liftedTree"))
sym.setInfo(MethodType(List(), tree.tpe))
new ChangeOwnerTraverser(currentOwner, sym).traverse(tree)
localTyper.typedPos(tree.pos)(Block(
List(DefDef(sym, List(Nil), tree)),
Apply(Ident(sym), Nil)
))
}
def withInConstructorFlag(inConstructorFlag: Long)(f: => Tree): Tree = {
val saved = this.inConstructorFlag
this.inConstructorFlag = inConstructorFlag
try f
finally this.inConstructorFlag = saved
}
if (isElidable(tree)) elideIntoUnit(tree)
else tree match {
case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
if (dd.symbol hasAnnotation VarargsClass) saveRepeatedParams(dd)
withNeedLift(false) {
if (tree.symbol.isClassConstructor) {
atOwner(tree.symbol) {
val rhs1 = (rhs: @unchecked) match {
case Block(stats, expr) =>
def transformInConstructor(stat: Tree) =
withInConstructorFlag(INCONSTRUCTOR) { transform(stat) }
val presupers = treeInfo.preSuperFields(stats) map transformInConstructor
val rest = stats drop presupers.length
val supercalls = rest take 1 map transformInConstructor
val others = rest drop 1 map transform
treeCopy.Block(rhs, presupers ::: supercalls ::: others, transform(expr))
}
treeCopy.DefDef(
tree, mods, name, transformTypeDefs(tparams),
transformValDefss(vparamss), transform(tpt), rhs1)
}
} else {
super.transform(tree)
}
}
case ValDef(_, _, _, rhs) =>
val sym = tree.symbol
if (!sym.owner.isSourceMethod || (sym.isVariable && freeMutableVars(sym)))
withNeedLift(true) { super.transform(tree) }
else
super.transform(tree)
case UnApply(fn, args) =>
inPattern = false
val fn1 = transform(fn)
inPattern = true
val args1 = transformTrees(fn.symbol.name match {
case nme.unapply => args
case nme.unapplySeq => transformArgs(tree.pos, fn.symbol, args, analyzer.unapplyTypeListFromReturnTypeSeq(fn.tpe))
case _ => sys.error("internal error: UnApply node has wrong symbol")
})
treeCopy.UnApply(tree, fn1, args1)
case Apply(fn, args) =>
if (fn.symbol == Object_synchronized && shouldBeLiftedAnyway(args.head))
transform(treeCopy.Apply(tree, fn, List(liftTree(args.head))))
else
withNeedLift(true) {
val formals = fn.tpe.paramTypes
treeCopy.Apply(tree, transform(fn), transformTrees(transformArgs(tree.pos, fn.symbol, args, formals)))
}
case Assign(Select(_, _), _) =>
withNeedLift(true) { super.transform(tree) }
case Assign(lhs, _) if lhs.symbol.owner != currentMethod || lhs.symbol.hasFlag(LAZY | ACCESSOR) =>
withNeedLift(true) { super.transform(tree) }
case Try(block, catches, finalizer) =>
if (needTryLift || shouldBeLiftedAnyway(tree)) transform(liftTree(tree))
else super.transform(tree)
case CaseDef(pat, guard, body) =>
inPattern = true
val pat1 = transform(pat)
inPattern = false
treeCopy.CaseDef(tree, pat1, transform(guard), transform(body))
case fun @ Function(_, _) =>
mainTransform(transformFunction(fun))
case Template(_, _, _) =>
withInConstructorFlag(0) { super.transform(tree) }
case _ =>
val tree1 = super.transform(tree)
if (isByNameRef(tree1)) {
val tree2 = tree1 setType functionType(Nil, tree1.tpe)
return {
if (noApply contains tree2) tree2
else localTyper.typedPos(tree1.pos)(Apply(Select(tree2, nme.apply), Nil))
}
}
tree1
}
} setType {
assert(tree.tpe != null, tree + " tpe is null")
uncurryTreeType(tree.tpe)
}
def postTransform(tree: Tree): Tree = atPhase(phase.next) {
def applyUnary(): Tree = {
def needsParens = tree.symbol.isMethod && !tree.tpe.isInstanceOf[PolyType]
def repair = {
if (!tree.tpe.isInstanceOf[MethodType])
tree.tpe = MethodType(Nil, tree.tpe.resultType)
atPos(tree.pos)(Apply(tree, Nil) setType tree.tpe.resultType)
}
if (needsParens) repair
else if (tree.isType) TypeTree(tree.tpe) setPos tree.pos
else tree
}
tree match {
case Template(parents, self, body) =>
localTyper = typer.atOwner(tree, currentClass)
val tmpl = if (!forMSIL || forMSIL) {
treeCopy.Template(tree, parents, self, transformTrees(newMembers.toList) ::: body)
} else super.transform(tree).asInstanceOf[Template]
newMembers.clear
tmpl
case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
val rhs1 = nonLocalReturnKeys.get(tree.symbol) match {
case None => rhs
case Some(k) => atPos(rhs.pos)(nonLocalReturnTry(rhs, k, tree.symbol))
}
val flatdd = treeCopy.DefDef(tree, mods, name, tparams, List(vparamss.flatten), tpt, rhs1)
if (dd.symbol hasAnnotation VarargsClass) addJavaVarargsForwarders(dd, flatdd, tree)
flatdd
case Try(body, catches, finalizer) =>
if (catches forall treeInfo.isCatchCase) tree
else {
val exname = unit.freshTermName("ex$")
val cases =
if ((catches exists treeInfo.isDefaultCase) || (catches.last match {
case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) if (tpt.tpe =:= ThrowableClass.tpe) =>
true
case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) if (tpt.tpe =:= ThrowableClass.tpe) =>
true
case _ =>
false
})) catches
else catches :+ CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Ident(exname)))
val catchall =
atPos(tree.pos) {
CaseDef(
Bind(exname, Ident(nme.WILDCARD)),
EmptyTree,
Match(Ident(exname), cases))
}
if (settings.debug.value) log("rewrote try: " + catches + " ==> " + catchall);
val catches1 = localTyper.typedCases(
tree, List(catchall), ThrowableClass.tpe, WildcardType)
treeCopy.Try(tree, body, catches1, finalizer)
}
case Apply(Apply(fn, args), args1) =>
treeCopy.Apply(tree, fn, args ::: args1)
case Ident(name) =>
assert(name != tpnme.WILDCARD_STAR)
applyUnary()
case Select(_, _) | TypeApply(_, _) =>
applyUnary()
case Return(expr) if (tree.symbol != currentOwner.enclMethod || currentOwner.isLazy) =>
if (settings.debug.value) log("non local return in "+tree.symbol+" from "+currentOwner.enclMethod)
atPos(tree.pos)(nonLocalReturnThrow(expr, tree.symbol))
case TypeTree() =>
tree
case _ =>
if (tree.isType) TypeTree(tree.tpe) setPos tree.pos else tree
}
}
private def saveRepeatedParams(dd: DefDef): Unit =
if (dd.symbol.isConstructor)
unit.error(dd.symbol.pos, "A constructor cannot be annotated with a `varargs` annotation.")
else treeInfo.repeatedParams(dd) match {
case Nil =>
unit.error(dd.symbol.pos, "A method without repeated parameters cannot be annotated with the `varargs` annotation.")
case reps =>
repeatedParams(dd.symbol) = reps
}
private def addJavaVarargsForwarders(dd: DefDef, flatdd: DefDef, tree: Tree): Unit = {
if (!repeatedParams.contains(dd.symbol))
return
def toSeqType(tp: Type): Type = {
val arg = elementType(ArrayClass, tp)
seqType(arg)
}
def toArrayType(tp: Type): Type = {
val arg = elementType(SeqClass, tp)
arrayType(
if (arg.typeSymbol.isTypeParameterOrSkolem) ObjectClass.tpe
else arg
)
}
val reps = repeatedParams(dd.symbol)
val rpsymbols = reps.map(_.symbol).toSet
val theTyper = typer.atOwner(tree, currentClass)
val flatparams = flatdd.vparamss.head
val forwformals = flatparams map {
case p if rpsymbols(p.symbol) => toArrayType(p.symbol.tpe)
case p => p.symbol.tpe
}
val forwresult = dd.symbol.tpe.finalResultType
val forwformsyms = (forwformals, flatparams).zipped map ((tp, oldparam) =>
currentClass.newValueParameter(oldparam.symbol.pos, oldparam.name).setInfo(tp)
)
def mono = MethodType(forwformsyms, forwresult)
val forwtype = dd.symbol.tpe match {
case MethodType(_, _) => mono
case PolyType(tps, _) => PolyType(tps, mono)
}
val forwsym = (
currentClass.newMethod(dd.pos, dd.name)
. setFlag (VARARGS | SYNTHETIC | flatdd.symbol.flags)
. setInfo (forwtype)
)
val forwtree = theTyper.typedPos(dd.pos) {
val locals = (forwsym ARGS, flatparams).zipped map {
case (_, fp) if !rpsymbols(fp.symbol) => null
case (argsym, fp) =>
Block(Nil,
gen.mkCast(
gen.mkWrapArray(Ident(argsym), elementType(ArrayClass, argsym.tpe)),
seqType(elementType(SeqClass, fp.symbol.tpe))
)
)
}
val seqargs = (locals, forwsym ARGS).zipped map {
case (null, argsym) => Ident(argsym)
case (l, _) => l
}
val end = if (forwsym.isConstructor) List(UNIT) else Nil
DEF(forwsym) === BLOCK(
Apply(gen.mkAttributedRef(flatdd.symbol), seqargs) :: end : _*
)
}
currentClass.info.member(forwsym.name).alternatives.find(s => s != forwsym && s.tpe.matches(forwsym.tpe)) match {
case Some(s) => unit.error(dd.symbol.pos,
"A method with a varargs annotation produces a forwarder method with the same signature "
+ s.tpe + " as an existing method.")
case None =>
currentClass.info.decls enter forwsym
newMembers += forwtree
}
}
}
private val freeMutableVars: mutable.Set[Symbol] = new mutable.HashSet
private val freeLocalsTraverser = new Traverser {
var currentMethod: Symbol = NoSymbol
var maybeEscaping = false
def withEscaping(body: => Unit) {
val saved = maybeEscaping
maybeEscaping = true
try body
finally maybeEscaping = saved
}
override def traverse(tree: Tree) = tree match {
case DefDef(_, _, _, _, _, _) =>
val lastMethod = currentMethod
currentMethod = tree.symbol
super.traverse(tree)
currentMethod = lastMethod
case Apply(fn, args) if fn.symbol.paramss.nonEmpty =>
traverse(fn)
(fn.symbol.paramss.head, args).zipped foreach { (param, arg) =>
if (param.tpe != null && isByNameParamType(param.tpe))
withEscaping(traverse(arg))
else
traverse(arg)
}
case Function(vparams, body) =>
vparams foreach traverse
withEscaping(traverse(body))
case Ident(_) =>
val sym = tree.symbol
if (sym.isVariable && sym.owner.isMethod && (maybeEscaping || sym.owner != currentMethod))
freeMutableVars += sym
case _ =>
super.traverse(tree)
}
}
}