diff --git a/pom.xml b/pom.xml index 3c2963cde..9fe0bfdd7 100644 --- a/pom.xml +++ b/pom.xml @@ -455,6 +455,11 @@ netty-transport ${netty.version} + + org.apache.httpcomponents.client5 + httpclient5 + 5.3.1 + @@ -488,5 +493,10 @@ 3.0 test + + org.mock-server + mockserver-junit-rule-no-dependencies + 5.14.0 + diff --git a/src/main/java/com/google/firebase/internal/ApacheHttp2AsyncEntityProducer.java b/src/main/java/com/google/firebase/internal/ApacheHttp2AsyncEntityProducer.java new file mode 100644 index 000000000..9bdf208c7 --- /dev/null +++ b/src/main/java/com/google/firebase/internal/ApacheHttp2AsyncEntityProducer.java @@ -0,0 +1,150 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import com.google.api.client.util.StreamingContent; +import com.google.common.annotations.VisibleForTesting; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.nio.AsyncEntityProducer; +import org.apache.hc.core5.http.nio.DataStreamChannel; + +public class ApacheHttp2AsyncEntityProducer implements AsyncEntityProducer { + private ByteBuffer bytebuf; + private ByteArrayOutputStream baos; + private final StreamingContent content; + private final ContentType contentType; + private final long contentLength; + private final String contentEncoding; + private final CompletableFuture writeFuture; + private final AtomicReference exception; + + public ApacheHttp2AsyncEntityProducer(StreamingContent content, ContentType contentType, + String contentEncoding, long contentLength, CompletableFuture writeFuture) { + this.content = content; + this.contentType = contentType; + this.contentEncoding = contentEncoding; + this.contentLength = contentLength; + this.writeFuture = writeFuture; + this.bytebuf = null; + + this.baos = new ByteArrayOutputStream((int) (contentLength < 0 ? 0 : contentLength)); + this.exception = new AtomicReference<>(); + } + + public ApacheHttp2AsyncEntityProducer(ApacheHttp2Request request, + CompletableFuture writeFuture) { + this( + request.getStreamingContent(), + ContentType.parse(request.getContentType()), + request.getContentEncoding(), + request.getContentLength(), + writeFuture); + } + + @Override + public boolean isRepeatable() { + return true; + } + + @Override + public String getContentType() { + return contentType != null ? contentType.toString() : null; + } + + @Override + public long getContentLength() { + return contentLength; + } + + @Override + public int available() { + return Integer.MAX_VALUE; + } + + @Override + public String getContentEncoding() { + return contentEncoding; + } + + @Override + public boolean isChunked() { + return contentLength == -1; + } + + @Override + public Set getTrailerNames() { + return null; + } + + @Override + public void produce(DataStreamChannel channel) throws IOException { + if (bytebuf == null) { + if (content != null) { + try { + content.writeTo(baos); + } catch (IOException e) { + failed(e); + throw e; + } + } + + this.bytebuf = ByteBuffer.wrap(baos.toByteArray()); + } + + if (bytebuf.hasRemaining()) { + channel.write(bytebuf); + } + + if (!bytebuf.hasRemaining()) { + channel.endStream(); + writeFuture.complete(null); + releaseResources(); + } + } + + @Override + public void failed(Exception cause) { + if (exception.compareAndSet(null, cause)) { + releaseResources(); + writeFuture.completeExceptionally(cause); + } + } + + public final Exception getException() { + return exception.get(); + } + + @Override + public void releaseResources() { + if (bytebuf != null) { + bytebuf.clear(); + } + } + + @VisibleForTesting + ByteBuffer getBytebuf() { + return bytebuf; + } +} diff --git a/src/main/java/com/google/firebase/internal/ApacheHttp2Request.java b/src/main/java/com/google/firebase/internal/ApacheHttp2Request.java new file mode 100644 index 000000000..56c2c9034 --- /dev/null +++ b/src/main/java/com/google/firebase/internal/ApacheHttp2Request.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import com.google.api.client.http.LowLevelHttpRequest; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.common.annotations.VisibleForTesting; + +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import org.apache.hc.client5.http.ConnectTimeoutException; +import org.apache.hc.client5.http.HttpHostConnectException; +import org.apache.hc.client5.http.async.methods.SimpleHttpRequest; +import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; +import org.apache.hc.client5.http.async.methods.SimpleRequestBuilder; +import org.apache.hc.client5.http.async.methods.SimpleResponseConsumer; +import org.apache.hc.client5.http.config.RequestConfig; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.http.nio.support.BasicRequestProducer; +import org.apache.hc.core5.http2.H2StreamResetException; +import org.apache.hc.core5.util.Timeout; + +final class ApacheHttp2Request extends LowLevelHttpRequest { + private final CloseableHttpAsyncClient httpAsyncClient; + private final SimpleRequestBuilder requestBuilder; + private SimpleHttpRequest request; + private final RequestConfig.Builder requestConfig; + private int writeTimeout; + private ApacheHttp2AsyncEntityProducer entityProducer; + + ApacheHttp2Request( + CloseableHttpAsyncClient httpAsyncClient, SimpleRequestBuilder requestBuilder) { + this.httpAsyncClient = httpAsyncClient; + this.requestBuilder = requestBuilder; + this.writeTimeout = 0; + + this.requestConfig = RequestConfig.custom() + .setRedirectsEnabled(false); + } + + @Override + public void addHeader(String name, String value) { + requestBuilder.addHeader(name, value); + } + + @Override + public void setTimeout(int connectionTimeout, int readTimeout) throws IOException { + requestConfig + .setConnectTimeout(Timeout.ofMilliseconds(connectionTimeout)) + .setResponseTimeout(Timeout.ofMilliseconds(readTimeout)); + } + + @Override + public void setWriteTimeout(int writeTimeout) throws IOException { + this.writeTimeout = writeTimeout; + } + + @Override + public LowLevelHttpResponse execute() throws IOException { + // Set request configs + requestBuilder.setRequestConfig(requestConfig.build()); + + // Build request + request = requestBuilder.build(); + + // Make Producer + CompletableFuture writeFuture = new CompletableFuture<>(); + entityProducer = new ApacheHttp2AsyncEntityProducer(this, writeFuture); + + // Execute + final Future responseFuture = httpAsyncClient.execute( + new BasicRequestProducer(request, entityProducer), + SimpleResponseConsumer.create(), + new FutureCallback() { + @Override + public void completed(final SimpleHttpResponse response) { + } + + @Override + public void failed(final Exception exception) { + } + + @Override + public void cancelled() { + } + }); + + // Wait for write + try { + if (writeTimeout != 0) { + writeFuture.get(writeTimeout, TimeUnit.MILLISECONDS); + } + } catch (TimeoutException e) { + throw new IOException("Write Timeout", e.getCause()); + } catch (Exception e) { + throw new IOException("Exception in write", e.getCause()); + } + + // Wait for response + try { + final SimpleHttpResponse response = responseFuture.get(); + return new ApacheHttp2Response(response); + } catch (ExecutionException e) { + if (e.getCause() instanceof ConnectTimeoutException + || e.getCause() instanceof SocketTimeoutException) { + throw new IOException("Connection Timeout", e.getCause()); + } else if (e.getCause() instanceof HttpHostConnectException) { + throw new IOException("Connection exception in request", e.getCause()); + } else if (e.getCause() instanceof H2StreamResetException) { + throw new IOException("Stream exception in request", e.getCause()); + } else { + throw new IOException("Unknown exception in request", e); + } + } catch (InterruptedException e) { + throw new IOException("Request Interrupted", e); + } catch (CancellationException e) { + throw new IOException("Request Cancelled", e); + } + } + + @VisibleForTesting + ApacheHttp2AsyncEntityProducer getEntityProducer() { + return entityProducer; + } +} diff --git a/src/main/java/com/google/firebase/internal/ApacheHttp2Response.java b/src/main/java/com/google/firebase/internal/ApacheHttp2Response.java new file mode 100644 index 000000000..4c05b0e03 --- /dev/null +++ b/src/main/java/com/google/firebase/internal/ApacheHttp2Response.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.common.annotations.VisibleForTesting; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.Header; + +public class ApacheHttp2Response extends LowLevelHttpResponse { + private final SimpleHttpResponse response; + private final Header[] allHeaders; + + ApacheHttp2Response(SimpleHttpResponse response) { + this.response = response; + allHeaders = response.getHeaders(); + } + + @Override + public int getStatusCode() { + return response.getCode(); + } + + @Override + public InputStream getContent() throws IOException { + return new ByteArrayInputStream(response.getBodyBytes()); + } + + @Override + public String getContentEncoding() { + Header contentEncodingHeader = response.getFirstHeader("Content-Encoding"); + return contentEncodingHeader == null ? null : contentEncodingHeader.getValue(); + } + + @Override + public long getContentLength() { + String bodyText = response.getBodyText(); + return bodyText == null ? 0 : bodyText.length(); + } + + @Override + public String getContentType() { + ContentType contentType = response.getContentType(); + return contentType == null ? null : contentType.toString(); + } + + @Override + public String getReasonPhrase() { + return response.getReasonPhrase(); + } + + @Override + public String getStatusLine() { + return response.toString(); + } + + public String getHeaderValue(String name) { + return response.getLastHeader(name).getValue(); + } + + @Override + public String getHeaderValue(int index) { + return allHeaders[index].getValue(); + } + + @Override + public int getHeaderCount() { + return allHeaders.length; + } + + @Override + public String getHeaderName(int index) { + return allHeaders[index].getName(); + } + + @VisibleForTesting + public SimpleHttpResponse getResponse() { + return response; + } +} diff --git a/src/main/java/com/google/firebase/internal/ApacheHttp2Transport.java b/src/main/java/com/google/firebase/internal/ApacheHttp2Transport.java new file mode 100644 index 000000000..9c0e413fb --- /dev/null +++ b/src/main/java/com/google/firebase/internal/ApacheHttp2Transport.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import com.google.api.client.http.HttpTransport; + +import java.io.IOException; +import java.net.ProxySelector; +import java.util.concurrent.TimeUnit; + +import org.apache.hc.client5.http.async.HttpAsyncClient; +import org.apache.hc.client5.http.async.methods.SimpleRequestBuilder; +import org.apache.hc.client5.http.config.ConnectionConfig; +import org.apache.hc.client5.http.config.TlsConfig; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.client5.http.impl.async.HttpAsyncClientBuilder; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManager; +import org.apache.hc.client5.http.impl.nio.PoolingAsyncClientConnectionManagerBuilder; +import org.apache.hc.client5.http.impl.routing.SystemDefaultRoutePlanner; +import org.apache.hc.client5.http.ssl.ClientTlsStrategyBuilder; +import org.apache.hc.core5.http.config.Http1Config; +import org.apache.hc.core5.http2.HttpVersionPolicy; +import org.apache.hc.core5.http2.config.H2Config; +import org.apache.hc.core5.io.CloseMode; +import org.apache.hc.core5.ssl.SSLContexts; + +/** + * HTTP/2 enabled async transport based on the Apache HTTP Client library + */ +public final class ApacheHttp2Transport extends HttpTransport { + + private final CloseableHttpAsyncClient httpAsyncClient; + private final boolean isMtls; + + public ApacheHttp2Transport() { + this(newDefaultHttpAsyncClient(), false); + } + + public ApacheHttp2Transport(CloseableHttpAsyncClient httpAsyncClient) { + this(httpAsyncClient, false); + } + + public ApacheHttp2Transport(CloseableHttpAsyncClient httpAsyncClient, boolean isMtls) { + this.httpAsyncClient = httpAsyncClient; + this.isMtls = isMtls; + + httpAsyncClient.start(); + } + + public static CloseableHttpAsyncClient newDefaultHttpAsyncClient() { + return defaultHttpAsyncClientBuilder().build(); + } + + public static HttpAsyncClientBuilder defaultHttpAsyncClientBuilder() { + PoolingAsyncClientConnectionManager connectionManager = + PoolingAsyncClientConnectionManagerBuilder.create() + // Set Max total connections to match google api client limits + // https://github.com/googleapis/google-http-java-client/blob/f9d4e15bd3c784b1fd3b0f3468000a91c6f79715/google-http-client-apache-v5/src/main/java/com/google/api/client/http/apache/v5/Apache5HttpTransport.java#L151 + .setMaxConnTotal(200) + // Set max connections per route to match the concurrent stream limit of the FCM backend. + .setMaxConnPerRoute(100) + .setDefaultConnectionConfig( + ConnectionConfig.custom().setTimeToLive(-1, TimeUnit.MILLISECONDS).build()) + .setDefaultTlsConfig( + TlsConfig.custom().setVersionPolicy(HttpVersionPolicy.NEGOTIATE).build()) + .setTlsStrategy(ClientTlsStrategyBuilder.create() + .setSslContext(SSLContexts.createSystemDefault()) + .build()) + .build(); + + return HttpAsyncClientBuilder.create() + // Set maxConcurrentStreams to 100 to match the concurrent stream limit of the FCM backend. + .setH2Config(H2Config.custom().setMaxConcurrentStreams(100).build()) + .setHttp1Config(Http1Config.DEFAULT) + .setConnectionManager(connectionManager) + .setRoutePlanner(new SystemDefaultRoutePlanner(ProxySelector.getDefault())) + .disableRedirectHandling() + .disableAutomaticRetries(); + } + + @Override + public boolean supportsMethod(String method) { + return true; + } + + @Override + protected ApacheHttp2Request buildRequest(String method, String url) { + SimpleRequestBuilder requestBuilder = SimpleRequestBuilder.create(method).setUri(url); + return new ApacheHttp2Request(httpAsyncClient, requestBuilder); + } + + /** + * Gracefully shuts down the connection manager and releases allocated resources. This closes all + * connections, whether they are currently used or not. + */ + @Override + public void shutdown() throws IOException { + httpAsyncClient.close(CloseMode.GRACEFUL); + } + + /** Returns the Apache HTTP client. */ + public HttpAsyncClient getHttpClient() { + return httpAsyncClient; + } + + /** Returns if the underlying HTTP client is mTLS. */ + @Override + public boolean isMtls() { + return isMtls; + } +} diff --git a/src/main/java/com/google/firebase/internal/ApiClientUtils.java b/src/main/java/com/google/firebase/internal/ApiClientUtils.java index e586bd11f..339726105 100644 --- a/src/main/java/com/google/firebase/internal/ApiClientUtils.java +++ b/src/main/java/com/google/firebase/internal/ApiClientUtils.java @@ -88,6 +88,6 @@ public static JsonFactory getDefaultJsonFactory() { } public static HttpTransport getDefaultTransport() { - return Utils.getDefaultTransport(); + return new ApacheHttp2Transport(); } } diff --git a/src/main/java/com/google/firebase/internal/MockApacheHttp2AsyncClient.java b/src/main/java/com/google/firebase/internal/MockApacheHttp2AsyncClient.java new file mode 100644 index 000000000..c859ec485 --- /dev/null +++ b/src/main/java/com/google/firebase/internal/MockApacheHttp2AsyncClient.java @@ -0,0 +1,73 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import java.io.IOException; +import java.util.concurrent.Future; + +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.function.Supplier; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.nio.AsyncPushConsumer; +import org.apache.hc.core5.http.nio.AsyncRequestProducer; +import org.apache.hc.core5.http.nio.AsyncResponseConsumer; +import org.apache.hc.core5.http.nio.HandlerFactory; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.io.CloseMode; +import org.apache.hc.core5.reactor.IOReactorStatus; +import org.apache.hc.core5.util.TimeValue; + +public class MockApacheHttp2AsyncClient extends CloseableHttpAsyncClient { + + @Override + public void close(CloseMode closeMode) { + } + + @Override + public void close() throws IOException { + } + + @Override + public void start() { + } + + @Override + public IOReactorStatus getStatus() { + return null; + } + + @Override + public void awaitShutdown(TimeValue waitTime) throws InterruptedException { + } + + @Override + public void initiateShutdown() { + } + + @Override + protected Future doExecute(HttpHost target, AsyncRequestProducer requestProducer, + AsyncResponseConsumer responseConsumer, + HandlerFactory pushHandlerFactory, + HttpContext context, FutureCallback callback) { + return null; + } + + @Override + public void register(String hostname, String uriPattern, Supplier supplier) { + } +} diff --git a/src/test/java/com/google/firebase/FirebaseOptionsTest.java b/src/test/java/com/google/firebase/FirebaseOptionsTest.java index 0a4e9e81a..05613e98a 100644 --- a/src/test/java/com/google/firebase/FirebaseOptionsTest.java +++ b/src/test/java/com/google/firebase/FirebaseOptionsTest.java @@ -24,12 +24,12 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; -import com.google.api.client.http.javanet.NetHttpTransport; import com.google.api.client.json.gson.GsonFactory; import com.google.auth.oauth2.AccessToken; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.cloud.firestore.FirestoreOptions; +import com.google.firebase.internal.ApacheHttp2Transport; import com.google.firebase.testing.ServiceAccount; import com.google.firebase.testing.TestUtils; import java.io.IOException; @@ -74,7 +74,7 @@ protected ThreadFactory getThreadFactory() { @Test public void createOptionsWithAllValuesSet() throws IOException { GsonFactory jsonFactory = new GsonFactory(); - NetHttpTransport httpTransport = new NetHttpTransport(); + ApacheHttp2Transport httpTransport = new ApacheHttp2Transport(); FirestoreOptions firestoreOptions = FirestoreOptions.newBuilder().build(); FirebaseOptions firebaseOptions = FirebaseOptions.builder() diff --git a/src/test/java/com/google/firebase/internal/ApacheHttp2TransportIT.java b/src/test/java/com/google/firebase/internal/ApacheHttp2TransportIT.java new file mode 100644 index 000000000..c8fbe5533 --- /dev/null +++ b/src/test/java/com/google/firebase/internal/ApacheHttp2TransportIT.java @@ -0,0 +1,409 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockserver.model.Header.header; +import static org.mockserver.model.HttpForward.forward; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpRequestFactory; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.json.JsonFactory; +import com.google.api.client.util.GenericData; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.common.collect.ImmutableMap; +import com.google.firebase.ErrorCode; +import com.google.firebase.FirebaseApp; +import com.google.firebase.FirebaseException; +import com.google.firebase.FirebaseOptions; +import com.google.firebase.IncomingHttpResponse; +import com.google.firebase.auth.MockGoogleCredentials; +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.client5.http.impl.async.HttpAsyncClients; +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpRequestInterceptor; +import org.apache.hc.core5.http.protocol.HttpContext; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.HttpForward.Scheme; +import org.mockserver.socket.PortFactory; + +public class ApacheHttp2TransportIT { + private static FirebaseApp app; + private static final GoogleCredentials MOCK_CREDENTIALS = new MockGoogleCredentials("test_token"); + private static final ImmutableMap payload = + ImmutableMap.of("foo", "bar"); + + // Sets a 5 second delay before server response to simulate a slow network that + // results in a read timeout. + private static final String DELAY_URL = "https://nghttp2.org/httpbin/delay/5"; + private static final String GET_URL = "https://nghttp2.org/httpbin/get"; + private static final String POST_URL = "https://nghttp2.org/httpbin/post"; + + private static ServerSocket serverSocket; + private static Socket fillerSocket; + private static int port; + + private static ClientAndServer mockProxy; + private static ClientAndServer mockServer; + + @BeforeClass + public static void setUpClass() throws IOException { + // Start server socket with a backlog queue of 1 and a automatically assigned + // port + serverSocket = new ServerSocket(0, 1); + port = serverSocket.getLocalPort(); + // Fill the backlog queue to force socket to ignore future connections + fillerSocket = new Socket(); + fillerSocket.connect(serverSocket.getLocalSocketAddress()); + } + + @AfterClass + public static void cleanUpClass() throws IOException { + if (serverSocket != null && !serverSocket.isClosed()) { + serverSocket.close(); + } + if (fillerSocket != null && !fillerSocket.isClosed()) { + fillerSocket.close(); + } + } + + @After + public void cleanup() { + if (app != null) { + app.delete(); + } + + if (mockProxy != null && mockProxy.isRunning()) { + mockProxy.close(); + } + + if (mockServer != null && mockServer.isRunning()) { + mockServer.close(); + } + + System.clearProperty("http.proxyHost"); + System.clearProperty("http.proxyPort"); + System.clearProperty("https.proxyHost"); + System.clearProperty("https.proxyPort"); + } + + @Test(timeout = 10_000L) + public void testUnauthorizedGetRequest() throws FirebaseException { + ErrorHandlingHttpClient httpClient = getHttpClient(false); + HttpRequestInfo request = HttpRequestInfo.buildGetRequest(GET_URL); + IncomingHttpResponse response = httpClient.send(request); + assertEquals(200, response.getStatusCode()); + } + + @Test(timeout = 10_000L) + public void testUnauthorizedPostRequest() throws FirebaseException { + ErrorHandlingHttpClient httpClient = getHttpClient(false); + HttpRequestInfo request = HttpRequestInfo.buildJsonPostRequest(POST_URL, payload); + GenericData body = httpClient.sendAndParse(request, GenericData.class); + assertEquals("{\"foo\":\"bar\"}", body.get("data")); + } + + @Test(timeout = 10_000L) + public void testConnectTimeoutAuthorizedGet() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setConnectTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildGetRequest("https://localhost:" + port); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Connection Timeout", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testConnectTimeoutAuthorizedPost() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setConnectTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildJsonPostRequest("https://localhost:" + port, payload); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Connection Timeout", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testReadTimeoutAuthorizedGet() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setReadTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildGetRequest(DELAY_URL); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Stream exception in request", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testReadTimeoutAuthorizedPost() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setReadTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildJsonPostRequest(DELAY_URL, payload); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Stream exception in request", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testWriteTimeoutAuthorizedGet() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setWriteTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildGetRequest(GET_URL); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Write Timeout", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testWriteTimeoutAuthorizedPost() throws FirebaseException { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setWriteTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildJsonPostRequest(POST_URL, payload); + + try { + httpClient.send(request); + fail("No exception thrown for HTTP error response"); + } catch (FirebaseException e) { + assertEquals(ErrorCode.UNKNOWN, e.getErrorCode()); + assertEquals("IO error: Write Timeout", e.getMessage()); + assertNull(e.getHttpResponse()); + } + } + + @Test(timeout = 10_000L) + public void testRequestShouldNotFollowRedirects() throws IOException { + ApacheHttp2Transport transport = new ApacheHttp2Transport(); + ApacheHttp2Request request = transport.buildRequest("GET", + "https://google.com"); + LowLevelHttpResponse response = request.execute(); + + assertEquals(301, response.getStatusCode()); + assert (response instanceof ApacheHttp2Response); + assertEquals("https://www.google.com/", ((ApacheHttp2Response) response).getHeaderValue("location")); + } + + @Test(timeout = 10_000L) + public void testRequestCanSetHeaders() { + final AtomicBoolean interceptorCalled = new AtomicBoolean(false); + CloseableHttpAsyncClient client = HttpAsyncClients.custom() + .addRequestInterceptorFirst( + new HttpRequestInterceptor() { + @Override + public void process( + HttpRequest request, EntityDetails details, HttpContext context) + throws HttpException, IOException { + Header header = request.getFirstHeader("foo"); + assertNotNull("Should have found header", header); + assertEquals("bar", header.getValue()); + interceptorCalled.set(true); + throw new IOException("cancelling request"); + } + }) + .build(); + + ApacheHttp2Transport transport = new ApacheHttp2Transport(client); + ApacheHttp2Request request = transport.buildRequest("GET", "http://www.google.com"); + request.addHeader("foo", "bar"); + try { + request.execute(); + fail("should not actually make the request"); + } catch (IOException exception) { + assertEquals("Unknown exception in request", exception.getMessage()); + } + assertTrue("Expected to have called our test interceptor", interceptorCalled.get()); + } + + @Test(timeout = 10_000L) + public void testVerifyProxyIsRespected() { + try { + System.setProperty("https.proxyHost", "localhost"); + System.setProperty("https.proxyPort", "8080"); + + HttpTransport transport = new ApacheHttp2Transport(); + transport.createRequestFactory().buildGetRequest(new GenericUrl(GET_URL)).execute(); + fail("No exception thrown for HTTP error response"); + } catch (IOException e) { + assertEquals("Connection exception in request", e.getMessage()); + assertTrue(e.getCause().getMessage().contains("localhost:8080")); + } + } + + @Test(timeout = 10_000L) + public void testProxyMockHttp() throws Exception { + // Start MockServer + mockProxy = ClientAndServer.startClientAndServer(PortFactory.findFreePort()); + mockServer = ClientAndServer.startClientAndServer(PortFactory.findFreePort()); + + System.setProperty("http.proxyHost", "localhost"); + System.setProperty("http.proxyPort", mockProxy.getPort().toString()); + + // Configure proxy to receieve requests and forward them to a mock destination + // server + mockProxy + .when( + request()) + .forward( + forward() + .withHost("localhost") + .withPort(mockServer.getPort()) + .withScheme(Scheme.HTTP)); + + // Configure server to listen and respond + mockServer + .when( + request()) + .respond( + response() + .withStatusCode(200) + .withBody("Expected server response")); + + // Send a request through the proxy + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .setWriteTimeout(100) + .build(), "test-app"); + ErrorHandlingHttpClient httpClient = getHttpClient(true, app); + HttpRequestInfo request = HttpRequestInfo.buildGetRequest("http://www.google.com"); + IncomingHttpResponse response = httpClient.send(request); + + // Verify that the proxy received request with destination host + mockProxy.verify( + request() + .withMethod("GET") + .withPath("/") + .withHeader(header("Host", "www.google.com"))); + + // Verify the forwarded request is received by the server + mockServer.verify( + request() + .withMethod("GET") + .withPath("/")); + + // Verify response + assertEquals(200, response.getStatusCode()); + assertEquals(response.getContent(), "Expected server response"); + } + + private static ErrorHandlingHttpClient getHttpClient(boolean authorized, + FirebaseApp app) { + HttpRequestFactory requestFactory; + if (authorized) { + requestFactory = ApiClientUtils.newAuthorizedRequestFactory(app); + } else { + requestFactory = ApiClientUtils.newUnauthorizedRequestFactory(app); + } + JsonFactory jsonFactory = ApiClientUtils.getDefaultJsonFactory(); + TestHttpErrorHandler errorHandler = new TestHttpErrorHandler(); + return new ErrorHandlingHttpClient<>(requestFactory, jsonFactory, errorHandler); + } + + private static ErrorHandlingHttpClient getHttpClient(boolean authorized) { + app = FirebaseApp.initializeApp(FirebaseOptions.builder() + .setCredentials(MOCK_CREDENTIALS) + .build(), "test-app"); + return getHttpClient(authorized, app); + } + + private static class TestHttpErrorHandler implements HttpErrorHandler { + @Override + public FirebaseException handleIOException(IOException e) { + return new FirebaseException( + ErrorCode.UNKNOWN, "IO error: " + e.getMessage(), e); + } + + @Override + public FirebaseException handleHttpResponseException( + HttpResponseException e, IncomingHttpResponse response) { + return new FirebaseException( + ErrorCode.INTERNAL, "Example error message: " + e.getContent(), e, response); + } + + @Override + public FirebaseException handleParseException(IOException e, IncomingHttpResponse response) { + return new FirebaseException(ErrorCode.UNKNOWN, "Parse error", e, response); + } + } +} diff --git a/src/test/java/com/google/firebase/internal/ApacheHttp2TransportTest.java b/src/test/java/com/google/firebase/internal/ApacheHttp2TransportTest.java new file mode 100644 index 000000000..fcaf563f3 --- /dev/null +++ b/src/test/java/com/google/firebase/internal/ApacheHttp2TransportTest.java @@ -0,0 +1,379 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import com.google.api.client.http.ByteArrayContent; +import com.google.api.client.http.GenericUrl; +import com.google.api.client.http.HttpContent; +import com.google.api.client.http.HttpMethods; +import com.google.api.client.http.HttpResponseException; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.http.InputStreamContent; +import com.google.api.client.http.LowLevelHttpResponse; +import com.google.api.client.util.ByteArrayStreamingContent; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; + +import org.apache.hc.client5.http.async.HttpAsyncClient; +import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; +import org.apache.hc.client5.http.async.methods.SimpleRequestBuilder; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.client5.http.impl.async.HttpAsyncClientBuilder; +import org.apache.hc.client5.http.impl.async.HttpAsyncClients; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpRequestMapper; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.HttpStatus; +import org.apache.hc.core5.http.impl.bootstrap.HttpServer; +import org.apache.hc.core5.http.impl.io.HttpService; +import org.apache.hc.core5.http.io.HttpRequestHandler; +import org.apache.hc.core5.http.io.entity.ByteArrayEntity; +import org.apache.hc.core5.http.io.support.BasicHttpServerRequestHandler; +import org.apache.hc.core5.http.nio.AsyncPushConsumer; +import org.apache.hc.core5.http.nio.AsyncRequestProducer; +import org.apache.hc.core5.http.nio.AsyncResponseConsumer; +import org.apache.hc.core5.http.nio.HandlerFactory; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.http.protocol.HttpProcessor; +import org.junit.Assert; +import org.junit.Test; + +public class ApacheHttp2TransportTest { + @Test + public void testContentLengthSet() throws Exception { + SimpleRequestBuilder requestBuilder = SimpleRequestBuilder.create(HttpMethods.POST) + .setUri("http://www.google.com"); + + ApacheHttp2Request request = new ApacheHttp2Request( + new MockApacheHttp2AsyncClient() { + @SuppressWarnings("unchecked") + @Override + public Future doExecute( + final HttpHost target, + final AsyncRequestProducer requestProducer, + final AsyncResponseConsumer responseConsumer, + final HandlerFactory pushHandlerFactory, + final HttpContext context, + final FutureCallback callback) { + return (Future) CompletableFuture.completedFuture(new SimpleHttpResponse(200)); + } + }, requestBuilder); + + HttpContent content = new ByteArrayContent("text/plain", + "sample".getBytes(StandardCharsets.UTF_8)); + request.setStreamingContent(content); + request.setContentLength(content.getLength()); + request.execute(); + + assertFalse(request.getEntityProducer().isChunked()); + assertEquals(6, request.getEntityProducer().getContentLength()); + } + + @Test + public void testChunked() throws Exception { + byte[] buf = new byte[300]; + Arrays.fill(buf, (byte) ' '); + SimpleRequestBuilder requestBuilder = SimpleRequestBuilder.create(HttpMethods.POST) + .setUri("http://www.google.com"); + ApacheHttp2Request request = new ApacheHttp2Request( + new MockApacheHttp2AsyncClient() { + @SuppressWarnings("unchecked") + @Override + public Future doExecute( + final HttpHost target, + final AsyncRequestProducer requestProducer, + final AsyncResponseConsumer responseConsumer, + final HandlerFactory pushHandlerFactory, + final HttpContext context, + final FutureCallback callback) { + return (Future) CompletableFuture.completedFuture(new SimpleHttpResponse(200)); + } + }, requestBuilder); + + HttpContent content = new InputStreamContent("text/plain", new ByteArrayInputStream(buf)); + request.setStreamingContent(content); + request.execute(); + + assertTrue(request.getEntityProducer().isChunked()); + assertEquals(-1, request.getEntityProducer().getContentLength()); + } + + @Test + public void testExecute() throws Exception { + SimpleHttpResponse simpleHttpResponse = SimpleHttpResponse.create(200, new byte[] { 1, 2, 3 }); + + SimpleRequestBuilder requestBuilder = SimpleRequestBuilder.create(HttpMethods.POST) + .setUri("http://www.google.com"); + + ApacheHttp2Request request = new ApacheHttp2Request( + new MockApacheHttp2AsyncClient() { + @SuppressWarnings("unchecked") + @Override + public Future doExecute( + final HttpHost target, + final AsyncRequestProducer requestProducer, + final AsyncResponseConsumer responseConsumer, + final HandlerFactory pushHandlerFactory, + final HttpContext context, + final FutureCallback callback) { + return (Future) CompletableFuture.completedFuture(simpleHttpResponse); + } + }, requestBuilder); + LowLevelHttpResponse response = request.execute(); + assertTrue(response instanceof ApacheHttp2Response); + + // we confirm that the simple response we prepared in this test is the same as + // the content's response + assertTrue(response.getContent() instanceof ByteArrayInputStream); + assertEquals(simpleHttpResponse, ((ApacheHttp2Response) response).getResponse()); + // No need to cloase ByteArrayInputStream since close() has no effect. + } + + @Test + public void testApacheHttpTransport() { + ApacheHttp2Transport transport = new ApacheHttp2Transport(); + checkHttpTransport(transport); + assertFalse(transport.isMtls()); + } + + @Test + public void testApacheHttpTransportWithParam() { + HttpAsyncClientBuilder clientBuilder = HttpAsyncClients.custom(); + ApacheHttp2Transport transport = new ApacheHttp2Transport(clientBuilder.build(), true); + checkHttpTransport(transport); + assertTrue(transport.isMtls()); + } + + @Test + public void testNewDefaultHttpClient() { + HttpAsyncClient client = ApacheHttp2Transport.newDefaultHttpAsyncClient(); + checkHttpClient(client); + } + + @Test + public void testDefaultHttpClientBuilder() { + HttpAsyncClientBuilder clientBuilder = ApacheHttp2Transport.defaultHttpAsyncClientBuilder(); + HttpAsyncClient client = clientBuilder.build(); + checkHttpClient(client); + } + + private void checkHttpTransport(ApacheHttp2Transport transport) { + assertNotNull(transport); + HttpAsyncClient client = transport.getHttpClient(); + checkHttpClient(client); + } + + private void checkHttpClient(HttpAsyncClient client) { + assertNotNull(client); + } + + @Test + public void testRequestsWithContent() throws IOException { + // This test confirms that we can set the content on any type of request + CloseableHttpAsyncClient mockClient = new MockApacheHttp2AsyncClient() { + @SuppressWarnings("unchecked") + @Override + public Future doExecute( + final HttpHost target, + final AsyncRequestProducer requestProducer, + final AsyncResponseConsumer responseConsumer, + final HandlerFactory pushHandlerFactory, + final HttpContext context, + final FutureCallback callback) { + return (Future) CompletableFuture.completedFuture(new SimpleHttpResponse(200)); + } + }; + ApacheHttp2Transport transport = new ApacheHttp2Transport(mockClient); + + // Test GET. + execute(transport.buildRequest("GET", "http://www.test.url")); + // Test DELETE. + execute(transport.buildRequest("DELETE", "http://www.test.url")); + // Test HEAD. + execute(transport.buildRequest("HEAD", "http://www.test.url")); + // Test PATCH. + execute(transport.buildRequest("PATCH", "http://www.test.url")); + // Test PUT. + execute(transport.buildRequest("PUT", "http://www.test.url")); + // Test POST. + execute(transport.buildRequest("POST", "http://www.test.url")); + // Test PATCH. + execute(transport.buildRequest("PATCH", "http://www.test.url")); + } + + @Test + public void testNormalizedUrl() throws IOException { + final HttpRequestHandler handler = new HttpRequestHandler() { + @Override + public void handle( + ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) + throws HttpException, IOException { + // Extract the request URI and convert to bytes + byte[] responseData = request.getRequestUri().getBytes(StandardCharsets.UTF_8); + + // Set the response headers (status code and content length) + response.setCode(HttpStatus.SC_OK); + response.setHeader(HttpHeaders.CONTENT_LENGTH, String.valueOf(responseData.length)); + + // Set the response entity (body) + ByteArrayEntity entity = new ByteArrayEntity(responseData, ContentType.TEXT_PLAIN); + response.setEntity(entity); + } + }; + try (FakeServer server = new FakeServer(handler)) { + HttpTransport transport = new ApacheHttp2Transport(); + GenericUrl testUrl = new GenericUrl("http://localhost/foo//bar"); + testUrl.setPort(server.getPort()); + com.google.api.client.http.HttpResponse response = transport.createRequestFactory() + .buildGetRequest(testUrl) + .execute(); + assertEquals(200, response.getStatusCode()); + assertEquals("/foo//bar", response.parseAsString()); + } + } + + @Test + public void testReadErrorStream() throws IOException { + final HttpRequestHandler handler = new HttpRequestHandler() { + @Override + public void handle( + ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) + throws HttpException, IOException { + byte[] responseData = "Forbidden".getBytes(StandardCharsets.UTF_8); + response.setCode(HttpStatus.SC_FORBIDDEN); // 403 Forbidden + response.setHeader(HttpHeaders.CONTENT_LENGTH, String.valueOf(responseData.length)); + ByteArrayEntity entity = new ByteArrayEntity(responseData, ContentType.TEXT_PLAIN); + response.setEntity(entity); + } + }; + try (FakeServer server = new FakeServer(handler)) { + HttpTransport transport = new ApacheHttp2Transport(); + GenericUrl testUrl = new GenericUrl("http://localhost/foo//bar"); + testUrl.setPort(server.getPort()); + com.google.api.client.http.HttpRequest getRequest = transport.createRequestFactory() + .buildGetRequest(testUrl); + getRequest.setThrowExceptionOnExecuteError(false); + com.google.api.client.http.HttpResponse response = getRequest.execute(); + assertEquals(403, response.getStatusCode()); + assertEquals("Forbidden", response.parseAsString()); + } + } + + @Test + public void testReadErrorStream_withException() throws IOException { + final HttpRequestHandler handler = new HttpRequestHandler() { + @Override + public void handle( + ClassicHttpRequest request, ClassicHttpResponse response, HttpContext context) + throws HttpException, IOException { + byte[] responseData = "Forbidden".getBytes(StandardCharsets.UTF_8); + response.setCode(HttpStatus.SC_FORBIDDEN); // 403 Forbidden + response.setHeader(HttpHeaders.CONTENT_LENGTH, String.valueOf(responseData.length)); + ByteArrayEntity entity = new ByteArrayEntity(responseData, ContentType.TEXT_PLAIN); + response.setEntity(entity); + } + }; + try (FakeServer server = new FakeServer(handler)) { + HttpTransport transport = new ApacheHttp2Transport(); + GenericUrl testUrl = new GenericUrl("http://localhost/foo//bar"); + testUrl.setPort(server.getPort()); + com.google.api.client.http.HttpRequest getRequest = transport.createRequestFactory() + .buildGetRequest(testUrl); + try { + getRequest.execute(); + Assert.fail(); + } catch (HttpResponseException ex) { + assertEquals("Forbidden", ex.getContent()); + } + } + } + + private void execute(ApacheHttp2Request request) throws IOException { + byte[] bytes = "abc".getBytes(StandardCharsets.UTF_8); + request.setStreamingContent(new ByteArrayStreamingContent(bytes)); + request.setContentType("text/html"); + request.setContentLength(bytes.length); + request.execute(); + } + + private static class FakeServer implements AutoCloseable { + private final HttpServer server; + + FakeServer(final HttpRequestHandler httpHandler) throws IOException { + HttpRequestMapper mapper = new HttpRequestMapper() { + @Override + public HttpRequestHandler resolve(HttpRequest request, HttpContext context) + throws HttpException { + return httpHandler; + } + }; + server = new HttpServer( + 0, + HttpService.builder() + .withHttpProcessor( + new HttpProcessor() { + @Override + public void process( + HttpRequest request, EntityDetails entity, HttpContext context) + throws HttpException, IOException { + } + + @Override + public void process( + HttpResponse response, EntityDetails entity, HttpContext context) + throws HttpException, IOException { + } + }) + .withHttpServerRequestHandler(new BasicHttpServerRequestHandler(mapper)) + .build(), + null, + null, + null, + null, + null, + null); + server.start(); + } + + public int getPort() { + return server.getLocalPort(); + } + + @Override + public void close() { + server.initiateShutdown(); + } + } +} diff --git a/src/test/java/com/google/firebase/internal/MockApacheHttp2Response.java b/src/test/java/com/google/firebase/internal/MockApacheHttp2Response.java new file mode 100644 index 000000000..0f9c04042 --- /dev/null +++ b/src/test/java/com/google/firebase/internal/MockApacheHttp2Response.java @@ -0,0 +1,188 @@ +/* + * Copyright 2024 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.internal; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; + +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.ProtocolException; +import org.apache.hc.core5.http.ProtocolVersion; + +public class MockApacheHttp2Response implements HttpResponse { + List
headers = new ArrayList<>(); + int code = 200; + + @Override + public void setVersion(ProtocolVersion version) { + } + + @Override + public ProtocolVersion getVersion() { + return null; + } + + @Override + public void addHeader(Header header) { + headers.add(header); + } + + @Override + public void addHeader(String name, Object value) { + addHeader(newHeader(name, value)); + } + + private Header newHeader(String key, Object value) { + return new Header() { + @Override + public boolean isSensitive() { + return false; + } + + @Override + public String getName() { + return key; + } + + @Override + public String getValue() { + return value.toString(); + } + }; + } + + @Override + public void setHeader(Header header) { + if (headers.contains(header)) { + int index = headers.indexOf(header); + headers.set(index, header); + } else { + addHeader(header); + } + } + + @Override + public void setHeader(String name, Object value) { + setHeader(newHeader(name, value)); + } + + @Override + public void setHeaders(Header... headers) { + for (Header header : headers) { + setHeader(header); + } + } + + @Override + public boolean removeHeader(Header header) { + if (headers.contains(header)) { + headers.remove(headers.indexOf(header)); + return true; + } + return false; + } + + @Override + public boolean removeHeaders(String name) { + int initialSize = headers.size(); + for (Header header : headers.stream().filter(h -> h.getName() == name) + .collect(Collectors.toList())) { + removeHeader(header); + } + return headers.size() < initialSize; + } + + @Override + public boolean containsHeader(String name) { + return headers.stream().anyMatch(h -> h.getName() == name); + } + + @Override + public int countHeaders(String name) { + return headers.size(); + } + + @Override + public Header getFirstHeader(String name) { + return headers.stream().findFirst().orElse(null); + } + + @Override + public Header getHeader(String name) throws ProtocolException { + return headers.stream().filter(h -> h.getName() == name).findFirst().orElse(null); + } + + @Override + public Header[] getHeaders() { + return headers.toArray(new Header[0]); + } + + @Override + public Header[] getHeaders(String name) { + return headers.stream() + .filter(h -> h.getName() == name) + .collect(Collectors.toList()) + .toArray(new Header[0]); + } + + @Override + public Header getLastHeader(String name) { + return headers.isEmpty() ? null : headers.get(headers.size() - 1); + } + + @Override + public Iterator
headerIterator() { + return headers.iterator(); + } + + @Override + public Iterator
headerIterator(String name) { + return headers.stream().filter(h -> h.getName() == name).iterator(); + } + + @Override + public int getCode() { + return this.code; + } + + @Override + public void setCode(int code) { + this.code = code; + } + + @Override + public String getReasonPhrase() { + return null; + } + + @Override + public void setReasonPhrase(String reason) { + } + + @Override + public Locale getLocale() { + return null; + } + + @Override + public void setLocale(Locale loc) { + } +} diff --git a/src/test/java/com/google/firebase/messaging/FirebaseMessagingClientImplTest.java b/src/test/java/com/google/firebase/messaging/FirebaseMessagingClientImplTest.java index 3180aea01..b75d9a3d6 100644 --- a/src/test/java/com/google/firebase/messaging/FirebaseMessagingClientImplTest.java +++ b/src/test/java/com/google/firebase/messaging/FirebaseMessagingClientImplTest.java @@ -368,7 +368,7 @@ private FirebaseMessagingClientImpl initMessagingClient( .setProjectId("test-project") .setJsonFactory(ApiClientUtils.getDefaultJsonFactory()) .setRequestFactory(transport.createRequestFactory()) - .setChildRequestFactory(Utils.getDefaultTransport().createRequestFactory()) + .setChildRequestFactory(ApiClientUtils.getDefaultTransport().createRequestFactory()) .setResponseInterceptor(interceptor) .build(); } @@ -379,7 +379,7 @@ private FirebaseMessagingClientImpl initClientWithFaultyTransport() { .setProjectId("test-project") .setJsonFactory(ApiClientUtils.getDefaultJsonFactory()) .setRequestFactory(transport.createRequestFactory()) - .setChildRequestFactory(Utils.getDefaultTransport().createRequestFactory()) + .setChildRequestFactory(ApiClientUtils.getDefaultTransport().createRequestFactory()) .build(); } @@ -405,8 +405,8 @@ private FirebaseMessagingClientImpl.Builder fullyPopulatedBuilder() { return FirebaseMessagingClientImpl.builder() .setProjectId("test-project") .setJsonFactory(ApiClientUtils.getDefaultJsonFactory()) - .setRequestFactory(Utils.getDefaultTransport().createRequestFactory()) - .setChildRequestFactory(Utils.getDefaultTransport().createRequestFactory()); + .setRequestFactory(ApiClientUtils.getDefaultTransport().createRequestFactory()) + .setChildRequestFactory(ApiClientUtils.getDefaultTransport().createRequestFactory()); } private void checkExceptionFromHttpResponse( diff --git a/src/test/java/com/google/firebase/messaging/InstanceIdClientImplTest.java b/src/test/java/com/google/firebase/messaging/InstanceIdClientImplTest.java index e7222ccb5..1a140680f 100644 --- a/src/test/java/com/google/firebase/messaging/InstanceIdClientImplTest.java +++ b/src/test/java/com/google/firebase/messaging/InstanceIdClientImplTest.java @@ -358,7 +358,7 @@ public void testRequestFactoryIsNull() { @Test(expected = NullPointerException.class) public void testJsonFactoryIsNull() { - new InstanceIdClientImpl(Utils.getDefaultTransport().createRequestFactory(), null); + new InstanceIdClientImpl(ApiClientUtils.getDefaultTransport().createRequestFactory(), null); } @Test diff --git a/src/test/java/com/google/firebase/remoteconfig/FirebaseRemoteConfigClientImplTest.java b/src/test/java/com/google/firebase/remoteconfig/FirebaseRemoteConfigClientImplTest.java index f2d5c8126..0a04809cf 100644 --- a/src/test/java/com/google/firebase/remoteconfig/FirebaseRemoteConfigClientImplTest.java +++ b/src/test/java/com/google/firebase/remoteconfig/FirebaseRemoteConfigClientImplTest.java @@ -1173,7 +1173,7 @@ private FirebaseRemoteConfigClientImpl.Builder fullyPopulatedBuilder() { return FirebaseRemoteConfigClientImpl.builder() .setProjectId("test-project") .setJsonFactory(ApiClientUtils.getDefaultJsonFactory()) - .setRequestFactory(Utils.getDefaultTransport().createRequestFactory()); + .setRequestFactory(ApiClientUtils.getDefaultTransport().createRequestFactory()); } private void checkGetRequestHeader(HttpRequest request) {