package scala.collection.parallel.mutable
import collection.generic._
import collection.mutable.HashSet
import collection.mutable.FlatHashTable
import collection.parallel.Combiner
import collection.mutable.UnrolledBuffer
@SerialVersionUID(1L)
class ParHashSet[T] private[collection] (contents: FlatHashTable.Contents[T])
extends ParSet[T]
   with GenericParTemplate[T, ParHashSet]
   with ParSetLike[T, ParHashSet[T], collection.mutable.HashSet[T]]
   with ParFlatHashTable[T]
   with Serializable
{
  initWithContents(contents)
  
  
  
  
  def this() = this(null)
  
  override def companion = ParHashSet
  
  override def empty = new ParHashSet
  
  override def iterator = splitter
  
  override def size = tableSize
  
  def clear() = clearTable()
  
  override def seq = new HashSet(hashTableContents)
  
  def +=(elem: T) = {
    addEntry(elem)
    this
  }
  
  def -=(elem: T) = {
    removeEntry(elem)
    this
  }
  
  override def stringPrefix = "ParHashSet"
  
  def contains(elem: T) = containsEntry(elem)
  
  def splitter = new ParHashSetIterator(0, table.length, size) with SCPI
  
  type SCPI = SignalContextPassingIterator[ParHashSetIterator]
  
  class ParHashSetIterator(start: Int, iteratesUntil: Int, totalElements: Int)
  extends ParFlatHashTableIterator(start, iteratesUntil, totalElements) with ParIterator {
  me: SCPI =>
    def newIterator(start: Int, until: Int, total: Int) = new ParHashSetIterator(start, until, total) with SCPI
  }
  
  private def writeObject(s: java.io.ObjectOutputStream) {
    serializeTo(s)
  }
  
  private def readObject(in: java.io.ObjectInputStream) {
    init(in, x => x)
  }
  
  import collection.DebugUtils._
  override def debugInformation = buildString {
    append =>
    append("Parallel flat hash table set")
    append("No. elems: " + tableSize)
    append("Table length: " + table.length)
    append("Table: ")
    append(arrayString(table, 0, table.length))
    append("Sizemap: ")
    append(arrayString(sizemap, 0, sizemap.length))
  }
  
}
object ParHashSet extends ParSetFactory[ParHashSet] {
  implicit def canBuildFrom[T]: CanCombineFrom[Coll, T, ParHashSet[T]] = new GenericCanCombineFrom[T]
  
  override def newBuilder[T]: Combiner[T, ParHashSet[T]] = newCombiner
  
