/*                     __                                               *\
**     ________ ___   / /  ___     Scala API                            **
**    / __/ __// _ | / /  / _ |    (c) 2003-2013, LAMP/EPFL             **
**  __\ \/ /__/ __ |/ /__/ __ |    http://scala-lang.org/               **
** /____/\___/_/ |_/____/_/ | |                                         **
**                          |/                                          **
\*                                                                      */

package scala.collection.parallel.mutable



import scala.collection.generic._
import scala.collection.mutable.FlatHashTable
import scala.collection.parallel.Combiner
import scala.collection.mutable.UnrolledBuffer
import scala.collection.parallel.Task



/** A parallel hash set.
 *
 *  `ParHashSet` is a parallel set which internally keeps elements within a hash table.
 *  It uses linear probing to resolve collisions.
 *
 *  @tparam T        type of the elements in the $coll.
 *
 *  @define Coll `ParHashSet`
 *  @define coll parallel hash set
 *
 *  @author Aleksandar Prokopec
 *  @see  [[http://docs.scala-lang.org/overviews/parallel-collections/concrete-parallel-collections.html#parallel_hash_tables Scala's Parallel Collections Library overview]]
 *  section on Parallel Hash Tables for more information.
 */
@SerialVersionUID(1L)
class ParHashSet[T] private[collection] (contents: FlatHashTable.Contents[T])
extends ParSet[T]
   with GenericParTemplate[T, ParHashSet]
   with ParSetLike[T, ParHashSet[T], scala.collection.mutable.HashSet[T]]
   with ParFlatHashTable[T]
   with Serializable
{
  initWithContents(contents)
  // println("----> new par hash set!")
  // java.lang.Thread.dumpStack
  // println(debugInformation)

  def this() = this(null)

  override def companion = ParHashSet

  override def empty = new ParHashSet

  override def iterator = splitter

  override def size = tableSize

  def clear() = clearTable()

  override def seq = new scala.collection.mutable.HashSet(hashTableContents)

  def +=(elem: T) = {
    addEntry(elem)
    this
  }

  def -=(elem: T) = {
    removeEntry(elem)
    this
  }

  override def stringPrefix = "ParHashSet"

  def contains(elem: T) = containsEntry(elem)

  def splitter = new ParHashSetIterator(0, table.length, size)

  class ParHashSetIterator(start: Int, iteratesUntil: Int, totalElements: Int)
  extends ParFlatHashTableIterator(start, iteratesUntil, totalElements) {
    def newIterator(start: Int, until: Int, total: Int) = new ParHashSetIterator(start, until, total)
  }

  private def writeObject(s: java.io.ObjectOutputStream) {
    serializeTo(s)
  }

  private def readObject(in: java.io.ObjectInputStream) {
    init(in, x => x)
  }

  import scala.collection.DebugUtils._
  override def debugInformation = buildString {
    append =>
    append("Parallel flat hash table set")
    append("No. elems: " + tableSize)
    append("Table length: " + table.length)
    append("Table: ")
    append(arrayString(table, 0, table.length))
    append("Sizemap: ")
    append(arrayString(sizemap, 0, sizemap.length))
  }

}


/** $factoryInfo
 *  @define Coll `mutable.ParHashSet`
 *  @define coll parallel hash set
 */
object ParHashSet extends ParSetFactory[ParHashSet] {
  implicit def canBuildFrom[T]: CanCombineFrom[Coll, T, ParHashSet[T]] = new GenericCanCombineFrom[T]

  override def newBuilder[T]: Combiner[T, ParHashSet[T]] = newCombiner

  override def newCombiner[T]: Combiner[T, ParHashSet[T]] = ParHashSetCombiner.apply[T]
}


