package scala.collection
package immutable
import HashMap.{ HashTrieMap, HashMapCollision1, HashMap1 }
import HashSet.{ HashTrieSet, HashSetCollision1, HashSet1 }
import annotation.unchecked.{ uncheckedVariance => uV }
import scala.annotation.tailrec
private[collection] abstract class TrieIterator[+T](elems: Array[Iterable[T]]) extends Iterator[T] {
outer =>
private[immutable] def getElem(x: AnyRef): T
def initDepth = 0
def initArrayStack: Array[Array[Iterable[T @uV]]] = new Array[Array[Iterable[T]]](6)
def initPosStack = new Array[Int](6)
def initArrayD: Array[Iterable[T @uV]] = elems
def initPosD = 0
def initSubIter: Iterator[T] = null
private[this] var depth = initDepth
private[this] var arrayStack: Array[Array[Iterable[T @uV]]] = initArrayStack
private[this] var posStack = initPosStack
private[this] var arrayD: Array[Iterable[T @uV]] = initArrayD
private[this] var posD = initPosD
private[this] var subIter = initSubIter
private[this] def getElems(x: Iterable[T]): Array[Iterable[T]] = (x match {
case x: HashTrieMap[_, _] => x.elems
case x: HashTrieSet[_] => x.elems
}).asInstanceOf[Array[Iterable[T]]]
private[this] def collisionToArray(x: Iterable[T]): Array[Iterable[T]] = (x match {
case x: HashMapCollision1[_, _] => x.kvs.map(x => HashMap(x)).toArray
case x: HashSetCollision1[_] => x.ks.map(x => HashSet(x)).toArray
}).asInstanceOf[Array[Iterable[T]]]
private type SplitIterators = ((Iterator[T], Int), Iterator[T])
private def isTrie(x: AnyRef) = x match {
case _: HashTrieMap[_,_] | _: HashTrieSet[_] => true
case _ => false
}
private def isContainer(x: AnyRef) = x match {
case _: HashMap1[_, _] | _: HashSet1[_] => true
case _ => false
}
final class DupIterator(xs: Array[Iterable[T]]) extends {
override val initDepth = outer.depth
override val initArrayStack: Array[Array[Iterable[T @uV]]] = outer.arrayStack
override val initPosStack = outer.posStack
override val initArrayD: Array[Iterable[T @uV]] = outer.arrayD
override val initPosD = outer.posD
override val initSubIter = outer.subIter
} with TrieIterator[T](xs) {
final override def getElem(x: AnyRef): T = outer.getElem(x)
}
def dupIterator: TrieIterator[T] = new DupIterator(elems)
private[this] def newIterator(xs: Array[Iterable[T]]) = new TrieIterator(xs) {
final override def getElem(x: AnyRef): T = outer.getElem(x)
}
private[this] def iteratorWithSize(arr: Array[Iterable[T]]): (Iterator[T], Int) =
(newIterator(arr), arr map (_.size) sum)
private[this] def arrayToIterators(arr: Array[Iterable[T]]): SplitIterators = {
val (fst, snd) = arr.splitAt(arr.length / 2)
(iteratorWithSize(snd), newIterator(fst))
}
private[this] def splitArray(ad: Array[Iterable[T]]): SplitIterators =
if (ad.length > 1) arrayToIterators(ad)
else ad(0) match {
case _: HashMapCollision1[_, _] | _: HashSetCollision1[_] =>
arrayToIterators(collisionToArray(ad(0)))
case _ =>
splitArray(getElems(ad(0)))
}
def hasNext = (subIter ne null) || depth >= 0
def next: T = {
if (subIter ne null) {
val el = subIter.next
if (!subIter.hasNext)
subIter = null
el
} else
next0(arrayD, posD)
}
@tailrec private[this] def next0(elems: Array[Iterable[T]], i: Int): T = {
if (i == elems.length-1) {
depth -= 1
if (depth >= 0) {
arrayD = arrayStack(depth)
posD = posStack(depth)
arrayStack(depth) = null
} else {
arrayD = null
posD = 0
}
} else
posD += 1
val m = elems(i)
if (isContainer(m))
getElem(m)
else if (isTrie(m)) {
if (depth >= 0) {
arrayStack(depth) = arrayD
posStack(depth) = posD
}
depth += 1
arrayD = getElems(m)
posD = 0
next0(getElems(m), 0)
}
else {
subIter = m.iterator
next
}
}
def split: SplitIterators = {
if (arrayD != null && depth == 0 && posD == 0)
return splitArray(arrayD)
if (subIter ne null) {
val buff = subIter.toBuffer
subIter = null
((buff.iterator, buff.length), this)
}
else {
if (depth > 0) {
val topmost = arrayStack(0)
if (posStack(0) == arrayStack(0).length - 1) {
val snd = Array[Iterable[T]](arrayStack(0).last)
val szsnd = snd(0).size
depth -= 1
1 until arrayStack.length foreach (i => arrayStack(i - 1) = arrayStack(i))
arrayStack(arrayStack.length - 1) = Array[Iterable[T]](null)
posStack = posStack.tail ++ Array[Int](0)
((newIterator(snd), szsnd), this)
} else {
val (fst, snd) = arrayStack(0).splitAt(arrayStack(0).length - (arrayStack(0).length - posStack(0) + 1) / 2)
arrayStack(0) = fst
(iteratorWithSize(snd), this)
}
} else {
if (posD == arrayD.length - 1) {
val m = arrayD(posD)
arrayToIterators(
if (isTrie(m)) getElems(m)
else collisionToArray(m)
)
}
else {
val (fst, snd) = arrayD.splitAt(arrayD.length - (arrayD.length - posD + 1) / 2)
arrayD = fst
(iteratorWithSize(snd), this)
}
}
}
}
}