package play.core.server.netty
import scala.language.reflectiveCalls
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http._
import org.jboss.netty.handler.codec.http.websocketx._
import play.core._
import play.core.websocket._
import play.core.server.websocket.WebSocketHandshake
import play.api._
import play.api.mvc.WebSocket.FrameFormatter
import play.api.libs.iteratee._
import play.api.libs.iteratee.Input._
import scala.concurrent.{ Future, Promise }
import scala.concurrent.stm._
import play.core.Execution.Implicits.internalContext
import org.jboss.netty.buffer.{ ChannelBuffers, ChannelBuffer }
import java.util.concurrent.atomic.AtomicInteger
private[server] trait WebSocketHandler {
import NettyFuture._
import WebSocketHandler._
val WebSocketNormalClose = 1000
val WebSocketUnacceptable = 1003
val WebSocketMessageTooLong = 1009
private val MaxInFlight = 3
def newWebSocketInHandler[A](frameFormatter: FrameFormatter[A], bufferLimit: Long): (Enumerator[A], ChannelHandler) = {
val basicFrameFormatter = frameFormatter.asInstanceOf[BasicFrameFormatter[A]]
def fromNettyFrame(nettyFrame: WebSocketFrame): A = nettyFrame match {
case nettyTextFrame: TextWebSocketFrame =>
val basicFrame = TextFrame(nettyTextFrame.getText)
basicFrameFormatter.fromFrame(basicFrame)
case nettyBinaryFrame: BinaryWebSocketFrame =>
val bytes = channelBufferToArray(nettyBinaryFrame.getBinaryData)
val basicFrame = BinaryFrame(bytes)
basicFrameFormatter.fromFrame(basicFrame)
}
def definedForNettyFrame(nettyFrame: WebSocketFrame): Boolean = nettyFrame match {
case _: TextWebSocketFrame => basicFrameFormatter.fromFrameDefined(classOf[TextFrame])
case _: BinaryWebSocketFrame => basicFrameFormatter.fromFrameDefined(classOf[BinaryFrame])
case _ => false
}
val enumerator = new WebSocketEnumerator[A]
(enumerator,
new SimpleChannelUpstreamHandler {
type FrameCreator = ChannelBuffer => WebSocketFrame
private var continuationBuffer: Option[(FrameCreator, ChannelBuffer)] = None
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
(e.getMessage, continuationBuffer) match {
case (frame: ContinuationWebSocketFrame, Some((_, buffer))) if frame.getBinaryData.readableBytes() + buffer.readableBytes() > bufferLimit =>
closeWebSocket(ctx, WebSocketMessageTooLong, "Fragmented message too long, configured limit is " + bufferLimit)
case (frame: ContinuationWebSocketFrame, Some((_, buffer))) if !frame.isFinalFragment =>
buffer.writeBytes(frame.getBinaryData)
case (frame: ContinuationWebSocketFrame, Some((creator, buffer))) =>
buffer.writeBytes(frame.getBinaryData)
continuationBuffer = None
val finalFrame = creator(buffer)
val basicFrame = finalFrame
enumerator.frameReceived(ctx, El(fromNettyFrame(finalFrame)))
case (frame: TextWebSocketFrame, None) if !frame.isFinalFragment && definedForNettyFrame(frame) =>
val buffer = ChannelBuffers.dynamicBuffer(Math.min(frame.getBinaryData.readableBytes() * 2, bufferLimit.asInstanceOf[Int]))
buffer.writeBytes(frame.getBinaryData)
continuationBuffer = Some((b => new TextWebSocketFrame(true, frame.getRsv, buffer), buffer))
case (frame: BinaryWebSocketFrame, None) if !frame.isFinalFragment && definedForNettyFrame(frame) =>
val buffer = ChannelBuffers.dynamicBuffer(Math.min(frame.getBinaryData.readableBytes() * 2, bufferLimit.asInstanceOf[Int]))
buffer.writeBytes(frame.getBinaryData)
continuationBuffer = Some((b => new BinaryWebSocketFrame(true, frame.getRsv, buffer), buffer))
case (frame: WebSocketFrame, None) if definedForNettyFrame(frame) =>
enumerator.frameReceived(ctx, El(fromNettyFrame(frame)))
case (frame: CloseWebSocketFrame, _) =>
closeWebSocket(ctx, frame.getStatusCode, "")
case (frame: PingWebSocketFrame, _) =>
ctx.getChannel.write(new PongWebSocketFrame(frame.getBinaryData))
case (frame: PongWebSocketFrame, _) =>
case (frame: WebSocketFrame, _) =>
closeWebSocket(ctx, WebSocketUnacceptable, "This WebSocket does not handle frames of that type")
case _ =>
}
}
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
e.getCause.printStackTrace()
e.getChannel.close()
}
override def channelDisconnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
enumerator.frameReceived(ctx, EOF)
logger.trace("disconnected socket")
}
private def closeWebSocket(ctx: ChannelHandlerContext, status: Int, reason: String): Unit = {
if (!reason.isEmpty) {
logger.trace("Closing WebSocket because " + reason)
}
if (ctx.getChannel.isOpen) {
for {
_ <- ctx.getChannel.write(new CloseWebSocketFrame(status, reason)).toScala
_ <- ctx.getChannel.close().toScala
} yield {
enumerator.frameReceived(ctx, EOF)
}
}
}
})
}
private class WebSocketEnumerator[A] extends Enumerator[A] {
val eventuallyIteratee = Promise[Iteratee[A, Any]]()
val iterateeRef = Ref[Iteratee[A, Any]](Iteratee.flatten(eventuallyIteratee.future))
private val promise: scala.concurrent.Promise[Iteratee[A, Any]] = Promise[Iteratee[A, Any]]()
private val inFlight = new AtomicInteger(0)
def apply[R](i: Iteratee[A, R]) = {
eventuallyIteratee.success(i)
promise.asInstanceOf[scala.concurrent.Promise[Iteratee[A, R]]].future
}
def setReadable(channel: Channel, readable: Boolean) {
if (channel.isOpen) {
channel.setReadable(readable)
}
}
def frameReceived(ctx: ChannelHandlerContext, input: Input[A]) {
val channel = ctx.getChannel
if (inFlight.incrementAndGet() >= MaxInFlight) {
setReadable(channel, false)
}
val eventuallyNext = Promise[Iteratee[A, Any]]()
val current = iterateeRef.single.swap(Iteratee.flatten(eventuallyNext.future))
val next = current.flatFold(
(a, e) => {
setReadable(channel, true)
Future.successful(current)
},
k => {
if (inFlight.decrementAndGet() < MaxInFlight) {
setReadable(channel, true)
}
val next = k(input)
next.fold {
case Step.Done(a, e) =>
promise.success(next)
if (channel.isOpen) {
for {
_ <- channel.write(new CloseWebSocketFrame(WebSocketNormalClose, "")).toScala
_ <- channel.close().toScala
} yield next
} else {
Future.successful(next)
}
case Step.Cont(_) =>
Future.successful(next)
case Step.Error(msg, e) =>
Future.successful(next)
}
},
(err, e) => {
setReadable(channel, true)
Future.successful(current)
})
eventuallyNext.success(next)
}
}
private def channelBufferToArray(buffer: ChannelBuffer) = {
if (buffer.readableBytes() == buffer.capacity()) {
buffer.array()
} else {
val bytes = new Array[Byte](buffer.readableBytes())
buffer.readBytes(bytes)
bytes
}
}
def websocketHandshake[A](ctx: ChannelHandlerContext, req: HttpRequest, e: MessageEvent, bufferLimit: Long)(frameFormatter: FrameFormatter[A]): Enumerator[A] = {
val (enumerator, handler) = newWebSocketInHandler(frameFormatter, bufferLimit)
val p: ChannelPipeline = ctx.getChannel.getPipeline
p.replace("handler", "handler", handler)
WebSocketHandshake.shake(ctx, req, bufferLimit)
enumerator
}
def websocketable(req: HttpRequest) = new server.WebSocketable {
def check = HttpHeaders.Values.WEBSOCKET.equalsIgnoreCase(req.headers().get(HttpHeaders.Names.UPGRADE))
def getHeader(header: String) = req.headers().get(header)
}
}
object WebSocketHandler {
private val logger = Logger(classOf[WebSocketHandler])
}