diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt index 4f480d4a334d..edb15a56919f 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt @@ -200,14 +200,15 @@ open class OkHttpClient internal constructor( @get:JvmName("socketFactory") val socketFactory: SocketFactory = builder.socketFactory - private val sslSocketFactoryOrNull: SSLSocketFactory? + private val sslInitializedFields: Lazy? @get:JvmName("sslSocketFactory") val sslSocketFactory: SSLSocketFactory - get() = sslSocketFactoryOrNull ?: throw IllegalStateException("CLEARTEXT-only client") + get() = sslInitializedFields?.value?.sslSocketFactory ?: throw IllegalStateException("CLEARTEXT-only client") @get:JvmName("x509TrustManager") val x509TrustManager: X509TrustManager? + get() = sslInitializedFields?.value?.x509TrustManager @get:JvmName("connectionSpecs") val connectionSpecs: List = @@ -219,11 +220,18 @@ open class OkHttpClient internal constructor( @get:JvmName("hostnameVerifier") val hostnameVerifier: HostnameVerifier = builder.hostnameVerifier + private lateinit var _certificatePinner: CertificatePinner + @get:JvmName("certificatePinner") - val certificatePinner: CertificatePinner + val certificatePinner: CertificatePinner by lazy { + certificateChainCleaner?.let { + _certificatePinner.withCertificateChainCleaner(it) + } ?: _certificatePinner + } @get:JvmName("certificateChainCleaner") val certificateChainCleaner: CertificateChainCleaner? + get() = sslInitializedFields?.value?.certificateChainCleaner /** * Default call timeout (in milliseconds). By default there is no timeout for complete calls, but @@ -284,24 +292,24 @@ open class OkHttpClient internal constructor( init { if (connectionSpecs.none { it.isTls }) { - this.sslSocketFactoryOrNull = null - this.certificateChainCleaner = null - this.x509TrustManager = null - this.certificatePinner = CertificatePinner.DEFAULT - } else if (builder.sslSocketFactoryOrNull != null) { - this.sslSocketFactoryOrNull = builder.sslSocketFactoryOrNull - this.certificateChainCleaner = builder.certificateChainCleaner!! - this.x509TrustManager = builder.x509TrustManagerOrNull!! - this.certificatePinner = - builder.certificatePinner - .withCertificateChainCleaner(certificateChainCleaner!!) + this.sslInitializedFields = null + this._certificatePinner = CertificatePinner.DEFAULT + } else if (builder.sslInitializedFields != null) { + this.sslInitializedFields = builder.sslInitializedFields + this._certificatePinner = builder.certificatePinner } else { - this.x509TrustManager = Platform.get().platformTrustManager() - this.sslSocketFactoryOrNull = Platform.get().newSslSocketFactory(x509TrustManager!!) - this.certificateChainCleaner = CertificateChainCleaner.get(x509TrustManager!!) - this.certificatePinner = - builder.certificatePinner - .withCertificateChainCleaner(certificateChainCleaner!!) + this.sslInitializedFields = + lazy { + val platform = Platform.get() + val trustManager = platform.platformTrustManager() + val certificateChainCleaner = CertificateChainCleaner.get(trustManager) + SSLInitializedFields( + trustManager, + platform.newSslSocketFactory(trustManager), + certificateChainCleaner, + ) + } + this._certificatePinner = builder.certificatePinner } verifyClientState() @@ -337,6 +345,12 @@ open class OkHttpClient internal constructor( ) } + internal data class SSLInitializedFields( + val x509TrustManager: X509TrustManager, + val sslSocketFactory: SSLSocketFactory, + val certificateChainCleaner: CertificateChainCleaner, + ) + private fun verifyClientState() { check(null !in (interceptors as List)) { "Null interceptor: $interceptors" @@ -346,14 +360,10 @@ open class OkHttpClient internal constructor( } if (connectionSpecs.none { it.isTls }) { - check(sslSocketFactoryOrNull == null) - check(certificateChainCleaner == null) - check(x509TrustManager == null) + check(sslInitializedFields == null) { "ssl initialized for plaintext client" } check(certificatePinner == CertificatePinner.DEFAULT) } else { - checkNotNull(sslSocketFactoryOrNull) { "sslSocketFactory == null" } - checkNotNull(certificateChainCleaner) { "certificateChainCleaner == null" } - checkNotNull(x509TrustManager) { "x509TrustManager == null" } + checkNotNull(sslInitializedFields) { "ssl not initialized for client" } } } @@ -609,13 +619,11 @@ open class OkHttpClient internal constructor( internal var proxySelector: ProxySelector? = null internal var proxyAuthenticator: Authenticator = Authenticator.NONE internal var socketFactory: SocketFactory = SocketFactory.getDefault() - internal var sslSocketFactoryOrNull: SSLSocketFactory? = null - internal var x509TrustManagerOrNull: X509TrustManager? = null + internal var sslInitializedFields: Lazy? = null internal var connectionSpecs: List = DEFAULT_CONNECTION_SPECS internal var protocols: List = DEFAULT_PROTOCOLS internal var hostnameVerifier: HostnameVerifier = OkHostnameVerifier internal var certificatePinner: CertificatePinner = CertificatePinner.DEFAULT - internal var certificateChainCleaner: CertificateChainCleaner? = null internal var callTimeout = 0 internal var connectTimeout = 10_000 internal var readTimeout = 10_000 @@ -644,13 +652,11 @@ open class OkHttpClient internal constructor( this.proxySelector = okHttpClient.proxySelector this.proxyAuthenticator = okHttpClient.proxyAuthenticator this.socketFactory = okHttpClient.socketFactory - this.sslSocketFactoryOrNull = okHttpClient.sslSocketFactoryOrNull - this.x509TrustManagerOrNull = okHttpClient.x509TrustManager + this.sslInitializedFields = okHttpClient.sslInitializedFields this.connectionSpecs = okHttpClient.connectionSpecs this.protocols = okHttpClient.protocols this.hostnameVerifier = okHttpClient.hostnameVerifier - this.certificatePinner = okHttpClient.certificatePinner - this.certificateChainCleaner = okHttpClient.certificateChainCleaner + this.certificatePinner = okHttpClient._certificatePinner this.callTimeout = okHttpClient.callTimeoutMillis this.connectTimeout = okHttpClient.connectTimeoutMillis this.readTimeout = okHttpClient.readTimeoutMillis @@ -913,18 +919,25 @@ open class OkHttpClient internal constructor( ) fun sslSocketFactory(sslSocketFactory: SSLSocketFactory) = apply { - if (sslSocketFactory != this.sslSocketFactoryOrNull) { + if (sslSocketFactory != sslInitializedFields?.value?.sslSocketFactory) { this.routeDatabase = null } - this.sslSocketFactoryOrNull = sslSocketFactory - this.x509TrustManagerOrNull = - Platform.get().trustManager(sslSocketFactory) ?: throw IllegalStateException( - "Unable to extract the trust manager on ${Platform.get()}, " + + val platform = Platform.get() + val trustManager = + platform.trustManager(sslSocketFactory) ?: throw IllegalStateException( + "Unable to extract the trust manager on $platform, " + "sslSocketFactory is ${sslSocketFactory.javaClass}", ) - this.certificateChainCleaner = - Platform.get().buildCertificateChainCleaner(x509TrustManagerOrNull!!) + // Expensive copy assuming SSL already initialized + sslInitializedFields = + lazyOf( + SSLInitializedFields( + trustManager, + sslSocketFactory = sslSocketFactory, + certificateChainCleaner = platform.buildCertificateChainCleaner(trustManager), + ), + ) } /** @@ -976,13 +989,21 @@ open class OkHttpClient internal constructor( sslSocketFactory: SSLSocketFactory, trustManager: X509TrustManager, ) = apply { - if (sslSocketFactory != this.sslSocketFactoryOrNull || trustManager != this.x509TrustManagerOrNull) { + val existingSsl = sslInitializedFields?.value + + if (sslSocketFactory != existingSsl?.sslSocketFactory || trustManager != existingSsl?.x509TrustManager) { this.routeDatabase = null } - this.sslSocketFactoryOrNull = sslSocketFactory - this.certificateChainCleaner = CertificateChainCleaner.get(trustManager) - this.x509TrustManagerOrNull = trustManager + // Expensive copy assuming SSL already initialized + sslInitializedFields = + lazyOf( + SSLInitializedFields( + trustManager, + sslSocketFactory = sslSocketFactory, + certificateChainCleaner = CertificateChainCleaner.get(trustManager), + ), + ) } fun connectionSpecs(connectionSpecs: List) = @@ -1078,11 +1099,13 @@ open class OkHttpClient internal constructor( */ fun certificatePinner(certificatePinner: CertificatePinner) = apply { - if (certificatePinner != this.certificatePinner) { + val cleanCertificatePinner = CertificatePinner(certificatePinner.pins) + + if (cleanCertificatePinner != this.certificatePinner) { this.routeDatabase = null } - this.certificatePinner = certificatePinner + this.certificatePinner = cleanCertificatePinner } /** diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientConstructionTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientConstructionTest.kt new file mode 100644 index 000000000000..cf73e13935d8 --- /dev/null +++ b/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientConstructionTest.kt @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3 + +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLSocketFactory +import javax.net.ssl.X509TrustManager +import okhttp3.HttpUrl.Companion.toHttpUrl +import okhttp3.internal.platform.Platform +import okhttp3.testing.PlatformRule +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.api.extension.RegisterExtension + +class OkHttpClientConstructionTest { + @RegisterExtension + var platform = PlatformRule() + + @Test fun constructionDoesntTriggerPlatformOrSSL() { + Platform.resetForTests(platform = ExplosivePlatform { TODO("Avoid call") }) + + val client = OkHttpClient() + + assertNotNull(client.toString()) + + client.newCall(Request("https://example.org/robots.txt".toHttpUrl())) + } + + @Test fun cloneDoesntTriggerPlatformOrSSL() { + Platform.resetForTests(platform = ExplosivePlatform { TODO("Avoid call") }) + + val client = OkHttpClient() + + val client2 = client.newBuilder().build() + assertNotNull(client2.toString()) + } + + @Test fun triggersOnExecute() { + Platform.resetForTests(platform = ExplosivePlatform { throw NoSuchAlgorithmException() }) + + val client = OkHttpClient() + + val call = client.newCall(Request("https://example.org/robots.txt".toHttpUrl())) + + assertThrows { + call.execute() + } + } + + class ExplosivePlatform(private val explode: () -> Nothing) : Platform() { + override fun newSSLContext(): SSLContext { + explode() + } + + override fun newSslSocketFactory(trustManager: X509TrustManager): SSLSocketFactory { + explode() + } + + override fun platformTrustManager(): X509TrustManager { + explode() + } + } +} diff --git a/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientTest.kt b/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientTest.kt index c29f58b69197..63109df7b4ad 100644 --- a/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientTest.kt +++ b/okhttp/src/jvmTest/kotlin/okhttp3/OkHttpClientTest.kt @@ -363,6 +363,15 @@ class OkHttpClientTest { .routeDatabase, ) + // identical CertificatePinner + assertSame( + client.routeDatabase, + client.newBuilder() + .certificatePinner(CertificatePinner.Builder().build()) + .build() + .routeDatabase, + ) + // logically different scope of client for route db assertNotSame( client.routeDatabase, @@ -423,7 +432,11 @@ class OkHttpClientTest { assertNotSame( client.routeDatabase, client.newBuilder() - .certificatePinner(CertificatePinner.Builder().build()) + .certificatePinner( + CertificatePinner.Builder() + .add("san.com", "sha1/afwiKY3RxoMmLkuRW1l7QsPZTJPwDS2pdDROQjXw8ig=") + .build(), + ) .build() .routeDatabase, )