Skip to content

Commit da4b626

Browse files
committed
OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException
Issue gh-8609
1 parent 4c902bb commit da4b626

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import java.util.function.Supplier;
3333

3434
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
35+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
36+
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
3537
import reactor.core.publisher.Mono;
3638
import reactor.util.context.Context;
3739

@@ -1094,9 +1096,14 @@ public OAuth2LoginSpec authenticationConverter(ServerAuthenticationConverter aut
10941096

10951097
private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
10961098
if (this.authenticationConverter == null) {
1097-
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
1098-
authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
1099+
ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
1100+
new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
1101+
delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
1102+
ServerAuthenticationConverter authenticationConverter = exchange ->
1103+
delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
1104+
e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
10991105
this.authenticationConverter = authenticationConverter;
1106+
return authenticationConverter;
11001107
}
11011108
return this.authenticationConverter;
11021109
}

config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

+22-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@
103103

104104
import static org.assertj.core.api.Assertions.assertThat;
105105
import static org.mockito.ArgumentMatchers.any;
106-
import static org.mockito.Mockito.*;
106+
import static org.mockito.Mockito.mock;
107+
import static org.mockito.Mockito.spy;
108+
import static org.mockito.Mockito.verify;
109+
import static org.mockito.Mockito.when;
107110
import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
108111

109112
/**
@@ -683,7 +686,6 @@ private ReactiveJwtDecoder getJwtDecoder() {
683686
}
684687
}
685688

686-
687689
@Test
688690
public void logoutWhenUsingOidcLogoutHandlerThenRedirects() {
689691
this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire();
@@ -739,6 +741,24 @@ ClientRegistration clientRegistration() {
739741
}
740742
}
741743

744+
// gh-8609
745+
@Test
746+
public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
747+
this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
748+
749+
WebTestClient webTestClient = WebTestClientBuilder
750+
.bindToWebFilters(this.springSecurity)
751+
.build();
752+
753+
webTestClient.get()
754+
.uri("/login/oauth2/code/google")
755+
.exchange()
756+
.expectStatus()
757+
.is3xxRedirection()
758+
.expectHeader()
759+
.valueEquals("Location", "/login?error");
760+
}
761+
742762
static class GitHubWebFilter implements WebFilter {
743763

744764
@Override

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,14 @@ public Mono<Authentication> authenticate(Authentication authentication) {
121121
.getAuthorizationExchange().getAuthorizationResponse();
122122

123123
if (authorizationResponse.statusError()) {
124-
throw new OAuth2AuthenticationException(
125-
authorizationResponse.getError(), authorizationResponse.getError().toString());
124+
return Mono.error(new OAuth2AuthenticationException(
125+
authorizationResponse.getError(), authorizationResponse.getError().toString()));
126126
}
127127

128128
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
129129
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
130-
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
130+
return Mono.error(new OAuth2AuthenticationException(
131+
oauth2Error, oauth2Error.toString()));
131132
}
132133

133134
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@@ -139,7 +140,7 @@ public Mono<Authentication> authenticate(Authentication authentication) {
139140
.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))
140141
.onErrorMap(JwtException.class, e -> {
141142
OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null);
142-
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
143+
return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
143144
});
144145
});
145146
}
@@ -178,7 +179,7 @@ private Mono<OAuth2LoginAuthenticationToken> authenticationResult(OAuth2Authoriz
178179
INVALID_ID_TOKEN_ERROR_CODE,
179180
"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
180181
null);
181-
throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
182+
return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
182183
}
183184

184185
return createOidcToken(clientRegistration, accessTokenResponse)

web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

+6-9
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,21 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
109109
.filter( matchResult -> matchResult.isMatch())
110110
.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
111111
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
112-
.flatMap( token -> authenticate(exchange, chain, token));
112+
.flatMap( token -> authenticate(exchange, chain, token))
113+
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
114+
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
113115
}
114116

115-
private Mono<Void> authenticate(ServerWebExchange exchange,
116-
WebFilterChain chain, Authentication token) {
117-
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
118-
117+
private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
119118
return this.authenticationManagerResolver.resolve(exchange)
120119
.flatMap(authenticationManager -> authenticationManager.authenticate(token))
121120
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
122-
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
121+
.flatMap(authentication -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain)))
123122
.doOnError(AuthenticationException.class, e -> {
124123
if (logger.isDebugEnabled()) {
125124
logger.debug("Authentication failed: " + e.getMessage());
126125
}
127-
})
128-
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
129-
.onAuthenticationFailure(webFilterExchange, e));
126+
});
130127
}
131128

132129
protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {

0 commit comments

Comments
 (0)