package kafka.message
import kafka.utils.Logging
import java.nio.ByteBuffer
import java.nio.channels._
import java.io.{InputStream, ByteArrayOutputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicLong
import kafka.utils.IteratorTemplate
object ByteBufferMessageSet {
private def create(offsetCounter: AtomicLong, compressionCodec: CompressionCodec, messages: Message*): ByteBuffer = {
if(messages.size == 0) {
MessageSet.Empty.buffer
} else if(compressionCodec == NoCompressionCodec) {
val buffer = ByteBuffer.allocate(MessageSet.messageSetSize(messages))
for(message <- messages)
writeMessage(buffer, message, offsetCounter.getAndIncrement)
buffer.rewind()
buffer
} else {
val byteArrayStream = new ByteArrayOutputStream(MessageSet.messageSetSize(messages))
val output = new DataOutputStream(CompressionFactory(compressionCodec, byteArrayStream))
var offset = -1L
try {
for(message <- messages) {
offset = offsetCounter.getAndIncrement
output.writeLong(offset)
output.writeInt(message.size)
output.write(message.buffer.array, message.buffer.arrayOffset, message.buffer.limit)
}
} finally {
output.close()
}
val bytes = byteArrayStream.toByteArray
val message = new Message(bytes, compressionCodec)
val buffer = ByteBuffer.allocate(message.size + MessageSet.LogOverhead)
writeMessage(buffer, message, offset)
buffer.rewind()
buffer
}
}
def decompress(message: Message): ByteBufferMessageSet = {
val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream
val inputStream: InputStream = new ByteBufferBackedInputStream(message.payload)
val intermediateBuffer = new Array[Byte](1024)
val compressed = CompressionFactory(message.compressionCodec, inputStream)
try {
Stream.continually(compressed.read(intermediateBuffer)).takeWhile(_ > 0).foreach { dataRead =>
outputStream.write(intermediateBuffer, 0, dataRead)
}
} finally {
compressed.close()
}
val outputBuffer = ByteBuffer.allocate(outputStream.size)
outputBuffer.put(outputStream.toByteArray)
outputBuffer.rewind
new ByteBufferMessageSet(outputBuffer)
}
private[kafka] def writeMessage(buffer: ByteBuffer, message: Message, offset: Long) {
buffer.putLong(offset)
buffer.putInt(message.size)
buffer.put(message.buffer)
message.buffer.rewind()
}
}
class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Logging {
private var shallowValidByteCount = -1
def this(compressionCodec: CompressionCodec, messages: Message*) {
this(ByteBufferMessageSet.create(new AtomicLong(0), compressionCodec, messages:_*))
}
def this(compressionCodec: CompressionCodec, offsetCounter: AtomicLong, messages: Message*) {
this(ByteBufferMessageSet.create(offsetCounter, compressionCodec, messages:_*))
}
def this(messages: Message*) {
this(NoCompressionCodec, new AtomicLong(0), messages: _*)
}
def getBuffer = buffer
private def shallowValidBytes: Int = {
if(shallowValidByteCount < 0) {
var bytes = 0
val iter = this.internalIterator(true)
while(iter.hasNext) {
val messageAndOffset = iter.next
bytes += MessageSet.entrySize(messageAndOffset.message)
}
this.shallowValidByteCount = bytes
}
shallowValidByteCount
}
def writeTo(channel: GatheringByteChannel, offset: Long, size: Int): Int = {
buffer.mark()
var written = 0
while(written < sizeInBytes)
written += channel.write(buffer)
buffer.reset()
written
}
override def iterator: Iterator[MessageAndOffset] = internalIterator()
def shallowIterator: Iterator[MessageAndOffset] = internalIterator(true)
private def internalIterator(isShallow: Boolean = false): Iterator[MessageAndOffset] = {
new IteratorTemplate[MessageAndOffset] {
var topIter = buffer.slice()
var innerIter: Iterator[MessageAndOffset] = null
def innerDone():Boolean = (innerIter == null || !innerIter.hasNext)
def makeNextOuter: MessageAndOffset = {
if (topIter.remaining < 12)
return allDone()
val offset = topIter.getLong()
val size = topIter.getInt()
if(size < Message.MinHeaderSize)
throw new InvalidMessageException("Message found with corrupt size (" + size + ")")
if(topIter.remaining < size)
return allDone()
val message = topIter.slice()
message.limit(size)
topIter.position(topIter.position + size)
val newMessage = new Message(message)
if(isShallow) {
new MessageAndOffset(newMessage, offset)
} else {
newMessage.compressionCodec match {
case NoCompressionCodec =>
innerIter = null
new MessageAndOffset(newMessage, offset)
case _ =>
innerIter = ByteBufferMessageSet.decompress(newMessage).internalIterator()
if(!innerIter.hasNext)
innerIter = null
makeNext()
}
}
}
override def makeNext(): MessageAndOffset = {
if(isShallow){
makeNextOuter
} else {
if(innerDone())
makeNextOuter
else
innerIter.next
}
}
}
}
private[kafka] def assignOffsets(offsetCounter: AtomicLong, codec: CompressionCodec): ByteBufferMessageSet = {
if(codec == NoCompressionCodec) {
var position = 0
buffer.mark()
while(position < sizeInBytes - MessageSet.LogOverhead) {
buffer.position(position)
buffer.putLong(offsetCounter.getAndIncrement())
position += MessageSet.LogOverhead + buffer.getInt()
}
buffer.reset()
this
} else {
val messages = this.internalIterator(isShallow = false).map(_.message)
new ByteBufferMessageSet(compressionCodec = codec, offsetCounter = offsetCounter, messages = messages.toBuffer:_*)
}
}
def sizeInBytes: Int = buffer.limit
def validBytes: Int = shallowValidBytes
override def equals(other: Any): Boolean = {
other match {
case that: ByteBufferMessageSet =>
buffer.equals(that.buffer)
case _ => false
}
}
override def hashCode: Int = buffer.hashCode
}