private[mutable] abstract class ParHashSetCombiner[T](private val tableLoadFactor: Int)
extends scala.collection.parallel.BucketCombiner[T, ParHashSet[T], Any, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks)
with scala.collection.mutable.FlatHashTable.HashUtils[T] {
//self: EnvironmentPassingCombiner[T, ParHashSet[T]] =>
  private var mask = ParHashSetCombiner.discriminantmask
  private var nonmasklen = ParHashSetCombiner.nonmasklength
  private var seedvalue = 27

  def +=(elem: T) = {
    sz += 1
    val hc = improve(elemHashCode(elem), seedvalue)
    val pos = hc >>> nonmasklen
    if (buckets(pos) eq null) {
      // initialize bucket
      buckets(pos) = new UnrolledBuffer[Any]
    }
    // add to bucket
    buckets(pos) += elem
    this
  }

  def result: ParHashSet[T] = {
    val contents = if (size >= ParHashSetCombiner.numblocks * sizeMapBucketSize) parPopulate else seqPopulate
    new ParHashSet(contents)
  }

  private def parPopulate: FlatHashTable.Contents[T] = {
    // construct it in parallel
    val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue)
    val (inserted, leftovers) = combinerTaskSupport.executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
    var leftinserts = 0
    for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
    table.setSize(leftinserts + inserted)
    table.hashTableContents
  }

  private def seqPopulate: FlatHashTable.Contents[T] = {
    // construct it sequentially
    // TODO parallelize by keeping separate size maps and merging them
    val tbl = new FlatHashTable[T] {
      sizeMapInit(table.length)
      seedvalue = ParHashSetCombiner.this.seedvalue
      for {
        buffer <- buckets;
        if buffer ne null;
        elem <- buffer
      } addEntry(elem.asInstanceOf[T])
    }
    tbl.hashTableContents
  }

  /* classes */

  /** A flat hash table which doesn't resize itself. It accepts the number of elements
   *  it has to take and allocates the underlying hash table in advance.
   *  Elements can only be added to it. The final size has to be adjusted manually.
   *  It is internal to `ParHashSet` combiners.
   */
  class AddingFlatHashTable(numelems: Int, lf: Int, inseedvalue: Int) extends FlatHashTable[T] {
    _loadFactor = lf
    table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
    tableSize = 0
    threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
    seedvalue = inseedvalue
    sizeMapInit(table.length)

    override def toString = "AFHT(%s)".format(table.length)

    def tableLength = table.length

    def setSize(sz: Int) = tableSize = sz

    /**
     *  The elements are added using the `insertEntry` method. This method accepts three
     *  arguments:
     *
     *  @param insertAt      where to add the element (set to -1 to use its hashcode)
     *  @param comesBefore   the position before which the element should be added to
     *  @param elem          the element to be added
     *
     *  If the element is to be inserted at the position corresponding to its hash code,
     *  the table will try to add the element in such a position if possible. Collisions are resolved
     *  using linear hashing, so the element may actually have to be added to a position
     *  that follows the specified one. In the case that the first unoccupied position
     *  comes after `comesBefore`, the element is not added and the method simply returns -1,
     *  indicating that it couldn't add the element in a position that comes before the
     *  specified one.
     *  If the element is already present in the hash table, it is not added, and this method
     *  returns 0. If the element is added, it returns 1.
     */
    def insertEntry(insertAt: Int, comesBefore: Int, elem: T): Int = {
      var h = insertAt
      if (h == -1) h = index(elemHashCode(elem))
      var entry = table(h)
      while (null != entry) {
        if (entry == elem) return 0
        h = h + 1 // we *do not* do `(h + 1) % table.length` here, because we'll never overflow!!
        if (h >= comesBefore) return -1
        entry = table(h)
      }
      table(h) = elem.asInstanceOf[AnyRef]

      // this is incorrect since we set size afterwards anyway and a counter
      // like this would not even work:
      //
      //   tableSize = tableSize + 1
      //
      // furthermore, it completely bogs down the parallel
      // execution when there are multiple workers

      nnSizeMapAdd(h)
      1
    }
  }

  /* tasks */

  class FillBlocks(buckets: Array[UnrolledBuffer[Any]], table: AddingFlatHashTable, val offset: Int, val howmany: Int)
  extends Task[(Int, UnrolledBuffer[Any]), FillBlocks] {
    var result = (Int.MinValue, new UnrolledBuffer[Any]);
    def leaf(prev: Option[(Int, UnrolledBuffer[Any])]) {
      var i = offset
      var totalinserts = 0
      var leftover = new UnrolledBuffer[Any]()
      while (i < (offset + howmany)) {
        val (inserted, intonextblock) = fillBlock(i, buckets(i), leftover)
        totalinserts += inserted
        leftover = intonextblock
        i += 1
      }
      result = (totalinserts, leftover)
    }
    private val blocksize = table.tableLength >> ParHashSetCombiner.discriminantbits
    private def blockStart(block: Int) = block * blocksize
    private def nextBlockStart(block: Int) = (block + 1) * blocksize
    private def fillBlock(block: Int, elems: UnrolledBuffer[Any], leftovers: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
      val beforePos = nextBlockStart(block)

      // store the elems
      val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[Any]())

      // store the leftovers
      val (leftoversIn, leftoversLeft) = insertAll(blockStart(block), beforePos, leftovers)

      // return the no. of stored elements tupled with leftovers
      (elemsIn + leftoversIn, elemsLeft concat leftoversLeft)
    }
    private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
      var leftovers = new UnrolledBuffer[Any]
      var inserted = 0

      var unrolled = elems.headPtr
      var i = 0
      var t = table
      while (unrolled ne null) {
        val chunkarr = unrolled.array
        val chunksz = unrolled.size
        while (i < chunksz) {
          val elem = chunkarr(i)
          val res = t.insertEntry(atPos, beforePos, elem.asInstanceOf[T])
          if (res >= 0) inserted += res
          else leftovers += elem
          i += 1
        }
        i = 0
        unrolled = unrolled.next
      }

      // slower:
      // var it = elems.iterator
      // while (it.hasNext) {
      //   val elem = it.next
      //   val res = table.insertEntry(atPos, beforePos, elem.asInstanceOf[T])
      //   if (res >= 0) inserted += res
      //   else leftovers += elem
      // }

      (inserted, leftovers)
    }
    def split = {
      val fp = howmany / 2
      List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp))
    }
    override def merge(that: FillBlocks) {
      // take the leftovers from the left task, store them into the block of the right task
      val atPos = blockStart(that.offset)
      val beforePos = blockStart(that.offset + that.howmany)
      val (inserted, remainingLeftovers) = insertAll(atPos, beforePos, this.result._2)

      // anything left after trying the store the left leftovers is added to the right task leftovers
      // and a new leftovers set is produced in this way
      // the total number of successfully inserted elements is adjusted accordingly
      result = (this.result._1 + that.result._1 + inserted, remainingLeftovers concat that.result._2)
    }
    def shouldSplitFurther = howmany > scala.collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, combinerTaskSupport.parallelismLevel)
  }

}


private[parallel] object ParHashSetCombiner {
  private[mutable] val discriminantbits = 5
  private[mutable] val numblocks = 1 << discriminantbits
  private[mutable] val discriminantmask = ((1 << discriminantbits) - 1);
  private[mutable] val nonmasklength = 32 - discriminantbits

  def apply[T] = new ParHashSetCombiner[T](FlatHashTable.defaultLoadFactor) {} //with EnvironmentPassingCombiner[T, ParHashSet[T]]
}