package play.core.parsers
import java.io.FileOutputStream
import play.api.Play
import play.api.libs.Files.TemporaryFile
import play.api.libs.iteratee.Parsing.MatchInfo
import play.api.libs.iteratee._
import play.api.mvc._
import play.api.mvc.MultipartFormData._
import play.api.http.Status._
import scala.collection.mutable.ListBuffer
import scala.concurrent.Future
import play.api.libs.iteratee.Execution.Implicits.trampoline
object Multipart {
def multipartParser[A](
maxDataLength: Int,
filePartHandler: PartHandler[FilePart[A]]): BodyParser[MultipartFormData[A]] = BodyParser("multipartFormData") { request =>
val maybeBoundary = for {
mt <- request.mediaType
(_, value) <- mt.parameters.find(_._1.equalsIgnoreCase("boundary"))
boundary <- value
} yield ("\r\n--" + boundary).getBytes("utf-8")
maybeBoundary.map { boundary =>
for {
_ <- Traversable.take[Array[Byte]](boundary.size - 2) transform Iteratee.ignore
result <- Parsing.search(boundary) transform parseParts(maxDataLength, filePartHandler)
} yield {
result.right.map { reversed =>
val parts = reversed.reverse
val data = parts.collect {
case DataPart(key, value) => (key, value)
}.groupBy(_._1).mapValues(_.map(_._2))
val files = parts.collect { case file: FilePart[A] => file }
val bad = parts.collect { case bad: BadPart => bad }
val missing = parts.collect {
case missing: MissingFilePart => missing
}
MultipartFormData(data, files, bad, missing)
}
}
}.getOrElse {
Iteratee.flatten(createBadResult("Missing boundary header")(request).
map(r => Done(Left(r))))
}
}
type PartHandler[A] = PartialFunction[Map[String, String], Iteratee[Array[Byte], A]]
def handleFilePartAsTemporaryFile: PartHandler[FilePart[TemporaryFile]] = {
handleFilePart {
case FileInfo(partName, filename, contentType) =>
val tempFile = TemporaryFile("multipartBody", "asTemporaryFile")
import play.core.Execution.Implicits.internalContext
Iteratee.fold[Array[Byte], FileOutputStream](
new java.io.FileOutputStream(tempFile.file)) { (os, data) =>
os.write(data)
os
}(internalContext).map { os =>
os.close()
tempFile
}(internalContext)
}
}
private val CRLF = "\r\n".getBytes
private val CRLFCRLF = CRLF ++ CRLF
private type ParserInput = MatchInfo[Array[Byte]]
private type Parser[T] = Iteratee[ParserInput, T]
private def parseParts(dataPartLimit: Int,
filePartHandler: PartHandler[Part],
parts: List[Part] = Nil): Parser[Either[Result, List[Part]]] = {
parsePart(dataPartLimit, filePartHandler).flatMap {
case None => Done(Right(parts))
case Some(MaxDataPartSizeExceeded(_)) => Done(Left(Results.EntityTooLarge))
case Some(dp @ DataPart(_, value)) =>
for {
_ <- Iteratee.head
result <- parseParts(dataPartLimit - value.length, filePartHandler, dp :: parts)
} yield result
case Some(other: Part) =>
for {
_ <- Iteratee.head
result <- parseParts(dataPartLimit, filePartHandler, other :: parts)
} yield result
}
}
private def parsePart(dataPartLimit: Int,
filePartHandler: PartHandler[Part]): Parser[Option[Part]] = {
val takeUpToBoundary = Enumeratee.takeWhile[ParserInput](!_.isMatch)
val = Traversable.takeUpTo[Array[Byte]](4 * 1024) transform Iteratee.consume[Array[Byte]]()
val : Iteratee[Array[Byte], Option[(Map[String, String], Array[Byte])]] = maxHeaderBuffer.map { =>
val (, ) = Option(buffer).map( => b.splitAt(b.indexOfSlice(CRLFCRLF))).get
val = new String(headerBytes, "utf-8").trim
if (headerString.startsWith("--") || headerString.isEmpty) {
None
} else {
val = headerString.lines.map { =>
val :: = header.trim.split(":").toList
(key.trim.toLowerCase, value.mkString(":").trim)
}.toMap
val = rest.drop(CRLFCRLF.length)
Some((headers, left))
}
}
val readPart: PartHandler[Part] = handleDataPart(dataPartLimit)
.orElse[Map[String, String], Iteratee[Array[Byte], Part]]({
case FileInfoMatcher(partName, fileName, _) if fileName.trim.isEmpty =>
Done(MissingFilePart(partName), Input.Empty)
})
.orElse(filePartHandler)
.orElse({
case => Done(BadPart(headers), Input.Empty)
})
takeUpToBoundary compose Enumeratee.map[MatchInfo[Array[Byte]]](_.content) transform collectHeaders.flatMap {
case Some((, left)) => Iteratee.flatten(readPart(headers).feed(Input.El(left))).map(Some.apply)
case _ => Done(None)
}
}
case class FileInfo(
partName: String,
fileName: String,
contentType: Option[String])
private[play] object FileInfoMatcher {
private def split(str: String): List[String] = {
var buffer = new java.lang.StringBuilder
var escape: Boolean = false
var quote: Boolean = false
val result = new ListBuffer[String]
def addPart() = {
result += buffer.toString.trim
buffer = new java.lang.StringBuilder
}
str foreach {
case '\\' =>
buffer.append('\\')
escape = true
case '"' =>
buffer.append('"')
if (!escape)
quote = !quote
escape = false
case ';' =>
if (!quote) {
addPart()
} else {
buffer.append(';')
}
escape = false
case c =>
buffer.append(c)
escape = false
}
addPart()
result.toList
}
def unapply(: Map[String, String]): Option[(String, String, Option[String])] = {
val KeyValue = """^([a-zA-Z_0-9]+)="(.*)"$""".r
for {
values <- headers.get("content-disposition").
map(split(_).map(_.trim).map {
case KeyValue(key, v) =>
(key.trim, v.trim.replaceAll("""\\"""", "\""))
case key => (key.trim, "")
}.toMap)
_ <- values.get("form-data")
partName <- values.get("name")
fileName <- values.get("filename")
contentType = headers.get("content-type")
} yield (partName, fileName, contentType)
}
}
def handleFilePart[A](handler: FileInfo => Iteratee[Array[Byte], A]): PartHandler[FilePart[A]] = {
case FileInfoMatcher(partName, fileName, contentType) =>
val safeFileName = fileName.split('\\').takeRight(1).mkString
handler(FileInfo(partName, safeFileName, contentType)).
map(a => FilePart(partName, safeFileName, contentType, a))
}
private object PartInfoMatcher {
def unapply(: Map[String, String]): Option[String] = {
val KeyValue = """^([a-zA-Z_0-9]+)="(.*)"$""".r
for {
values <- headers.get("content-disposition").map(
_.split(";").map(_.trim).map {
case KeyValue(key, v) => (key.trim, v.trim)
case key => (key.trim, "")
}.toMap)
_ <- values.get("form-data")
partName <- values.get("name")
} yield partName
}
}
private def handleDataPart(maxLength: Int): PartHandler[Part] = {
case headers @ PartInfoMatcher(partName) if !FileInfoMatcher.unapply(headers).isDefined =>
Traversable.takeUpTo[Array[Byte]](maxLength)
.transform(Iteratee.consume[Array[Byte]]().map(bytes => DataPart(partName, new String(bytes, "utf-8"))))
.flatMap { data =>
Cont({
case Input.El(_) => Done(MaxDataPartSizeExceeded(partName), Input.Empty)
case in => Done(data, in)
})
}
}
private def createBadResult(msg: String): RequestHeader => Future[Result] =
{ request =>
Play.maybeApplication.fold(Future.successful[Result](Results.BadRequest))(
_.errorHandler.onClientError(request, BAD_REQUEST, msg))
}
}