package play.api.libs.ws.ssl.debug
import play.api.libs.ws.ssl._
import java.security.AccessController
import scala.util.control.NonFatal
object FixInternalDebugLogging {
private val logger = org.slf4j.LoggerFactory.getLogger("play.api.libs.ws.ssl.debug.FixInternalDebugLogging")
class MonkeyPatchInternalSslDebugAction(val newOptions: String) extends FixLoggingAction {
val logger = org.slf4j.LoggerFactory.getLogger("play.api.libs.ws.ssl.debug.FixInternalDebugLogging.MonkeyPatchInternalSslDebugAction")
val initialResource = foldRuntime(
older = "/javax/net/ssl/SSLContext.class",
newer = "/sun/security/ssl/Debug.class"
)
val debugClassName = foldRuntime(
older = "com.sun.net.ssl.internal.ssl.Debug",
newer = "sun.security.ssl.Debug"
)
def isValidClass(className: String): Boolean = {
if (className.startsWith("com.sun.net.ssl.internal.ssl")) return true
if (className.startsWith("sun.security.ssl")) return true
false
}
def isUsingDebug: Boolean = (newOptions != null) && (!newOptions.isEmpty)
def run() {
System.setProperty("javax.net.debug", newOptions)
val debugType: Class[_] = Thread.currentThread().getContextClassLoader.loadClass(debugClassName)
val newDebug: AnyRef = debugType.newInstance().asInstanceOf[AnyRef]
logger.debug(s"run: debugType = $debugType")
val debugValue = if (isUsingDebug) newDebug else null
var isPatched = false
for (
debugClass <- findClasses;
debugField <- debugClass.getDeclaredFields
) {
if (isValidField(debugField, debugType)) {
logger.debug(s"run: patching $debugClass with $debugValue")
monkeyPatchField(debugField, debugValue)
isPatched = true
}
}
if (!isPatched) {
throw new IllegalStateException("No debug classes found!")
}
val argsField = debugType.getDeclaredField("args")
monkeyPatchField(argsField, newOptions)
}
}
def apply(newOptions: String) {
logger.trace(s"apply: newOptions = ${newOptions}")
try {
val action = new MonkeyPatchInternalSslDebugAction(newOptions)
AccessController.doPrivileged(action)
} catch {
case NonFatal(e) =>
throw new IllegalStateException("InternalDebug configuration error", e)
}
}
}