package play.core.server.netty
import org.jboss.netty.buffer.ChannelBuffers
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http._
import org.jboss.netty.handler.codec.http.HttpHeaders._
import org.jboss.netty.handler.codec.http.websocketx.{ WebSocketFrame, TextWebSocketFrame, BinaryWebSocketFrame }
import org.jboss.netty.handler.codec.frame.TooLongFrameException
import org.jboss.netty.handler.ssl._
import org.jboss.netty.channel.group._
import play.api._
import play.api.http.{ HttpErrorHandler, DefaultHttpErrorHandler }
import play.api.mvc._
import play.api.libs.iteratee._
import play.api.libs.iteratee.Input._
import play.core.server.Server
import play.core.server.common.{ ForwardedHeaderHandler, ServerRequestUtils, ServerResultUtils }
import play.core.system.RequestIdProvider
import play.core.websocket._
import scala.collection.JavaConverters._
import scala.util.control.Exception
import com.typesafe.netty.http.pipelining.{ OrderedDownstreamChannelEvent, OrderedUpstreamMessageEvent }
import scala.concurrent.Future
import java.net.{ InetSocketAddress, URI }
import java.io.IOException
import org.jboss.netty.handler.codec.http.websocketx.CloseWebSocketFrame
private[play] class PlayDefaultUpstreamHandler(server: Server, allChannels: DefaultChannelGroup) extends SimpleChannelUpstreamHandler with WebSocketHandler with RequestBodyHandler {
import PlayDefaultUpstreamHandler._
private lazy val forwardedHeaderHandler = new ForwardedHeaderHandler(
ForwardedHeaderHandler.ForwardedHeaderHandlerConfig(server.applicationProvider.get.toOption.map(_.configuration)))
private def sendSimpleErrorResponse(ctx: ChannelHandlerContext, status: HttpResponseStatus): Unit = {
val response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status)
response.headers().set(Names.CONNECTION, "close")
response.headers().set(Names.CONTENT_LENGTH, "0")
ctx.getChannel.write(response).addListener(ChannelFutureListener.CLOSE)
}
override def exceptionCaught(ctx: ChannelHandlerContext, event: ExceptionEvent): Unit = {
event.getCause match {
case e: IOException =>
logger.trace("Benign IO exception caught in Netty", e)
event.getChannel.close()
case e: TooLongFrameException =>
logger.warn("Handling TooLongFrameException", e)
sendSimpleErrorResponse(ctx, HttpResponseStatus.REQUEST_URI_TOO_LONG)
case e: IllegalArgumentException if Option(e.getMessage).exists(_.contains("Header value contains a prohibited character")) =>
logger.debug("Handling Header value error", e)
sendSimpleErrorResponse(ctx, HttpResponseStatus.BAD_REQUEST)
case e =>
logger.error("Exception caught in Netty", e)
event.getChannel.close()
}
}
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
Option(ctx.getPipeline.get(classOf[SslHandler])).map { sslHandler =>
sslHandler.handshake()
}
}
override def channelOpen(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
allChannels.add(e.getChannel)
}
override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
e.getMessage match {
case nettyHttpRequest: HttpRequest =>
logger.trace("Http request received by netty: " + nettyHttpRequest)
val websocketableRequest = websocketable(nettyHttpRequest)
var nettyVersion = nettyHttpRequest.getProtocolVersion
val nettyUri = new QueryStringDecoder(nettyHttpRequest.getUri)
val rHeaders: Headers = getHeaders(nettyHttpRequest)
def rRemoteAddress = ServerRequestUtils.findRemoteAddress(
forwardedHeaderHandler,
rHeaders,
connectionRemoteAddress = e.getRemoteAddress.asInstanceOf[InetSocketAddress])
def rSecure = ServerRequestUtils.findSecureProtocol(
forwardedHeaderHandler,
rHeaders,
connectionSecureProtocol = ctx.getPipeline.get(classOf[SslHandler]) != null
)
def tryToCreateRequest = {
val parameters = Map.empty[String, Seq[String]] ++ nettyUri.getParameters.asScala.mapValues(_.asScala)
val path = new URI(nettyUri.getPath).getRawPath
createRequestHeader(path, parameters)
}
def createRequestHeader(parsedPath: String, parameters: Map[String, Seq[String]] = Map.empty[String, Seq[String]]) = {
val untaggedRequestHeader = new RequestHeader {
override val id = RequestIdProvider.requestIDs.incrementAndGet
override val tags = Map.empty[String, String]
override def uri = nettyHttpRequest.getUri
override def path = parsedPath
override def method = nettyHttpRequest.getMethod.getName
override def version = nettyVersion.getText
override def queryString = parameters
override def headers = rHeaders
override lazy val remoteAddress = rRemoteAddress
override lazy val secure = rSecure
}
untaggedRequestHeader
}
val (requestHeader, handler: Either[Future[Result], (Handler, Application)] @unchecked) = Exception
.allCatch[RequestHeader].either(tryToCreateRequest).fold(
e => {
val rh = createRequestHeader(nettyUri.getPath)
val result = Future
.successful(())
.flatMap { _ =>
errorHandler(server.applicationProvider.get.toOption).onClientError(rh, 400, e.getMessage)
}(play.api.libs.iteratee.Execution.trampoline)
(rh, Left(result))
},
rh => server.getHandlerFor(rh) match {
case directResult @ Left(_) => (rh, directResult)
case Right((taggedRequestHeader, handler, application)) => (taggedRequestHeader, Right((handler, application)))
}
)
implicit val msgCtx = ctx
implicit val oue = e.asInstanceOf[OrderedUpstreamMessageEvent]
handler match {
case Right((action: EssentialAction, app)) =>
val a = EssentialAction { rh =>
import play.api.libs.iteratee.Execution.Implicits.trampoline
Iteratee.flatten(action(rh).unflatten.map(_.it).recover {
case error =>
Iteratee.flatten(
app.errorHandler.onServerError(requestHeader, error).map(result => Done(result, Input.Empty))
): Iteratee[Array[Byte], Result]
})
}
handleAction(a, Some(app))
case Right((ws @ WebSocket(f), app)) if websocketableRequest.check =>
logger.trace("Serving this request with: " + ws)
val executed = Future(f(requestHeader))(play.api.libs.concurrent.Execution.defaultContext)
import play.api.libs.iteratee.Execution.Implicits.trampoline
executed.flatMap(identity).map {
case Left(result) =>
val a = EssentialAction(_ => Done(result, Input.Empty))
handleAction(a, Some(app))
case Right(socket) =>
val bufferLimit = app.configuration.getBytes("play.websocket.buffer.limit").getOrElse(65536L)
val enumerator = websocketHandshake(ctx, nettyHttpRequest, e, bufferLimit)(ws.inFormatter)
socket(enumerator, socketOut(ctx)(ws.outFormatter))
}.recover {
case error =>
app.errorHandler.onServerError(requestHeader, error).map { result =>
val a = EssentialAction(_ => Done(result, Input.Empty))
handleAction(a, Some(app))
}
}
case Right((WebSocket(_), app)) =>
logger.trace("Bad websocket request")
val a = EssentialAction(_ => Done(Results.BadRequest, Input.Empty))
handleAction(a, Some(app))
case Left(e) =>
logger.trace("No handler, got direct result: " + e)
import play.api.libs.iteratee.Execution.Implicits.trampoline
val a = EssentialAction(_ => Iteratee.flatten(e.map(result => Done(result, Input.Empty))))
handleAction(a, None)
}
def handleAction(action: EssentialAction, app: Option[Application]) {
logger.trace("Serving this request with: " + action)
val bodyParser = Iteratee.flatten(
scala.concurrent.Future(action(requestHeader))(play.api.libs.concurrent.Execution.defaultContext)
)
import play.api.libs.iteratee.Execution.Implicits.trampoline
val expectContinue: Option[_] = requestHeader.headers.get("Expect").filter(_.equalsIgnoreCase("100-continue"))
val eventuallyResult: Future[Result] = if (nettyHttpRequest.isChunked) {
val pipeline = ctx.getChannel.getPipeline
val result = newRequestBodyUpstreamHandler(bodyParser, { handler =>
pipeline.replace("handler", "handler", handler)
}, {
pipeline.replace("handler", "handler", this)
})
result
} else {
val bodyEnumerator = {
val body = {
val cBuffer = nettyHttpRequest.getContent
val bytes = new Array[Byte](cBuffer.readableBytes())
cBuffer.readBytes(bytes)
bytes
}
Enumerator(body).andThen(Enumerator.enumInput(EOF))
}
bodyEnumerator |>>> bodyParser
}
val eventuallyResultWithSequence: Future[(Result, Int)] = expectContinue match {
case Some(_) => {
bodyParser.unflatten.flatMap {
case Step.Cont(k) =>
sendDownstream(0, false, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE))
eventuallyResult.map((_, 1))
case Step.Done(result, _) => {
Future.successful((result.copy(connection = HttpConnection.Close), 0))
}
case Step.Error(msg, _) => {
e.getChannel.setReadable(true)
val error = new RuntimeException("Body parser iteratee in error: " + msg)
val result = errorHandler(app).onServerError(requestHeader, error)
result.map(r => (r.copy(connection = HttpConnection.Close), 0))
}
}
}
case None => eventuallyResult.map((_, 0))
}
val sent = eventuallyResultWithSequence.recoverWith {
case error =>
logger.error("Cannot invoke the action", error)
e.getChannel.setReadable(true)
errorHandler(app).onServerError(requestHeader, error)
.map((_, 0))
}.flatMap {
case (result, sequence) =>
val cleanedResult = ServerResultUtils.cleanFlashCookie(requestHeader, result)
NettyResultStreamer.sendResult(requestHeader, cleanedResult, nettyVersion, sequence)
}
}
case unexpected => logger.error("Oops, unexpected message received in NettyServer (please report this problem): " + unexpected)
}
}
private def errorHandler(app: Option[Application]) = app.fold[HttpErrorHandler](DefaultHttpErrorHandler)(_.errorHandler)
def socketOut[A](ctx: ChannelHandlerContext)(frameFormatter: play.api.mvc.WebSocket.FrameFormatter[A]): Iteratee[A, Unit] = {
import play.api.libs.iteratee.Execution.Implicits.trampoline
val channel = ctx.getChannel
val basicFrameFormatter = frameFormatter.asInstanceOf[BasicFrameFormatter[A]]
import NettyFuture._
def iteratee: Iteratee[A, _] = Cont {
case El(e) =>
val basicFrame: BasicFrame = basicFrameFormatter.toFrame(e)
val nettyFrame: WebSocketFrame = basicFrame match {
case TextFrame(text) => new TextWebSocketFrame(true, 0, text)
case BinaryFrame(bytes) => new BinaryWebSocketFrame(true, 0, ChannelBuffers.wrappedBuffer(bytes))
}
Iteratee.flatten(channel.write(nettyFrame).toScala.map(_ => iteratee))
case e @ EOF =>
if (channel.isOpen) {
Iteratee.flatten(for {
_ <- channel.write(new CloseWebSocketFrame(WebSocketNormalClose, "")).toScala
_ <- channel.close().toScala
} yield Done((), e))
} else Done((), e)
case Empty => iteratee
}
iteratee.map(_ => ())
}
def getHeaders(nettyRequest: HttpRequest): Headers = {
val pairs = nettyRequest.headers().entries().asScala.map(h => h.getKey -> h.getValue)
new Headers(pairs)
}
def sendDownstream(subSequence: Int, last: Boolean, message: Object)(implicit ctx: ChannelHandlerContext, oue: OrderedUpstreamMessageEvent) = {
val ode = new OrderedDownstreamChannelEvent(oue, subSequence, last, message)
ctx.sendDownstream(ode)
ode.getFuture
}
}
object PlayDefaultUpstreamHandler {
private val logger = Logger(classOf[PlayDefaultUpstreamHandler])
}