|
29 | 29 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
30 | 30 | import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
31 | 31 | import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
| 32 | +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
| 33 | +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; |
| 34 | +import org.springframework.security.oauth2.core.OAuth2Error; |
32 | 35 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
33 | 36 | import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
34 | 37 | import org.springframework.util.CollectionUtils;
|
|
41 | 44 | import java.util.Map;
|
42 | 45 |
|
43 | 46 | import static org.assertj.core.api.Assertions.assertThatCode;
|
| 47 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
44 | 48 | import static org.mockito.ArgumentMatchers.any;
|
45 | 49 | import static org.mockito.Mockito.times;
|
46 | 50 | import static org.mockito.Mockito.verify;
|
@@ -234,6 +238,61 @@ public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotPr
|
234 | 238 | verifyZeroInteractions(this.authenticationManager);
|
235 | 239 | }
|
236 | 240 |
|
| 241 | + // gh-8609 |
| 242 | + @Test |
| 243 | + public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 244 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 245 | + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); |
| 246 | + |
| 247 | + MockServerHttpRequest authorizationRequest = |
| 248 | + createAuthorizationRequest("/authorization/callback"); |
| 249 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 250 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 251 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 252 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 253 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 254 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 255 | + |
| 256 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 257 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 258 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 259 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 260 | + |
| 261 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 262 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 263 | + .hasMessageContaining("client_registration_not_found"); |
| 264 | + verifyZeroInteractions(this.authenticationManager); |
| 265 | + } |
| 266 | + |
| 267 | + // gh-8609 |
| 268 | + @Test |
| 269 | + public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 270 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 271 | + when(this.clientRegistrationRepository.findByRegistrationId(any())) |
| 272 | + .thenReturn(Mono.just(clientRegistration)); |
| 273 | + |
| 274 | + MockServerHttpRequest authorizationRequest = |
| 275 | + createAuthorizationRequest("/authorization/callback"); |
| 276 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 277 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 278 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 279 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 280 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 281 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 282 | + |
| 283 | + when(this.authenticationManager.authenticate(any())) |
| 284 | + .thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); |
| 285 | + |
| 286 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 287 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 288 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 289 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 290 | + |
| 291 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 292 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 293 | + .hasMessageContaining("authorization_error"); |
| 294 | + } |
| 295 | + |
237 | 296 | private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
|
238 | 297 | MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
|
239 | 298 | Map<String, Object> attributes = new HashMap<>();
|
|
0 commit comments