From d6d40aa44f21b02bb55f04b41fc34442480e96fa Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:03:35 -0600 Subject: [PATCH 1/2] Add servlet support for OAuth 2.0 Token Exchange Grant Issue gh-5199 --- ...xchangeOAuth2AuthorizedClientProvider.java | 176 +++++++ ...faultTokenExchangeTokenResponseClient.java | 126 +++++ .../endpoint/TokenExchangeGrantRequest.java | 78 +++ ...enExchangeGrantRequestEntityConverter.java | 79 +++ ...geOAuth2AuthorizedClientProviderTests.java | 384 ++++++++++++++ ...TokenExchangeTokenResponseClientTests.java | 491 ++++++++++++++++++ ...hangeGrantRequestEntityConverterTests.java | 308 +++++++++++ .../TokenExchangeGrantRequestTests.java | 91 ++++ .../oauth2/core/AuthorizationGrantType.java | 6 + .../core/endpoint/OAuth2ParameterNames.java | 36 ++ 10 files changed, 1775 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..256ced675ab --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java @@ -0,0 +1,176 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultTokenExchangeTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see OAuth2AuthorizedClientProvider + * @see DefaultTokenExchangeTokenResponseClient + */ +public final class TokenExchangeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + + private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultTokenExchangeTokenResponseClient(); + + private Function subjectTokenResolver = this::resolveSubjectToken; + + private Function actorTokenResolver = (context) -> null; + + private Duration clockSkew = Duration.ofSeconds(60); + + private Clock clock = Clock.systemUTC(); + + /** + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns {@code null} if authorization (or re-authorization) is not + * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() + * authorization grant type} is not {@link AuthorizationGrantType#TOKEN_EXCHANGE + * token-exchange} OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} + * is not expired. + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not + * supported + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no + // need for re-authorization + return null; + } + OAuth2Token subjectToken = this.subjectTokenResolver.apply(context); + if (subjectToken == null) { + return null; + } + + OAuth2Token actorToken = this.actorTokenResolver.apply(context); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, subjectToken, + actorToken); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, grantRequest); + + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken()); + } + + private OAuth2Token resolveSubjectToken(OAuth2AuthorizationContext context) { + if (context.getPrincipal().getPrincipal() instanceof OAuth2Token accessToken) { + return accessToken; + } + return null; + } + + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + TokenExchangeGrantRequest tokenExchangeGrantRequest) { + try { + return this.accessTokenResponseClient.getTokenResponse(tokenExchangeGrantRequest); + } + catch (OAuth2AuthorizationException ex) { + throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); + } + } + + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code token-exchange} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code token-exchange} grant + */ + public void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token subject token}. + * @param subjectTokenResolver the resolver used for resolving the {@link OAuth2Token + * subject token} + */ + public void setSubjectTokenResolver(Function subjectTokenResolver) { + Assert.notNull(subjectTokenResolver, "subjectTokenResolver cannot be null"); + this.subjectTokenResolver = subjectTokenResolver; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token actor token}. + * @param actorTokenResolver the resolver used for resolving the {@link OAuth2Token + * actor token} + */ + public void setActorTokenResolver(Function actorTokenResolver) { + Assert.notNull(actorTokenResolver, "actorTokenResolver cannot be null"); + this.actorTokenResolver = actorTokenResolver; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. + * + *

+ * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java new file mode 100644 index 00000000000..787e72ad877 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Arrays; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +/** + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. This implementation + * uses a {@link RestOperations} when requesting an access token credential at the + * Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.3 + * @see OAuth2AccessTokenResponseClient + * @see TokenExchangeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public final class DefaultTokenExchangeTokenResponseClient + implements OAuth2AccessTokenResponseClient { + + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private Converter> requestEntityConverter = new ClientAuthenticationMethodValidatingRequestEntityConverter<>( + new TokenExchangeGrantRequestEntityConverter()); + + private RestOperations restOperations; + + public DefaultTokenExchangeTokenResponseClient() { + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest grantRequest) { + Assert.notNull(grantRequest, "grantRequest cannot be null"); + RequestEntity requestEntity = this.requestEntityConverter.convert(grantRequest); + ResponseEntity responseEntity = getResponse(requestEntity); + + return responseEntity.getBody(); + } + + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + + /** + * Sets the {@link Converter} used for converting the + * {@link TokenExchangeGrantRequest} to a {@link RequestEntity} representation of the + * OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request + */ + public void setRequestEntityConverter( + Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: + *

    + *
  1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
  2. + *
  3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  4. + *
+ * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response + */ + public void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java new file mode 100644 index 00000000000..0a026a56724 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.util.Assert; + +/** + * A Token Exchange Grant request that holds the {@link OAuth2Token subject token} and + * optional {@link OAuth2Token actor token}. + * + * @author Steve Riesenberg + * @since 6.3 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see ClientRegistration + * @see OAuth2Token + * @see Section + * 1.1 Delegation vs. Impersonation Semantics + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + + private final OAuth2Token subjectToken; + + private final OAuth2Token actorToken; + + /** + * Constructs a {@code TokenExchangeGrantRequest} using the provided parameters. + * @param clientRegistration the client registration + * @param subjectToken the subject token + * @param actorToken the actor token + */ + public TokenExchangeGrantRequest(ClientRegistration clientRegistration, OAuth2Token subjectToken, + OAuth2Token actorToken) { + super(AuthorizationGrantType.TOKEN_EXCHANGE, clientRegistration); + Assert.isTrue(AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType()), + "clientRegistration.authorizationGrantType must be AuthorizationGrantType.TOKEN_EXCHANGE"); + Assert.notNull(subjectToken, "subjectToken cannot be null"); + this.subjectToken = subjectToken; + this.actorToken = actorToken; + } + + /** + * Returns the {@link OAuth2Token subject token}. + * @return the {@link OAuth2Token subject token} + */ + public OAuth2Token getSubjectToken() { + return this.subjectToken; + } + + /** + * Returns the {@link OAuth2Token actor token}. + * @return the {@link OAuth2Token actor token} + */ + public OAuth2Token getActorToken() { + return this.actorToken; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java new file mode 100644 index 00000000000..c8f72e4adb4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link TokenExchangeGrantRequest} to a {@link RequestEntity} + * representation of an OAuth 2.0 Access Token Request for the Token Exchange Grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter + * @see TokenExchangeGrantRequest + * @see RequestEntity + * @see Section + * 1.1 Delegation vs. Impersonation Semantics + */ +public class TokenExchangeGrantRequestEntityConverter + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + @Override + protected MultiValueMap createParameters(TokenExchangeGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + parameters.add(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + parameters.add(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + parameters.add(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + parameters.add(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + parameters.add(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + return parameters; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..8cf3b0fdf0f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,384 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link TokenExchangeOAuth2AuthorizedClientProvider}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeOAuth2AuthorizedClientProviderTests { + + private TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider; + + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private ClientRegistration clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private Authentication principal; + + @BeforeEach + public void setUp() { + this.authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + // @formatter:off + this.clientRegistration = ClientRegistration.withRegistrationId("token-exchange") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .tokenUri("https://example.com/oauth2/token") + .build(); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + this.principal = new TestingAuthenticationToken(this.subjectToken, this.subjectToken); + } + + @Test + public void setAccessTokenResponseClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on + } + + @Test + public void setSubjectTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setSubjectTokenResolver(null)) + .withMessage("subjectTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setActorTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setActorTokenResolver(null)) + .withMessage("actorTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on + } + + @Test + public void setClockWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .withMessage("context cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenNotTokenExchangeThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredThenNotReauthorized() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() { + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willThrow(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST))); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + + // @formatter:off + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.minus(Duration.ofMinutes(30)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.plus(Duration.ofMinutes(1)); + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client + this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotResolveThenUnableToAuthorize() { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(new TestingAuthenticationToken("user", "password")) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { + Function subjectTokenResolver = mock(Function.class); + given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.subjectToken); + this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(subjectTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomActorTokenResolverSetThenCalled() { + Function actorTokenResolver = mock(Function.class); + given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.actorToken); + this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(actorTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isEqualTo(this.actorToken); + } + + @Test + public void authorizeWhenClockSetThenCalled() { + Clock clock = mock(Clock.class); + given(clock.instant()).willReturn(Instant.now()); + this.authorizedClientProvider.setClock(clock); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verify(clock).instant(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java new file mode 100644 index 00000000000..9624c6465cf --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java @@ -0,0 +1,491 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DefaultJwtBearerTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class DefaultTokenExchangeTokenResponseClientTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private DefaultTokenExchangeTokenResponseClient tokenResponseClient; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private MockWebServer server; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new DefaultTokenExchangeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) + .withMessage("requestEntityConverter cannot be null"); + // @formatter:on + } + + @Test + public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) + .withMessage("restOperations cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessageContaining("[invalid_grant]"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomRequestEntityConverterSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Converter> requestEntityConverter = spy( + TokenExchangeGrantRequestEntityConverter.class); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(requestEntityConverter).convert(grantRequest); + } + + @Test + public void getTokenResponseWhenCustomRestOperationsSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestOperations restOperations = mock(RestOperations.class); + given(restOperations.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .willReturn(new ResponseEntity<>(HttpStatus.OK)); + this.tokenResponseClient.setRestOperations(restOperations); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(restOperations).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java new file mode 100644 index 00000000000..8a77a66dfb6 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.InOrder; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link TokenExchangeGrantRequestEntityConverter}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeGrantRequestEntityConverterTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private TokenExchangeGrantRequestEntityConverter converter; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + @BeforeEach + public void setUp() { + this.converter = new TokenExchangeGrantRequestEntityConverter(); + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.converter.convert(grantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(grantRequest); + inOrder.verify(headersConverter2).convert(grantRequest); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.converter.convert(grantRequest); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(TokenExchangeGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(TokenExchangeGrantRequest.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenGrantRequestValidThenConverts() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()) + .contains(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenClientAuthenticationMethodIsClientSecretPostThenClientIdAndSecretParametersPresent() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()) + .contains(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isEqualTo(clientRegistration.getClientId()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_SECRET)) + .isEqualTo(clientRegistration.getClientSecret()); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenActorTokenIsNotNullThenActorTokenParametersPresent() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN)) + .isEqualTo(this.actorToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(JWT_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenActorTokenIsJwtThenActorTokenTypeIsJwt() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN)) + .isEqualTo(this.actorToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).isEqualTo(JWT_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestTests.java new file mode 100644 index 00000000000..f8b740eea97 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestTests.java @@ -0,0 +1,91 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link TokenExchangeGrantRequest}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeGrantRequestTests { + + private final ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .build(); + + private final OAuth2Token subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + + private final OAuth2Token actorToken = TestOAuth2AccessTokens.noScopes(); + + @Test + public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new TokenExchangeGrantRequest(null, this.subjectToken, this.actorToken)) + .withMessage("clientRegistration cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenSubjectTokenIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new TokenExchangeGrantRequest(this.clientRegistration, null, this.actorToken)) + .withMessage("subjectToken cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenActorTokenIsNullThenCreated() { + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration, + this.subjectToken, null); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE); + assertThat(grantRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(grantRequest.getSubjectToken()).isSameAs(this.subjectToken); + } + + @Test + public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() { + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new TokenExchangeGrantRequest(registration, this.subjectToken, this.actorToken)) + .withMessage("clientRegistration.authorizationGrantType must be AuthorizationGrantType.TOKEN_EXCHANGE"); + // @formatter:on + } + + @Test + public void constructorWhenValidParametersProvidedThenCreated() { + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration, + this.subjectToken, this.actorToken); + assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE); + assertThat(grantRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(grantRequest.getSubjectToken()).isSameAs(this.subjectToken); + } + +} diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java index 07b16e2b741..e1321bd7595 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java @@ -69,6 +69,12 @@ public final class AuthorizationGrantType implements Serializable { public static final AuthorizationGrantType DEVICE_CODE = new AuthorizationGrantType( "urn:ietf:params:oauth:grant-type:device_code"); + /** + * @since 6.3 + */ + public static final AuthorizationGrantType TOKEN_EXCHANGE = new AuthorizationGrantType( + "urn:ietf:params:oauth:grant-type:token-exchange"); + private final String value; /** diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java index d387b482d94..d482d7c1968 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java @@ -182,6 +182,42 @@ public final class OAuth2ParameterNames { */ public static final String INTERVAL = "interval"; + /** + * {@code requested_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String REQUESTED_TOKEN_TYPE = "requested_token_type"; + + /** + * {@code issued_token_type} - used in Token Exchange Access Token Response. + * @since 6.3 + */ + private static final String ISSUED_TOKEN_TYPE = "issued_token_type"; + + /** + * {@code subject_token} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String SUBJECT_TOKEN = "subject_token"; + + /** + * {@code subject_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String SUBJECT_TOKEN_TYPE = "subject_token_type"; + + /** + * {@code actor_token} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String ACTOR_TOKEN = "actor_token"; + + /** + * {@code actor_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String ACTOR_TOKEN_TYPE = "actor_token_type"; + private OAuth2ParameterNames() { } From 13e70c365474660582bf050c2c739b26cfd2d4cc Mon Sep 17 00:00:00 2001 From: Steve Riesenberg <5248162+sjohnr@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:40:06 -0600 Subject: [PATCH 2/2] Add reactive support for OAuth 2.0 Token Exchange Grant Issue gh-5199 --- ...eactiveOAuth2AuthorizedClientProvider.java | 166 +++++ ...ctiveTokenExchangeTokenResponseClient.java | 83 +++ ...veOAuth2AuthorizedClientProviderTests.java | 389 +++++++++++ ...TokenExchangeTokenResponseClientTests.java | 654 ++++++++++++++++++ 4 files changed, 1292 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..43e0607d2ea --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProvider.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import reactor.core.publisher.Mono; + +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveTokenExchangeTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link ReactiveOAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see ReactiveOAuth2AuthorizedClientProvider + * @see WebClientReactiveTokenExchangeTokenResponseClient + */ +public final class TokenExchangeReactiveOAuth2AuthorizedClientProvider + implements ReactiveOAuth2AuthorizedClientProvider { + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient = new WebClientReactiveTokenExchangeTokenResponseClient(); + + private Function> subjectTokenResolver = this::resolveSubjectToken; + + private Function> actorTokenResolver = (context) -> Mono.empty(); + + private Duration clockSkew = Duration.ofSeconds(60); + + private Clock clock = Clock.systemUTC(); + + /** + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns an empty {@code Mono} if authorization (or + * re-authorization) is not supported, e.g. the client's + * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is + * not {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} OR the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if + * authorization is not supported + */ + @Override + public Mono authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType())) { + return Mono.empty(); + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no + // need for re-authorization + return Mono.empty(); + } + + return this.subjectTokenResolver.apply(context) + .flatMap((subjectToken) -> this.actorTokenResolver.apply(context) + .map((actorToken) -> new TokenExchangeGrantRequest(clientRegistration, subjectToken, actorToken)) + .defaultIfEmpty(new TokenExchangeGrantRequest(clientRegistration, subjectToken, null))) + .flatMap(this.accessTokenResponseClient::getTokenResponse) + .onErrorMap(OAuth2AuthorizationException.class, + (ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex)) + .map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken())); + } + + private Mono resolveSubjectToken(OAuth2AuthorizationContext context) { + // @formatter:off + return Mono.just(context) + .map((ctx) -> ctx.getPrincipal().getPrincipal()) + .filter((principal) -> principal instanceof OAuth2Token) + .cast(OAuth2Token.class); + // @formatter:on + } + + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code token-exchange} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code token-exchange} grant + */ + public void setAccessTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token subject token}. + * @param subjectTokenResolver the resolver used for resolving the {@link OAuth2Token + * subject token} + */ + public void setSubjectTokenResolver(Function> subjectTokenResolver) { + Assert.notNull(subjectTokenResolver, "subjectTokenResolver cannot be null"); + this.subjectTokenResolver = subjectTokenResolver; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token actor token}. + * @param actorTokenResolver the resolver used for resolving the {@link OAuth2Token + * actor token} + */ + public void setActorTokenResolver(Function> actorTokenResolver) { + Assert.notNull(actorTokenResolver, "actorTokenResolver cannot be null"); + this.actorTokenResolver = actorTokenResolver; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. + * + *

+ * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java new file mode 100644 index 00000000000..abc9ad751b8 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClient.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Set; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; + +/** + * The default implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} for + * the {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. This + * implementation uses {@link WebClient} when requesting an access token credential at the + * Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.3 + * @see ReactiveOAuth2AccessTokenResponseClient + * @see TokenExchangeGrantRequest + * @see OAuth2AccessToken + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public final class WebClientReactiveTokenExchangeTokenResponseClient + extends AbstractWebClientReactiveOAuth2AccessTokenResponseClient { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + @Override + ClientRegistration clientRegistration(TokenExchangeGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); + } + + @Override + Set scopes(TokenExchangeGrantRequest grantRequest) { + return grantRequest.getClientRegistration().getScopes(); + } + + @Override + BodyInserters.FormInserter populateTokenRequestBody(TokenExchangeGrantRequest grantRequest, + BodyInserters.FormInserter body) { + super.populateTokenRequestBody(grantRequest, body); + body.with(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + body.with(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + body.with(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + body.with(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + body.with(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + return body; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..2b7250911f7 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeReactiveOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,389 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link TokenExchangeReactiveOAuth2AuthorizedClientProvider}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeReactiveOAuth2AuthorizedClientProviderTests { + + private TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + + private ReactiveOAuth2AccessTokenResponseClient accessTokenResponseClient; + + private ClientRegistration clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private Authentication principal; + + @BeforeEach + public void setUp() { + this.authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(ReactiveOAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + // @formatter:off + this.clientRegistration = ClientRegistration.withRegistrationId("token-exchange") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .tokenUri("https://example.com/oauth2/token") + .build(); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + this.principal = new TestingAuthenticationToken(this.subjectToken, this.subjectToken); + } + + @Test + public void setAccessTokenResponseClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on + } + + @Test + public void setSubjectTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setSubjectTokenResolver(null)) + .withMessage("subjectTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setActorTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setActorTokenResolver(null)) + .withMessage("actorTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on + } + + @Test + public void setClockWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) + .withMessage("context cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenNotTokenExchangeThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredThenNotReauthorized() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() { + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))).willReturn( + Mono.error(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST)))); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + + // @formatter:off + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST)) + .withMessageContaining("[invalid_request]"); + // @formatter:on + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.minus(Duration.ofMinutes(30)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.plus(Duration.ofMinutes(1)); + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client + this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) + .block(); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotResolveThenUnableToAuthorize() { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(new TestingAuthenticationToken("user", "password")) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomSubjectTokenResolverSetThenCalled() { + Function> subjectTokenResolver = mock(Function.class); + given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))) + .willReturn(Mono.just(this.subjectToken)); + this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(subjectTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomActorTokenResolverSetThenCalled() { + Function> actorTokenResolver = mock(Function.class); + given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(Mono.just(this.actorToken)); + this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(Mono.just(accessTokenResponse)); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(actorTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isEqualTo(this.actorToken); + } + + @Test + public void authorizeWhenClockSetThenCalled() { + Clock clock = mock(Clock.class); + given(clock.instant()).willReturn(Instant.now()); + this.authorizedClientProvider.setClock(clock); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); + verify(clock).instant(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java new file mode 100644 index 00000000000..c9bbfcf5178 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveTokenExchangeTokenResponseClientTests.java @@ -0,0 +1,654 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Collections; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ReactiveHttpInputMessage; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyExtractor; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link WebClientReactiveTokenExchangeTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class WebClientReactiveTokenExchangeTokenResponseClientTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private WebClientReactiveTokenExchangeTokenResponseClient tokenResponseClient; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private MockWebServer server; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new WebClientReactiveTokenExchangeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setWebClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setWebClient(null)) + .withMessage("webClient cannot be null"); + // @formatter:on + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setBodyExtractorWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setBodyExtractor(null)) + .withMessage("bodyExtractor cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenGrantRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("grantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .havingRootCause().withMessage("Unsupported token_type: not-bearer"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest).block(); + assertThat(accessTokenResponse).isNotNull(); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenInvalidResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(301)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .withMessage("[invalid_token_response] Empty OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"server_error\", \"error_description\": \"A server error occurred\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(500)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR)) + .withMessage("[server_error] A server error occurred"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\", \"error_description\": \"Invalid grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + TokenExchangeGrantRequest request = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(request).block()) + .satisfies((ex) -> assertThat(ex.getError().getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT)) + .withMessage("[invalid_grant] Invalid grant"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest).block()) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.addHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenHeadersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter headersConverter = mock(Converter.class); + HttpHeaders headers = new HttpHeaders(); + headers.put("custom-header-name", Collections.singletonList("custom-header-value")); + given(headersConverter.convert(grantRequest)).willReturn(headers); + this.tokenResponseClient.setHeadersConverter(headersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(headersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(recordedRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterSetThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("custom-parameter-name=custom-parameter-value"); + } + + @Test + public void getTokenResponseWhenParametersConverterAddedThenCalled() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(grantRequest)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(parametersConverter); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(parametersConverter).convert(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")), + param("custom-parameter-name", "custom-parameter-value") + ); + // @formatter:on + } + + @Test + public void getTokenResponseWhenBodyExtractorSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( + BodyExtractor.class); + OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(bodyExtractor.extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class))) + .willReturn(Mono.just(response)); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.setBodyExtractor(bodyExtractor); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(bodyExtractor).extract(any(ReactiveHttpInputMessage.class), any(BodyExtractor.Context.class)); + } + + @Test + public void getTokenResponseWhenWebClientSetThenCalled() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + WebClient customClient = mock(WebClient.class); + given(customClient.post()).willReturn(WebClient.builder().build().post()); + this.tokenResponseClient.setWebClient(customClient); + ClientRegistration clientRegistration = this.clientRegistration.build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest).block(); + verify(customClient).post(); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +}