package scala.tools.nsc
package transform
import scala.collection.{ mutable, immutable }
import scala.collection.mutable.ListBuffer
import symtab.Flags._
import util.TreeSet
abstract class Constructors extends Transform with ast.TreeDSL {
import global._
import definitions._
val phaseName: String = "constructors"
protected def newTransformer(unit: CompilationUnit): Transformer =
new ConstructorTransformer(unit)
private val guardedCtorStats: mutable.Map[Symbol, List[Tree]] = new mutable.HashMap[Symbol, List[Tree]]
private val ctorParams: mutable.Map[Symbol, List[Symbol]] = new mutable.HashMap[Symbol, List[Symbol]]
class ConstructorTransformer(unit: CompilationUnit) extends Transformer {
def transformClassTemplate(impl: Template): Template = {
val clazz = impl.symbol.owner
val stats = impl.body
val localTyper = typer.atOwner(impl, clazz)
val specializedFlag: Symbol = clazz.info.decl(nme.SPECIALIZED_INSTANCE)
val shouldGuard = (specializedFlag != NoSymbol) && !clazz.hasFlag(SPECIALIZED)
case class ConstrInfo(
constr: DefDef,
constrParams: List[Symbol],
constrBody: Block
)
val constrInfo: ConstrInfo = {
val primary = stats find (_.symbol.isPrimaryConstructor)
assert(primary.isDefined, "no constructor in template: impl = " + impl)
val ddef @ DefDef(_, _, _, List(vparams), _, rhs @ Block(_, _)) = primary.get
ConstrInfo(ddef, vparams map (_.symbol), rhs)
}
import constrInfo._
val paramAccessors = clazz.constrParamAccessors
def parameter(acc: Symbol): Symbol =
parameterNamed(nme.getterName(acc.originalName))
def parameterNamed(name: Name): Symbol = {
def matchesName(param: Symbol) = param.name == name || param.name.startsWith(name + "$")
(constrParams filter matchesName) match {
case Nil => assert(false, name + " not in " + constrParams) ; null
case p :: _ => p
}
}
var usesSpecializedField: Boolean = false
val intoConstructorTransformer = new Transformer {
def isParamRef(sym: Symbol) =
sym.isParamAccessor &&
sym.owner == clazz &&
!(clazz isSubClass DelayedInitClass) &&
!(sym.isGetter && sym.accessed.isVariable) &&
!sym.isSetter
private def possiblySpecialized(s: Symbol) = specializeTypes.specializedTypeVars(s).nonEmpty
override def transform(tree: Tree): Tree = tree match {
case Apply(Select(This(_), _), List()) =>
if (isParamRef(tree.symbol) && !possiblySpecialized(tree.symbol))
gen.mkAttributedIdent(parameter(tree.symbol.accessed)) setPos tree.pos
else if (tree.symbol.outerSource == clazz && !clazz.isImplClass)
gen.mkAttributedIdent(parameterNamed(nme.OUTER)) setPos tree.pos
else
super.transform(tree)
case Select(This(_), _) if (isParamRef(tree.symbol) && !possiblySpecialized(tree.symbol)) =>
gen.mkAttributedIdent(parameter(tree.symbol)) setPos tree.pos
case Select(_, _) =>
if (specializeTypes.specializedTypeVars(tree.symbol).nonEmpty)
usesSpecializedField = true
super.transform(tree)
case _ =>
super.transform(tree)
}
}
def intoConstructor(oldowner: Symbol, tree: Tree) =
intoConstructorTransformer.transform(
new ChangeOwnerTraverser(oldowner, constr.symbol)(tree))
def canBeMoved(tree: Tree) = tree match {
case ValDef(mods, _, _, _) => (mods hasFlag PRESUPER | PARAMACCESSOR)
case _ => false
}
def mkAssign(to: Symbol, from: Tree): Tree =
localTyper.typedPos(to.pos) { Assign(Select(This(clazz), to), from) }
def copyParam(to: Symbol, from: Symbol): Tree = {
import CODE._
val result = mkAssign(to, Ident(from))
if (from.name != nme.OUTER) result
else localTyper.typedPos(to.pos) {
IF (from OBJ_EQ NULL) THEN THROW(NullPointerExceptionClass) ELSE result
}
}
val defBuf = new ListBuffer[Tree]
val auxConstructorBuf = new ListBuffer[Tree]
val constrStatBuf = new ListBuffer[Tree]
val constrPrefixBuf = new ListBuffer[Tree]
val presupers = treeInfo.preSuperFields(stats)
for (stat <- constrBody.stats) {
constrStatBuf += stat
stat match {
case ValDef(mods, name, _, _) if (mods hasFlag PRESUPER) =>
val fields = presupers filter (
vdef => nme.localToGetter(vdef.name) == name)
assert(fields.length == 1)
val to = fields.head.symbol
if (!to.tpe.isInstanceOf[ConstantType])
constrStatBuf += mkAssign(to, Ident(stat.symbol))
case _ =>
}
}
for (stat <- stats) stat match {
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
stat.symbol.tpe match {
case MethodType(List(), tp @ ConstantType(c)) =>
defBuf += treeCopy.DefDef(
stat, mods, name, tparams, vparamss, tpt,
Literal(c) setPos rhs.pos setType tp)
case _ =>
if (stat.symbol.isPrimaryConstructor) ()
else if (stat.symbol.isConstructor) auxConstructorBuf += stat
else defBuf += stat
}
case ValDef(mods, name, tpt, rhs) =>
if (!stat.symbol.tpe.isInstanceOf[ConstantType]) {
if (rhs != EmptyTree && !stat.symbol.isLazy) {
val rhs1 = intoConstructor(stat.symbol, rhs);
(if (canBeMoved(stat)) constrPrefixBuf else constrStatBuf) += mkAssign(
stat.symbol, rhs1)
}
defBuf += treeCopy.ValDef(stat, mods, name, tpt, EmptyTree)
}
case ClassDef(_, _, _, _) =>
defBuf += new ConstructorTransformer(unit).transform(stat)
case _ =>
constrStatBuf += intoConstructor(impl.symbol, stat)
}
val accessedSyms = new TreeSet[Symbol]((x, y) => x isLess y)
var outerAccessors: List[(Symbol, Tree)] = List()
def maybeOmittable(sym: Symbol) = sym.owner == clazz && (
sym.isParamAccessor && sym.isPrivateLocal ||
sym.isOuterAccessor && sym.owner.isFinal && sym.allOverriddenSymbols.isEmpty &&
!(clazz isSubClass DelayedInitClass)
)
def mustbeKept(sym: Symbol) = !maybeOmittable(sym) || (accessedSyms contains sym)
val accessTraverser = new Traverser {
override def traverse(tree: Tree) = {
tree match {
case DefDef(_, _, _, _, _, body)
if (tree.symbol.isOuterAccessor && tree.symbol.owner == clazz && clazz.isFinal) =>
log("outerAccessors += " + tree.symbol.fullName)
outerAccessors ::= (tree.symbol, body)
case Select(_, _) =>
if (!mustbeKept(tree.symbol)) {
log("accessedSyms += " + tree.symbol.fullName)
accessedSyms addEntry tree.symbol
}
super.traverse(tree)
case _ =>
super.traverse(tree)
}
}
}
for (stat <- defBuf.iterator ++ auxConstructorBuf.iterator)
accessTraverser.traverse(stat)
for ((accSym, accBody) <- outerAccessors)
if (mustbeKept(accSym)) accessTraverser.traverse(accBody)
val parentSymbols = Map((for {
p <- impl.parents
if p.symbol.isTrait
sym <- p.symbol.info.nonPrivateMembers
if sym.isGetter && !sym.isOuterField
} yield sym.name -> p): _*)
val paramInits =
for (acc <- paramAccessors if mustbeKept(acc)) yield {
if (parentSymbols contains acc.name)
unit.error(acc.pos, "parameter '%s' requires field but conflicts with %s in '%s'".format(
acc.name, acc.name, parentSymbols(acc.name)))
copyParam(acc, parameter(acc))
}
def mergeConstructors(genericClazz: Symbol, originalStats: List[Tree], specializedStats: List[Tree]): List[Tree] = {
val specBuf = new ListBuffer[Tree]
specBuf ++= specializedStats
def specializedAssignFor(sym: Symbol): Option[Tree] =
specializedStats.find {
case Assign(sel @ Select(This(_), _), rhs) if sel.symbol.hasFlag(SPECIALIZED) =>
val (generic, _, _) = nme.splitSpecializedName(nme.localToGetter(sel.symbol.name))
generic == nme.localToGetter(sym.name)
case _ => false
}
def rewriteArrayUpdate(tree: Tree): Tree = {
val array_update = definitions.ScalaRunTimeModule.info.member("array_update")
val adapter = new Transformer {
override def transform(t: Tree): Tree = t match {
case Apply(fun @ Select(receiver, method), List(xs, idx, v)) if fun.symbol == array_update =>
localTyper.typed(Apply(gen.mkAttributedSelect(xs, definitions.Array_update), List(idx, v)))
case _ => super.transform(t)
}
}
adapter.transform(tree)
}
log("merging: " + originalStats.mkString("\n") + "\nwith\n" + specializedStats.mkString("\n"))
val res = for (s <- originalStats; val stat = s.duplicate) yield {
log("merge: looking at " + stat)
val stat1 = stat match {
case Assign(sel @ Select(This(_), field), _) =>
specializedAssignFor(sel.symbol).getOrElse(stat)
case _ => stat
}
if (stat1 ne stat) {
log("replaced " + stat + " with " + stat1)
specBuf -= stat1
}
if (stat1 eq stat) {
assert(ctorParams(genericClazz).length == constrParams.length)
(new specializeTypes.ImplementationAdapter(ctorParams(genericClazz), constrParams, null, true))(stat1)
val stat2 = rewriteArrayUpdate(stat1)
if (settings.debug.value) log("retyping " + stat2)
val d = new specializeTypes.Duplicator
d.retyped(localTyper.context1.asInstanceOf[d.Context],
stat2,
genericClazz,
clazz,
Map.empty)
} else
stat1
}
if (specBuf.nonEmpty)
println("residual specialized constructor statements: " + specBuf)
res
}
def guardSpecializedInitializer(stats: List[Tree]): List[Tree] = if (settings.nospecialization.value) stats else {
if (usesSpecializedField && shouldGuard && stats.nonEmpty) {
guardedCtorStats(clazz) = stats
ctorParams(clazz) = constrParams
val tree =
If(
Apply(
Select(
Apply(gen.mkAttributedRef(specializedFlag), List()),
definitions.getMember(definitions.BooleanClass, nme.UNARY_!)),
List()),
Block(stats, Literal(())),
EmptyTree)
List(localTyper.typed(tree))
} else if (clazz.hasFlag(SPECIALIZED)) {
val (genericName, _, _) = nme.splitSpecializedName(clazz.name)
val genericClazz = clazz.owner.info.decl(genericName.toTypeName)
assert(genericClazz != NoSymbol)
guardedCtorStats.get(genericClazz) match {
case Some(stats1) => mergeConstructors(genericClazz, stats1, stats)
case None => stats
}
} else stats
}
def addAccessor(sym: Symbol, name: TermName, flags: Long) = {
val m = clazz.newMethod(sym.pos, name)
.setFlag(flags & ~LOCAL & ~PRIVATE)
m.privateWithin = clazz
clazz.info.decls.enter(m)
m
}
def addGetter(sym: Symbol): Symbol = {
val getr = addAccessor(
sym, nme.getterName(sym.name), getterFlags(sym.flags))
getr setInfo MethodType(List(), sym.tpe)
defBuf += localTyper.typed {
atPos(sym.pos) {
DefDef(getr, Select(This(clazz), sym))
}
}
getr
}
def addSetter(sym: Symbol): Symbol = {
sym setFlag MUTABLE
val setr = addAccessor(
sym, nme.getterToSetter(nme.getterName(sym.name)), setterFlags(sym.flags))
setr setInfo MethodType(setr.newSyntheticValueParams(List(sym.tpe)), UnitClass.tpe)
defBuf += localTyper.typed {
atPos(sym.pos) {
DefDef(setr, paramss =>
Assign(Select(This(clazz), sym), Ident(paramss.head.head)))
}
}
setr
}
def ensureAccessor(sym: Symbol)(acc: => Symbol) =
if (sym.owner == clazz && !sym.isMethod && sym.isPrivate) {
var getr = acc
getr makeNotPrivate clazz
getr
} else {
if (sym.owner == clazz) sym makeNotPrivate clazz
NoSymbol
}
def ensureGetter(sym: Symbol): Symbol = ensureAccessor(sym) {
val getr = sym.getter(clazz)
if (getr != NoSymbol) getr else addGetter(sym)
}
def ensureSetter(sym: Symbol): Symbol = ensureAccessor(sym) {
var setr = sym.setter(clazz, hasExpandedName = false)
if (setr == NoSymbol) setr = sym.setter(clazz, hasExpandedName = true)
if (setr == NoSymbol) setr = addSetter(sym)
setr
}
def delayedInitClosure(stats: List[Tree]) =
localTyper.typed {
atPos(impl.pos) {
val closureClass = clazz.newClass(impl.pos, nme.delayedInitArg.toTypeName)
.setFlag(SYNTHETIC | FINAL)
val closureParents = List(AbstractFunctionClass(0).tpe, ScalaObjectClass.tpe)
closureClass.setInfo(new ClassInfoType(closureParents, new Scope, closureClass))
val outerField = closureClass.newValue(impl.pos, nme.OUTER)
.setFlag(PRIVATE | LOCAL | PARAMACCESSOR)
.setInfo(clazz.tpe)
val applyMethod = closureClass.newMethod(impl.pos, nme.apply)
.setFlag(FINAL)
.setInfo(MethodType(List(), ObjectClass.tpe))
closureClass.info.decls enter outerField
closureClass.info.decls enter applyMethod
val outerFieldDef = ValDef(outerField)
val changeOwner = new ChangeOwnerTraverser(impl.symbol, applyMethod)
val closureClassTyper = localTyper.atOwner(closureClass)
val applyMethodTyper = closureClassTyper.atOwner(applyMethod)
val constrStatTransformer = new Transformer {
override def transform(tree: Tree): Tree = tree match {
case This(_) if tree.symbol == clazz =>
applyMethodTyper.typed {
atPos(tree.pos) {
Select(This(closureClass), outerField)
}
}
case _ =>
super.transform {
tree match {
case Select(qual, _) =>
val getter = ensureGetter(tree.symbol)
if (getter != NoSymbol)
applyMethodTyper.typed {
atPos(tree.pos) {
Apply(Select(qual, getter), List())
}
}
else tree
case Assign(lhs @ Select(qual, _), rhs) =>
val setter = ensureSetter(lhs.symbol)
if (setter != NoSymbol)
applyMethodTyper.typed {
atPos(tree.pos) {
Apply(Select(qual, setter), List(rhs))
}
}
else tree
case _ =>
changeOwner.changeOwner(tree)
tree
}
}
}
}
def applyMethodStats = constrStatTransformer.transformTrees(stats)
val applyMethodDef = DefDef(
sym = applyMethod,
vparamss = List(List()),
rhs = Block(applyMethodStats, gen.mkAttributedRef(BoxedUnit_UNIT)))
ClassDef(
sym = closureClass,
constrMods = Modifiers(0),
vparamss = List(List(outerFieldDef)),
argss = List(List()),
body = List(applyMethodDef),
superPos = impl.pos)
}
}
def delayedInitCall(closure: Tree) =
localTyper.typed {
atPos(impl.pos) {
Apply(
Select(This(clazz), delayedInitMethod),
List(New(TypeTree(closure.symbol.tpe), List(List(This(clazz))))))
}
}
def splitAtSuper(stats: List[Tree]) = {
def isConstr(tree: Tree) = (tree.symbol ne null) && tree.symbol.isConstructor
val (pre, rest0) = stats span (!isConstr(_))
val (supercalls, rest) = rest0 span (isConstr(_))
(pre ::: supercalls, rest)
}
var (uptoSuperStats, remainingConstrStats) = splitAtSuper(constrStatBuf.toList)
val needsDelayedInit =
(clazz isSubClass DelayedInitClass) && remainingConstrStats.nonEmpty
if (needsDelayedInit) {
val dicl = new ConstructorTransformer(unit) transform delayedInitClosure(remainingConstrStats)
defBuf += dicl
remainingConstrStats = List(delayedInitCall(dicl))
}
defBuf += treeCopy.DefDef(
constr, constr.mods, constr.name, constr.tparams, constr.vparamss, constr.tpt,
treeCopy.Block(
constrBody,
paramInits ::: constrPrefixBuf.toList ::: uptoSuperStats :::
guardSpecializedInitializer(remainingConstrStats),
constrBody.expr));
defBuf ++= auxConstructorBuf
for (sym <- clazz.info.decls.toList)
if (!mustbeKept(sym)) {
clazz.info.decls unlink sym
}
treeCopy.Template(impl, impl.parents, impl.self,
defBuf.toList filter (stat => mustbeKept(stat.symbol)))
}
override def transform(tree: Tree): Tree =
tree match {
case ClassDef(mods, name, tparams, impl) if !tree.symbol.isInterface && !isValueClass(tree.symbol) =>
treeCopy.ClassDef(tree, mods, name, tparams, transformClassTemplate(impl))
case _ =>
super.transform(tree)
}
}
}