package play.api.libs.ws.ning
import java.security.KeyStore
import java.security.cert.CertPathValidatorException
import javax.inject.{ Singleton, Inject, Provider }
import org.slf4j.LoggerFactory
import com.ning.http.client.AsyncHttpClientConfig
import javax.net.ssl._
import play.api.{ PlayConfig, Environment, Configuration }
import play.api.libs.ws.ssl._
import play.api.libs.ws.WSClientConfig
import scala.concurrent.duration._
case class NingWSClientConfig(wsClientConfig: WSClientConfig = WSClientConfig(),
allowPoolingConnection: Boolean = true,
allowSslConnectionPool: Boolean = true,
ioThreadMultiplier: Int = 2,
maxConnectionsPerHost: Int = -1,
maxConnectionsTotal: Int = -1,
maxConnectionLifetime: Duration = Duration.Inf,
idleConnectionInPoolTimeout: Duration = 1.minute,
webSocketIdleTimeout: Duration = 15.minutes,
maxNumberOfRedirects: Int = 5,
maxRequestRetry: Int = 5,
disableUrlEncoding: Boolean = false)
object NingWSClientConfigFactory {
def forClientConfig(config: WSClientConfig) = {
NingWSClientConfig(wsClientConfig = config)
}
}
@Singleton
class NingWSClientConfigParser @Inject() (wsClientConfig: WSClientConfig,
configuration: Configuration,
environment: Environment) extends Provider[NingWSClientConfig] {
def get = parse()
def parse(): NingWSClientConfig = {
val config = PlayConfig(configuration).get[PlayConfig]("play.ws.ning")
val allowPoolingConnection = config.get[Boolean]("allowPoolingConnection")
val allowSslConnectionPool = config.get[Boolean]("allowSslConnectionPool")
val ioThreadMultiplier = config.get[Int]("ioThreadMultiplier")
val maximumConnectionsPerHost = config.get[Int]("maxConnectionsPerHost")
val maximumConnectionsTotal = config.get[Int]("maxConnectionsTotal")
val maxConnectionLifetime = config.get[Duration]("maxConnectionLifetime")
val idleConnectionInPoolTimeout = config.get[Duration]("idleConnectionInPoolTimeout")
val webSocketIdleTimeout = config.get[Duration]("webSocketIdleTimeout")
val maximumNumberOfRedirects = config.get[Int]("maxNumberOfRedirects")
val maxRequestRetry = config.get[Int]("maxRequestRetry")
val disableUrlEncoding = config.get[Boolean]("disableUrlEncoding")
NingWSClientConfig(
wsClientConfig = wsClientConfig,
allowPoolingConnection = allowPoolingConnection,
allowSslConnectionPool = allowSslConnectionPool,
ioThreadMultiplier = ioThreadMultiplier,
maxConnectionsPerHost = maximumConnectionsPerHost,
maxConnectionsTotal = maximumConnectionsTotal,
maxConnectionLifetime = maxConnectionLifetime,
idleConnectionInPoolTimeout = idleConnectionInPoolTimeout,
webSocketIdleTimeout = webSocketIdleTimeout,
maxNumberOfRedirects = maximumNumberOfRedirects,
maxRequestRetry = maxRequestRetry,
disableUrlEncoding = disableUrlEncoding
)
}
}
class NingAsyncHttpClientConfigBuilder(ningConfig: NingWSClientConfig = NingWSClientConfig()) {
@deprecated("Use NingAsyncHttpClientConfigBuilder(NingWSClientConfig)", "2.4")
def this(config: WSClientConfig) =
this(NingWSClientConfig(wsClientConfig = config))
protected val addCustomSettings: AsyncHttpClientConfig.Builder => AsyncHttpClientConfig.Builder = identity
val builder: AsyncHttpClientConfig.Builder = new AsyncHttpClientConfig.Builder()
private[ning] val logger = LoggerFactory.getLogger(this.getClass)
def configure(): AsyncHttpClientConfig.Builder = {
val config = ningConfig.wsClientConfig
configureWS(ningConfig)
configureSSL(config.ssl)
addCustomSettings(builder)
}
def build(): AsyncHttpClientConfig = {
configure().build()
}
def modifyUnderlying(
modify: AsyncHttpClientConfig.Builder => AsyncHttpClientConfig.Builder): NingAsyncHttpClientConfigBuilder = {
new NingAsyncHttpClientConfigBuilder(ningConfig) {
override val addCustomSettings = modify compose NingAsyncHttpClientConfigBuilder.this.addCustomSettings
override val builder = NingAsyncHttpClientConfigBuilder.this.builder
}
}
def configureWS(ningConfig: NingWSClientConfig): Unit = {
val config = ningConfig.wsClientConfig
def toMillis(duration: Duration): Int = {
if (duration.isFinite()) duration.toMillis.toInt
else -1
}
builder.setConnectTimeout(toMillis(config.connectionTimeout))
.setReadTimeout(toMillis(config.idleTimeout))
.setRequestTimeout(toMillis(config.requestTimeout))
.setFollowRedirect(config.followRedirects)
.setUseProxyProperties(config.useProxyProperties)
.setCompressionEnforced(config.compressionEnabled)
config.userAgent foreach builder.setUserAgent
builder.setAllowPoolingConnections(ningConfig.allowPoolingConnection)
builder.setAllowPoolingSslConnections(ningConfig.allowSslConnectionPool)
builder.setIOThreadMultiplier(ningConfig.ioThreadMultiplier)
builder.setMaxConnectionsPerHost(ningConfig.maxConnectionsPerHost)
builder.setMaxConnections(ningConfig.maxConnectionsTotal)
builder.setConnectionTTL(toMillis(ningConfig.maxConnectionLifetime))
builder.setPooledConnectionIdleTimeout(toMillis(ningConfig.idleConnectionInPoolTimeout))
builder.setWebSocketTimeout(toMillis(ningConfig.webSocketIdleTimeout))
builder.setMaxRedirects(ningConfig.maxNumberOfRedirects)
builder.setMaxRequestRetry(ningConfig.maxRequestRetry)
builder.setDisableUrlEncodingForBoundedRequests(ningConfig.disableUrlEncoding)
}
@deprecated("Use configureWS(NingWSClientConfig)", "2.4")
def configureWS(config: WSClientConfig): Unit = configureWS(NingWSClientConfig(wsClientConfig = config))
def configureProtocols(existingProtocols: Array[String], sslConfig: SSLConfig): Array[String] = {
val definedProtocols = sslConfig.enabledProtocols match {
case Some(configuredProtocols) =>
configuredProtocols.filter(existingProtocols.contains).toArray
case None =>
Protocols.recommendedProtocols.filter(existingProtocols.contains).toArray
}
if (!sslConfig.loose.allowWeakProtocols) {
val deprecatedProtocols = Protocols.deprecatedProtocols
for (deprecatedProtocol <- deprecatedProtocols) {
if (definedProtocols.contains(deprecatedProtocol)) {
throw new IllegalStateException(s"Weak protocol $deprecatedProtocol found in ws.ssl.protocols!")
}
}
}
definedProtocols
}
def configureCipherSuites(existingCiphers: Array[String], sslConfig: SSLConfig): Array[String] = {
val definedCiphers = sslConfig.enabledCipherSuites match {
case Some(configuredCiphers) =>
configuredCiphers.filter(existingCiphers.contains(_)).toArray
case None =>
Ciphers.recommendedCiphers.filter(existingCiphers.contains(_)).toArray
}
if (!sslConfig.loose.allowWeakCiphers) {
val deprecatedCiphers = Ciphers.deprecatedCiphers
for (deprecatedCipher <- deprecatedCiphers) {
if (definedCiphers.contains(deprecatedCipher)) {
throw new IllegalStateException(s"Weak cipher $deprecatedCipher found in ws.ssl.ciphers!")
}
}
}
definedCiphers
}
def configureSSL(sslConfig: SSLConfig) {
val sslContext = if (sslConfig.default) {
logger.info("buildSSLContext: ws.ssl.default is true, using default SSLContext")
validateDefaultTrustManager(sslConfig)
SSLContext.getDefault
} else {
val keyManagerFactory = buildKeyManagerFactory(sslConfig)
val trustManagerFactory = buildTrustManagerFactory(sslConfig)
new ConfigSSLContextBuilder(sslConfig, keyManagerFactory, trustManagerFactory).build()
}
val defaultParams = sslContext.getDefaultSSLParameters
val defaultProtocols = defaultParams.getProtocols
val protocols = configureProtocols(defaultProtocols, sslConfig)
defaultParams.setProtocols(protocols)
builder.setEnabledProtocols(protocols)
val defaultCiphers = defaultParams.getCipherSuites
val cipherSuites = configureCipherSuites(defaultCiphers, sslConfig)
defaultParams.setCipherSuites(cipherSuites)
builder.setEnabledCipherSuites(cipherSuites)
builder.setAcceptAnyCertificate(sslConfig.loose.acceptAnyCertificate)
if (!sslConfig.loose.disableHostnameVerification) {
val hostnameVerifier = buildHostnameVerifier(sslConfig)
builder.setHostnameVerifier(hostnameVerifier)
} else {
logger.warn("buildHostnameVerifier: disabling hostname verification")
val disabledHostnameVerifier = new DisabledComplainingHostnameVerifier
builder.setHostnameVerifier(disabledHostnameVerifier)
}
builder.setSSLContext(sslContext)
}
def buildKeyManagerFactory(ssl: SSLConfig): KeyManagerFactoryWrapper = {
new DefaultKeyManagerFactoryWrapper(ssl.keyManagerConfig.algorithm)
}
def buildTrustManagerFactory(ssl: SSLConfig): TrustManagerFactoryWrapper = {
new DefaultTrustManagerFactoryWrapper(ssl.trustManagerConfig.algorithm)
}
def buildHostnameVerifier(sslConfig: SSLConfig): HostnameVerifier = {
logger.debug("buildHostnameVerifier: enabling hostname verification using {}", sslConfig.hostnameVerifierClass)
try {
sslConfig.hostnameVerifierClass.newInstance()
} catch {
case e: Exception =>
throw new IllegalStateException("Cannot configure hostname verifier", e)
}
}
def validateDefaultTrustManager(sslConfig: SSLConfig) {
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
tmf.init(null.asInstanceOf[KeyStore])
val trustManager: X509TrustManager = tmf.getTrustManagers()(0).asInstanceOf[X509TrustManager]
val constraints = sslConfig.disabledKeyAlgorithms.map(a => AlgorithmConstraintsParser.parseAll(AlgorithmConstraintsParser.expression, a).get).toSet
val algorithmChecker = new AlgorithmChecker(keyConstraints = constraints, signatureConstraints = Set())
for (cert <- trustManager.getAcceptedIssuers) {
try {
algorithmChecker.checkKeyAlgorithms(cert)
} catch {
case e: CertPathValidatorException =>
logger.warn("You are using ws.ssl.default=true and have a weak certificate in your default trust store! (You can modify ws.ssl.disabledKeyAlgorithms to remove this message.)", e)
}
}
}
}