package play.core.server.netty
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http._
import play.api._
import play.api.libs.iteratee._
import play.api.libs.iteratee.Input._
import scala.concurrent.{ Future, Promise }
import scala.util.{ Try, Success }
private[server] trait RequestBodyHandler {
import RequestBodyHandler._
def newRequestBodyUpstreamHandler[A](bodyHandler: Iteratee[Array[Byte], A],
replaceHandler: ChannelUpstreamHandler => Unit,
handlerFinished: => Unit): Future[A] = {
implicit val internalContext = play.core.Execution.internalContext
import scala.concurrent.stm._
val bodyHandlerResult = Promise[Iteratee[Array[Byte], A]]()
val MaxMessages = 10
val MinMessages = 10
val counter = Ref(0)
val iteratee: Ref[Iteratee[Array[Byte], A]] = Ref(bodyHandler)
def pushChunk(ctx: ChannelHandlerContext, chunk: Input[Array[Byte]]) {
if (counter.single.transformAndGet { _ + 1 } > MaxMessages && ctx.getChannel.isOpen && !bodyHandlerResult.isCompleted)
ctx.getChannel.setReadable(false)
val itPromise = Promise[Iteratee[Array[Byte], A]]()
val current = atomic { implicit txn =>
if (!bodyHandlerResult.isCompleted) {
Some(iteratee.single.swap(Iteratee.flatten(itPromise.future)))
} else {
if (chunk != Input.EOF) {
replaceHandler(new IgnoreBodyHandler(handlerFinished))
}
None
}
}
current.foreach { currentIteratee =>
currentIteratee.feed(chunk).flatMap(_.unflatten).onComplete {
case Success(c @ Step.Cont(k)) =>
continue(c.it)
case done =>
finish(done.map(_.it))
}
}
def continue(it: Iteratee[Array[Byte], A]) {
if (counter.single.transformAndGet { _ - 1 } <= MinMessages && ctx.getChannel.isOpen)
ctx.getChannel.setReadable(true)
itPromise.success(it)
}
def finish(result: Try[Iteratee[Array[Byte], A]]) {
if (!bodyHandlerResult.tryComplete(result)) {
if (ctx.getChannel.isOpen) ctx.getChannel.setReadable(true)
}
itPromise.complete(result)
}
}
replaceHandler(new SimpleChannelUpstreamHandler {
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case chunk: HttpChunk if !chunk.isLast =>
val cBuffer = chunk.getContent
val bytes = new Array[Byte](cBuffer.readableBytes())
cBuffer.readBytes(bytes)
pushChunk(ctx, El(bytes))
case chunk: HttpChunk if chunk.isLast => {
pushChunk(ctx, EOF)
handlerFinished
}
case unexpected =>
logger.error("Oops, unexpected message received in NettyServer/RequestBodyHandler" +
" (please report this problem): " + unexpected)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
logger.error("Exception caught in RequestBodyHandler", e.getCause)
e.getChannel.close()
}
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
pushChunk(ctx, EOF)
}
})
bodyHandlerResult.future.flatMap(_.run)
}
private class IgnoreBodyHandler(handlerFinished: => Unit) extends SimpleChannelUpstreamHandler {
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case chunk: HttpChunk => {
if (chunk.isLast) handlerFinished
}
case unexpected =>
logger.error("Oops, unexpected message received in NettyServer/IgnoreBodyHandler" +
" (please report this problem): " + unexpected)
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
logger.error("Exception caught in IgnoreBodyHandler", e.getCause)
e.getChannel.close()
}
}
}
object RequestBodyHandler {
private val logger = Logger(classOf[RequestBodyHandler])
}