/* __ *\ ** ________ ___ / / ___ Scala API ** ** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL ** ** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ ** ** /____/\___/_/ |_/____/_/ | | ** ** |/ ** \* */ package scala.collection.parallel package mutable import scala.collection.generic._ import scala.collection.mutable.DefaultEntry import scala.collection.mutable.HashEntry import scala.collection.mutable.HashTable import scala.collection.mutable.UnrolledBuffer import scala.collection.parallel.Task /** A parallel hash map. * * `ParHashMap` is a parallel map which internally keeps elements within a hash table. * It uses chaining to resolve collisions. * * @tparam K type of the keys in the parallel hash map * @tparam V type of the values in the parallel hash map * * @define Coll `ParHashMap` * @define coll parallel hash map * * @author Aleksandar Prokopec * @see [[http://docs.scala-lang.org/overviews/parallel-collections/concrete-parallel-collections.html#parallel_hash_tables Scala's Parallel Collections Library overview]] * section on Parallel Hash Tables for more information. */ @SerialVersionUID(1L) class ParHashMap[K, V] private[collection] (contents: HashTable.Contents[K, DefaultEntry[K, V]]) extends ParMap[K, V] with GenericParMapTemplate[K, V, ParHashMap] with ParMapLike[K, V, ParHashMap[K, V], scala.collection.mutable.HashMap[K, V]] with ParHashTable[K, DefaultEntry[K, V]] with Serializable { self => initWithContents(contents) type Entry = scala.collection.mutable.DefaultEntry[K, V] def this() = this(null) override def mapCompanion: GenericParMapCompanion[ParHashMap] = ParHashMap override def empty: ParHashMap[K, V] = new ParHashMap[K, V] protected[this] override def newCombiner = ParHashMapCombiner[K, V] override def seq = new scala.collection.mutable.HashMap[K, V](hashTableContents) def splitter = new ParHashMapIterator(1, table.length, size, table(0).asInstanceOf[DefaultEntry[K, V]]) override def size = tableSize override def clear() = clearTable() def get(key: K): Option[V] = { val e = findEntry(key) if (e eq null) None else Some(e.value) } def put(key: K, value: V): Option[V] = { val e = findOrAddEntry(key, value) if (e eq null) None else { val v = e.value; e.value = value; Some(v) } } def update(key: K, value: V): Unit = put(key, value) def remove(key: K): Option[V] = { val e = removeEntry(key) if (e ne null) Some(e.value) else None } def += (kv: (K, V)): this.type = { val e = findOrAddEntry(kv._1, kv._2) if (e ne null) e.value = kv._2 this } def -=(key: K): this.type = { removeEntry(key); this } override def stringPrefix = "ParHashMap" class ParHashMapIterator(start: Int, untilIdx: Int, totalSize: Int, e: DefaultEntry[K, V]) extends EntryIterator[(K, V), ParHashMapIterator](start, untilIdx, totalSize, e) { def entry2item(entry: DefaultEntry[K, V]) = (entry.key, entry.value); def newIterator(idxFrom: Int, idxUntil: Int, totalSz: Int, es: DefaultEntry[K, V]) = new ParHashMapIterator(idxFrom, idxUntil, totalSz, es) } protected def createNewEntry[V1](key: K, value: V1): Entry = { new Entry(key, value.asInstanceOf[V]) } private def writeObject(out: java.io.ObjectOutputStream) { serializeTo(out, { entry => out.writeObject(entry.key) out.writeObject(entry.value) }) } private def readObject(in: java.io.ObjectInputStream) { init(in, createNewEntry(in.readObject().asInstanceOf[K], in.readObject())) } private[parallel] override def brokenInvariants = { // bucket by bucket, count elements val buckets = for (i <- 0 until (table.length / sizeMapBucketSize)) yield checkBucket(i) // check if each element is in the position corresponding to its key val elems = for (i <- 0 until table.length) yield checkEntry(i) buckets.flatMap(x => x) ++ elems.flatMap(x => x) } private def checkBucket(i: Int) = { def count(e: HashEntry[K, DefaultEntry[K, V]]): Int = if (e eq null) 0 else 1 + count(e.next) val expected = sizemap(i) val found = ((i * sizeMapBucketSize) until ((i + 1) * sizeMapBucketSize)).foldLeft(0) { (acc, c) => acc + count(table(c)) } if (found != expected) List("Found " + found + " elements, while sizemap showed " + expected) else Nil } private def checkEntry(i: Int) = { def check(e: HashEntry[K, DefaultEntry[K, V]]): List[String] = if (e eq null) Nil else if (index(elemHashCode(e.key)) == i) check(e.next) else ("Element " + e.key + " at " + i + " with " + elemHashCode(e.key) + " maps to " + index(elemHashCode(e.key))) :: check(e.next) check(table(i)) } } /** $factoryInfo * @define Coll `mutable.ParHashMap` * @define coll parallel hash map */ object ParHashMap extends ParMapFactory[ParHashMap] { var iters = 0 def empty[K, V]: ParHashMap[K, V] = new ParHashMap[K, V] def newCombiner[K, V]: Combiner[(K, V), ParHashMap[K, V]] = ParHashMapCombiner.apply[K, V] implicit def canBuildFrom[K, V]: CanCombineFrom[Coll, (K, V), ParHashMap[K, V]] = new CanCombineFromMap[K, V] } private[mutable] abstract class ParHashMapCombiner[K, V](private val tableLoadFactor: Int) extends scala.collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntry[K, V], ParHashMapCombiner[K, V]](ParHashMapCombiner.numblocks) with scala.collection.mutable.HashTable.HashUtils[K] { private var mask = ParHashMapCombiner.discriminantmask private var nonmasklen = ParHashMapCombiner.nonmasklength private var seedvalue = 27 def +=(elem: (K, V)) = { sz += 1 val hc = improve(elemHashCode(elem._1), seedvalue) val pos = (hc >>> nonmasklen) if (buckets(pos) eq null) { // initialize bucket buckets(pos) = new UnrolledBuffer[DefaultEntry[K, V]]() } // add to bucket buckets(pos) += new DefaultEntry(elem._1, elem._2) this } def result: ParHashMap[K, V] = if (size >= (ParHashMapCombiner.numblocks * sizeMapBucketSize)) { // 1024 // construct table val table = new AddingHashTable(size, tableLoadFactor, seedvalue) val bucks = buckets.map(b => if (b ne null) b.headPtr else null) val insertcount = combinerTaskSupport.executeAndWaitResult(new FillBlocks(bucks, table, 0, bucks.length)) table.setSize(insertcount) // TODO compare insertcount and size to see if compression is needed val c = table.hashTableContents new ParHashMap(c) } else { // construct a normal table and fill it sequentially // TODO parallelize by keeping separate sizemaps and merging them object table extends HashTable[K, DefaultEntry[K, V]] { type Entry = DefaultEntry[K, V] def insertEntry(e: Entry) { super.findOrAddEntry(e.key, e) } def createNewEntry[E](key: K, entry: E): Entry = entry.asInstanceOf[Entry] sizeMapInit(table.length) } var i = 0 while (i < ParHashMapCombiner.numblocks) { if (buckets(i) ne null) { for (elem <- buckets(i)) table.insertEntry(elem) } i += 1 } new ParHashMap(table.hashTableContents) } /* classes */ /** A hash table which will never resize itself. Knowing the number of elements in advance, * it allocates the table of the required size when created. * * Entries are added using the `insertEntry` method. This method checks whether the element * exists and updates the size map. It returns false if the key was already in the table, * and true if the key was successfully inserted. It does not update the number of elements * in the table. */ private[ParHashMapCombiner] class AddingHashTable(numelems: Int, lf: Int, _seedvalue: Int) extends HashTable[K, DefaultEntry[K, V]] { import HashTable._ _loadFactor = lf table = new Array[HashEntry[K, DefaultEntry[K, V]]](capacity(sizeForThreshold(_loadFactor, numelems))) tableSize = 0 seedvalue = _seedvalue threshold = newThreshold(_loadFactor, table.length) sizeMapInit(table.length) def setSize(sz: Int) = tableSize = sz def insertEntry(/*block: Int, */e: DefaultEntry[K, V]) = { var h = index(elemHashCode(e.key)) // assertCorrectBlock(h, block) var olde = table(h).asInstanceOf[DefaultEntry[K, V]] // check if key already exists var ce = olde while (ce ne null) { if (ce.key == e.key) { h = -1 ce = null } else ce = ce.next } // if key does not already exist if (h != -1) { e.next = olde table(h) = e nnSizeMapAdd(h) true } else false } private def assertCorrectBlock(h: Int, block: Int) { val blocksize = table.length / (1 << ParHashMapCombiner.discriminantbits) if (!(h >= block * blocksize && h < (block + 1) * blocksize)) { println("trying to put " + h + " into block no.: " + block + ", range: [" + block * blocksize + ", " + (block + 1) * blocksize + ">") assert(h >= block * blocksize && h < (block + 1) * blocksize) } } protected def createNewEntry[X](key: K, x: X) = ??? } /* tasks */ import UnrolledBuffer.Unrolled class FillBlocks(buckets: Array[Unrolled[DefaultEntry[K, V]]], table: AddingHashTable, offset: Int, howmany: Int) extends Task[Int, FillBlocks] { var result = Int.MinValue def leaf(prev: Option[Int]) = { var i = offset val until = offset + howmany result = 0 while (i < until) { result += fillBlock(i, buckets(i)) i += 1 } } private def fillBlock(block: Int, elems: Unrolled[DefaultEntry[K, V]]) = { var insertcount = 0 var unrolled = elems var i = 0 val t = table while (unrolled ne null) { val chunkarr = unrolled.array val chunksz = unrolled.size while (i < chunksz) { val elem = chunkarr(i) // assertCorrectBlock(block, elem.key) if (t.insertEntry(elem)) insertcount += 1 i += 1 } i = 0 unrolled = unrolled.next } insertcount } private def assertCorrectBlock(block: Int, k: K) { val hc = improve(elemHashCode(k), seedvalue) if ((hc >>> nonmasklen) != block) { println(hc + " goes to " + (hc >>> nonmasklen) + ", while expected block is " + block) assert((hc >>> nonmasklen) == block) } } def split = { val fp = howmany / 2 List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp)) } override def merge(that: FillBlocks) { this.result += that.result } def shouldSplitFurther = howmany > scala.collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, combinerTaskSupport.parallelismLevel) } } private[parallel] object ParHashMapCombiner { private[mutable] val discriminantbits = 5 private[mutable] val numblocks = 1 << discriminantbits private[mutable] val discriminantmask = ((1 << discriminantbits) - 1); private[mutable] val nonmasklength = 32 - discriminantbits def apply[K, V] = new ParHashMapCombiner[K, V](HashTable.defaultLoadFactor) {} // was: with EnvironmentPassingCombiner[(K, V), ParHashMap[K, V]] }