  override def newCombiner[T]: Combiner[T, ParHashSet[T]] = ParHashSetCombiner.apply[T]
}
private[mutable] abstract class ParHashSetCombiner[T](private val tableLoadFactor: Int)
extends collection.parallel.BucketCombiner[T, ParHashSet[T], Any, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks)
with collection.mutable.FlatHashTable.HashUtils[T] {
  import collection.parallel.tasksupport._
  private var mask = ParHashSetCombiner.discriminantmask
  private var nonmasklen = ParHashSetCombiner.nonmasklength
  
  def +=(elem: T) = {
    sz += 1
    val hc = improve(elemHashCode(elem))
    val pos = hc >>> nonmasklen
    if (buckets(pos) eq null) {
      
      buckets(pos) = new UnrolledBuffer[Any]
    }
    
    buckets(pos) += elem
    this
  }
  
  def result: ParHashSet[T] = {
    val contents = if (size >= ParHashSetCombiner.numblocks * sizeMapBucketSize) parPopulate else seqPopulate
    new ParHashSet(contents)
  }
  
  private def parPopulate: FlatHashTable.Contents[T] = {
    
    val table = new AddingFlatHashTable(size, tableLoadFactor)
    val (inserted, leftovers) = executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
    var leftinserts = 0
    for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
    table.setSize(leftinserts + inserted)
    table.hashTableContents
  }
  
  private def seqPopulate: FlatHashTable.Contents[T] = {
    
    
    val tbl = new FlatHashTable[T] {
      sizeMapInit(table.length)
    }
    for {
      buffer <- buckets;
      if buffer ne null;
      elem <- buffer
    } tbl.addEntry(elem.asInstanceOf[T])
    tbl.hashTableContents
  }
  
  
  
  
  class AddingFlatHashTable(numelems: Int, lf: Int) extends FlatHashTable[T] {
    _loadFactor = lf
    table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
    tableSize = 0
    threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
    sizeMapInit(table.length)
    
    override def toString = "AFHT(%s)".format(table.length)
    
    def tableLength = table.length
    
    def setSize(sz: Int) = tableSize = sz
    
    def insertEntry(insertAt: Int, comesBefore: Int, elem: T): Int = {
      var h = insertAt
      if (h == -1) h = index(elemHashCode(elem))
      var entry = table(h)
      while (null != entry) {
        if (entry == elem) return 0
        h = h + 1 
        if (h >= comesBefore) return -1
        entry = table(h)
      }
      table(h) = elem.asInstanceOf[AnyRef]
      
      
      
      
      
      
      
      
      
      nnSizeMapAdd(h)
      1
    }
  }
  
  
  
  class FillBlocks(buckets: Array[UnrolledBuffer[Any]], table: AddingFlatHashTable, val offset: Int, val howmany: Int)
  extends Task[(Int, UnrolledBuffer[Any]), FillBlocks] {
    var result = (Int.MinValue, new UnrolledBuffer[Any]);
    def leaf(prev: Option[(Int, UnrolledBuffer[Any])]) {
      var i = offset
      var totalinserts = 0
      var leftover = new UnrolledBuffer[Any]()
      while (i < (offset + howmany)) {
        val (inserted, intonextblock) = fillBlock(i, buckets(i), leftover)
        totalinserts += inserted
        leftover = intonextblock
        i += 1
      }
      result = (totalinserts, leftover)
    }
    private val blocksize = table.tableLength >> ParHashSetCombiner.discriminantbits
    private def blockStart(block: Int) = block * blocksize
    private def nextBlockStart(block: Int) = (block + 1) * blocksize
    private def fillBlock(block: Int, elems: UnrolledBuffer[Any], leftovers: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
      val beforePos = nextBlockStart(block)
      
      
      val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[Any]())
      
      
      val (leftoversIn, leftoversLeft) = insertAll(blockStart(block), beforePos, leftovers)
      
      
      (elemsIn + leftoversIn, elemsLeft concat leftoversLeft)
    }
    private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
      var leftovers = new UnrolledBuffer[Any]
      var inserted = 0
      
      var unrolled = elems.headPtr
      var i = 0
      var t = table
      while (unrolled ne null) {
        val chunkarr = unrolled.array
        val chunksz = unrolled.size
        while (i < chunksz) {
          val elem = chunkarr(i)
          val res = t.insertEntry(atPos, beforePos, elem.asInstanceOf[T])
          if (res >= 0) inserted += res
          else leftovers += elem
          i += 1
        }
        i = 0
        unrolled = unrolled.next
      }
      
      
      
      
      
      
      
      
      
      
      (inserted, leftovers)
    }
    def split = {
      val fp = howmany / 2
      List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp))
    }
    override def merge(that: FillBlocks) {
      
      val atPos = blockStart(that.offset)
      val beforePos = blockStart(that.offset + that.howmany)
      val (inserted, remainingLeftovers) = insertAll(atPos, beforePos, this.result._2)
      
      
      
      
      result = (this.result._1 + that.result._1 + inserted, remainingLeftovers concat that.result._2)
    }
    def shouldSplitFurther = howmany > collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, parallelismLevel)
  }
  
}
private[parallel] object ParHashSetCombiner {
  private[mutable] val discriminantbits = 5
  private[mutable] val numblocks = 1 << discriminantbits
  private[mutable] val discriminantmask = ((1 << discriminantbits) - 1);
  private[mutable] val nonmasklength = 32 - discriminantbits
  
  def apply[T] = new ParHashSetCombiner[T](FlatHashTable.defaultLoadFactor) {} 
}