package scala.collection.parallel.mutable
import collection.generic._
import collection.mutable.HashSet
import collection.mutable.FlatHashTable
import collection.parallel.Combiner
import collection.mutable.UnrolledBuffer
@SerialVersionUID(1L)
class ParHashSet[T] private[collection] (contents: FlatHashTable.Contents[T])
extends ParSet[T]
with GenericParTemplate[T, ParHashSet]
with ParSetLike[T, ParHashSet[T], collection.mutable.HashSet[T]]
with ParFlatHashTable[T]
with Serializable
{
initWithContents(contents)
def this() = this(null)
override def companion = ParHashSet
override def empty = new ParHashSet
override def iterator = splitter
override def size = tableSize
def clear() = clearTable()
override def seq = new HashSet(hashTableContents)
def +=(elem: T) = {
addEntry(elem)
this
}
def -=(elem: T) = {
removeEntry(elem)
this
}
override def stringPrefix = "ParHashSet"
def contains(elem: T) = containsEntry(elem)
def splitter = new ParHashSetIterator(0, table.length, size) with SCPI
type SCPI = SignalContextPassingIterator[ParHashSetIterator]
class ParHashSetIterator(start: Int, iteratesUntil: Int, totalElements: Int)
extends ParFlatHashTableIterator(start, iteratesUntil, totalElements) with ParIterator {
me: SCPI =>
def newIterator(start: Int, until: Int, total: Int) = new ParHashSetIterator(start, until, total) with SCPI
}
private def writeObject(s: java.io.ObjectOutputStream) {
serializeTo(s)
}
private def readObject(in: java.io.ObjectInputStream) {
init(in, x => x)
}
import collection.DebugUtils._
override def debugInformation = buildString {
append =>
append("Parallel flat hash table set")
append("No. elems: " + tableSize)
append("Table length: " + table.length)
append("Table: ")
append(arrayString(table, 0, table.length))
append("Sizemap: ")
append(arrayString(sizemap, 0, sizemap.length))
}
}
object ParHashSet extends ParSetFactory[ParHashSet] {
implicit def canBuildFrom[T]: CanCombineFrom[Coll, T, ParHashSet[T]] = new GenericCanCombineFrom[T]
override def newBuilder[T]: Combiner[T, ParHashSet[T]] = newCombiner
override def newCombiner[T]: Combiner[T, ParHashSet[T]] = ParHashSetCombiner.apply[T]
}
private[mutable] abstract class ParHashSetCombiner[T](private val tableLoadFactor: Int)
extends collection.parallel.BucketCombiner[T, ParHashSet[T], Any, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks)
with collection.mutable.FlatHashTable.HashUtils[T] {
import collection.parallel.tasksupport._
private var mask = ParHashSetCombiner.discriminantmask
private var nonmasklen = ParHashSetCombiner.nonmasklength
def +=(elem: T) = {
sz += 1
val hc = improve(elemHashCode(elem))
val pos = hc >>> nonmasklen
if (buckets(pos) eq null) {
buckets(pos) = new UnrolledBuffer[Any]
}
buckets(pos) += elem
this
}
def result: ParHashSet[T] = {
val contents = if (size >= ParHashSetCombiner.numblocks * sizeMapBucketSize) parPopulate else seqPopulate
new ParHashSet(contents)
}
private def parPopulate: FlatHashTable.Contents[T] = {
val table = new AddingFlatHashTable(size, tableLoadFactor)
val (inserted, leftovers) = executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
var leftinserts = 0
for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
table.setSize(leftinserts + inserted)
table.hashTableContents
}
private def seqPopulate: FlatHashTable.Contents[T] = {
val tbl = new FlatHashTable[T] {
sizeMapInit(table.length)
}
for {
buffer <- buckets;
if buffer ne null;
elem <- buffer
} tbl.addEntry(elem.asInstanceOf[T])
tbl.hashTableContents
}
class AddingFlatHashTable(numelems: Int, lf: Int) extends FlatHashTable[T] {
_loadFactor = lf
table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
tableSize = 0
threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
sizeMapInit(table.length)
override def toString = "AFHT(%s)".format(table.length)
def tableLength = table.length
def setSize(sz: Int) = tableSize = sz
def insertEntry(insertAt: Int, comesBefore: Int, elem: T): Int = {
var h = insertAt
if (h == -1) h = index(elemHashCode(elem))
var entry = table(h)
while (null != entry) {
if (entry == elem) return 0
h = h + 1
if (h >= comesBefore) return -1
entry = table(h)
}
table(h) = elem.asInstanceOf[AnyRef]
nnSizeMapAdd(h)
1
}
}
class FillBlocks(buckets: Array[UnrolledBuffer[Any]], table: AddingFlatHashTable, val offset: Int, val howmany: Int)
extends Task[(Int, UnrolledBuffer[Any]), FillBlocks] {
var result = (Int.MinValue, new UnrolledBuffer[Any]);
def leaf(prev: Option[(Int, UnrolledBuffer[Any])]) {
var i = offset
var totalinserts = 0
var leftover = new UnrolledBuffer[Any]()
while (i < (offset + howmany)) {
val (inserted, intonextblock) = fillBlock(i, buckets(i), leftover)
totalinserts += inserted
leftover = intonextblock
i += 1
}
result = (totalinserts, leftover)
}
private val blocksize = table.tableLength >> ParHashSetCombiner.discriminantbits
private def blockStart(block: Int) = block * blocksize
private def nextBlockStart(block: Int) = (block + 1) * blocksize
private def fillBlock(block: Int, elems: UnrolledBuffer[Any], leftovers: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
val beforePos = nextBlockStart(block)
val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[Any]())
val (leftoversIn, leftoversLeft) = insertAll(blockStart(block), beforePos, leftovers)
(elemsIn + leftoversIn, elemsLeft concat leftoversLeft)
}
private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
var leftovers = new UnrolledBuffer[Any]
var inserted = 0
var unrolled = elems.headPtr
var i = 0
var t = table
while (unrolled ne null) {
val chunkarr = unrolled.array
val chunksz = unrolled.size
while (i < chunksz) {
val elem = chunkarr(i)
val res = t.insertEntry(atPos, beforePos, elem.asInstanceOf[T])
if (res >= 0) inserted += res
else leftovers += elem
i += 1
}
i = 0
unrolled = unrolled.next
}
(inserted, leftovers)
}
def split = {
val fp = howmany / 2
List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp))
}
override def merge(that: FillBlocks) {
val atPos = blockStart(that.offset)
val beforePos = blockStart(that.offset + that.howmany)
val (inserted, remainingLeftovers) = insertAll(atPos, beforePos, this.result._2)
result = (this.result._1 + that.result._1 + inserted, remainingLeftovers concat that.result._2)
}
def shouldSplitFurther = howmany > collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, parallelismLevel)
}
}
private[parallel] object ParHashSetCombiner {
private[mutable] val discriminantbits = 5
private[mutable] val numblocks = 1 << discriminantbits
private[mutable] val discriminantmask = ((1 << discriminantbits) - 1);
private[mutable] val nonmasklength = 32 - discriminantbits
def apply[T] = new ParHashSetCombiner[T](FlatHashTable.defaultLoadFactor) {}
}