Skip to content

Commit 8acdb82

Browse files
committed
OAuth2AuthorizationCodeGrantWebFilter matches on query parameters
Fixes gh-7966
1 parent 5ce0ce3 commit 8acdb82

File tree

3 files changed

+202
-65
lines changed

3 files changed

+202
-65
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java

+40-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -37,13 +37,20 @@
3737
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
3838
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
3939
import org.springframework.util.Assert;
40-
import org.springframework.util.MultiValueMap;
4140
import org.springframework.web.server.ServerWebExchange;
4241
import org.springframework.web.server.WebFilter;
4342
import org.springframework.web.server.WebFilterChain;
43+
import org.springframework.web.util.UriComponents;
4444
import org.springframework.web.util.UriComponentsBuilder;
4545
import reactor.core.publisher.Mono;
4646

47+
import java.net.URI;
48+
import java.util.LinkedHashSet;
49+
import java.util.List;
50+
import java.util.Map;
51+
import java.util.Objects;
52+
import java.util.Set;
53+
4754
/**
4855
* A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
4956
* which handles the processing of the OAuth 2.0 Authorization Response.
@@ -165,10 +172,10 @@ private void updateDefaultAuthenticationConverter() {
165172
@Override
166173
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
167174
return this.requiresAuthenticationMatcher.matches(exchange)
168-
.filter( matchResult -> matchResult.isMatch())
169-
.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
175+
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
176+
.flatMap(matchResult -> this.authenticationConverter.convert(exchange))
170177
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
171-
.flatMap( token -> authenticate(exchange, chain, token));
178+
.flatMap(token -> authenticate(exchange, chain, token));
172179
}
173180

174181
private Mono<Void> authenticate(ServerWebExchange exchange,
@@ -198,20 +205,34 @@ private Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFil
198205
}
199206

200207
private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) {
201-
return this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
202-
.flatMap(authorizationRequest -> {
203-
String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
204-
.query(null)
205-
.build()
206-
.toUriString();
207-
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
208-
if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
209-
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) {
210-
return ServerWebExchangeMatcher.MatchResult.match();
211-
}
212-
return ServerWebExchangeMatcher.MatchResult.notMatch();
213-
})
214-
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
208+
return Mono.just(exchange)
209+
.filter(exch -> OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams()))
210+
.flatMap(exch -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
211+
.flatMap(authorizationRequest ->
212+
matchesRedirectUri(exch.getRequest().getURI(), authorizationRequest.getRedirectUri())))
215213
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
216214
}
215+
216+
private static Mono<ServerWebExchangeMatcher.MatchResult> matchesRedirectUri(
217+
URI authorizationResponseUri, String authorizationRequestRedirectUri) {
218+
UriComponents requestUri = UriComponentsBuilder.fromUri(authorizationResponseUri).build();
219+
UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequestRedirectUri).build();
220+
Set<Map.Entry<String, List<String>>> requestUriParameters =
221+
new LinkedHashSet<>(requestUri.getQueryParams().entrySet());
222+
Set<Map.Entry<String, List<String>>> redirectUriParameters =
223+
new LinkedHashSet<>(redirectUri.getQueryParams().entrySet());
224+
// Remove the additional request parameters (if any) from the authorization response (request)
225+
// before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any)
226+
requestUriParameters.retainAll(redirectUriParameters);
227+
228+
if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) &&
229+
Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) &&
230+
Objects.equals(requestUri.getHost(), redirectUri.getHost()) &&
231+
Objects.equals(requestUri.getPort(), redirectUri.getPort()) &&
232+
Objects.equals(requestUri.getPath(), redirectUri.getPath()) &&
233+
Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) {
234+
return ServerWebExchangeMatcher.MatchResult.match();
235+
}
236+
return ServerWebExchangeMatcher.MatchResult.notMatch();
237+
}
217238
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -28,7 +28,6 @@
2828
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
2929
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
3030
import org.springframework.util.Assert;
31-
import org.springframework.util.MultiValueMap;
3231
import org.springframework.web.server.ServerWebExchange;
3332
import org.springframework.web.util.UriComponentsBuilder;
3433
import reactor.core.publisher.Mono;
@@ -103,14 +102,10 @@ private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(S
103102
}
104103

105104
private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) {
106-
MultiValueMap<String, String> queryParams = exchange.getRequest()
107-
.getQueryParams();
108105
String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
109-
.query(null)
110106
.build()
111107
.toUriString();
112-
113108
return OAuth2AuthorizationResponseUtils
114-
.convert(queryParams, redirectUri);
109+
.convert(exchange.getRequest().getQueryParams(), redirectUri);
115110
}
116111
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java

+160-39
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,25 +25,28 @@
2525
import org.springframework.mock.web.server.MockServerWebExchange;
2626
import org.springframework.security.authentication.AnonymousAuthenticationToken;
2727
import org.springframework.security.authentication.ReactiveAuthenticationManager;
28-
import org.springframework.security.core.Authentication;
29-
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
3028
import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens;
3129
import org.springframework.security.oauth2.client.registration.ClientRegistration;
3230
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
3331
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
34-
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
3532
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
36-
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
3733
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
38-
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
39-
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
40-
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
34+
import org.springframework.util.CollectionUtils;
4135
import org.springframework.web.server.handler.DefaultWebFilterChain;
4236
import reactor.core.publisher.Mono;
4337

38+
import java.util.Collections;
39+
import java.util.HashMap;
40+
import java.util.LinkedHashMap;
41+
import java.util.Map;
42+
4443
import static org.assertj.core.api.Assertions.assertThatCode;
4544
import static org.mockito.ArgumentMatchers.any;
46-
import static org.mockito.Mockito.*;
45+
import static org.mockito.Mockito.times;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.verifyZeroInteractions;
48+
import static org.mockito.Mockito.when;
49+
import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
4750

4851
/**
4952
* @author Rob Winch
@@ -102,7 +105,7 @@ public void filterWhenNotMatchThenAuthenticationManagerNotCalled() {
102105
MockServerWebExchange exchange = MockServerWebExchange
103106
.from(MockServerHttpRequest.get("/"));
104107
DefaultWebFilterChain chain = new DefaultWebFilterChain(
105-
e -> e.getResponse().setComplete());
108+
e -> e.getResponse().setComplete(), Collections.emptyList());
106109

107110
this.filter.filter(exchange, chain).block();
108111

@@ -111,43 +114,161 @@ public void filterWhenNotMatchThenAuthenticationManagerNotCalled() {
111114

112115
@Test
113116
public void filterWhenMatchThenAuthorizedClientSaved() {
114-
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
115-
.redirectUri("/authorize/registration-id")
116-
.build();
117-
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
118-
.redirectUri("/authorize/registration-id")
119-
.build();
120-
OAuth2AuthorizationExchange authorizationExchange =
121-
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
122-
ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
123-
Mono<Authentication> authentication = Mono.just(
124-
new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange));
125-
OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
126-
.authenticated();
127-
128-
when(this.authenticationManager.authenticate(any())).thenReturn(
129-
Mono.just(authenticated));
117+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
118+
when(this.clientRegistrationRepository.findByRegistrationId(any()))
119+
.thenReturn(Mono.just(clientRegistration));
120+
121+
MockServerHttpRequest authorizationRequest =
122+
createAuthorizationRequest("/authorization/callback");
123+
OAuth2AuthorizationRequest oauth2AuthorizationRequest =
124+
createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
130125
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
131-
.thenReturn(Mono.just(authorizationRequest));
126+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
127+
when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
128+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
129+
132130
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
133131
.thenReturn(Mono.empty());
134-
ServerAuthenticationConverter converter = e -> authentication;
132+
when(this.authenticationManager.authenticate(any()))
133+
.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
135134

136-
this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
137-
this.authenticationManager, converter, this.authorizedClientRepository);
138-
this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
139-
140-
MockServerHttpRequest request = MockServerHttpRequest
141-
.get("/authorize/registration-id")
142-
.queryParam(OAuth2ParameterNames.CODE, "code")
143-
.queryParam(OAuth2ParameterNames.STATE, "state")
144-
.build();
145-
MockServerWebExchange exchange = MockServerWebExchange.from(request);
135+
MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
136+
MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
146137
DefaultWebFilterChain chain = new DefaultWebFilterChain(
147-
e -> e.getResponse().setComplete());
138+
e -> e.getResponse().setComplete(), Collections.emptyList());
148139

149140
this.filter.filter(exchange, chain).block();
150141

151142
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any());
152143
}
144+
145+
// gh-7966
146+
@Test
147+
public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() {
148+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
149+
when(this.clientRegistrationRepository.findByRegistrationId(any()))
150+
.thenReturn(Mono.just(clientRegistration));
151+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
152+
.thenReturn(Mono.empty());
153+
when(this.authenticationManager.authenticate(any()))
154+
.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
155+
156+
// 1) redirect_uri with query parameters
157+
Map<String, String> parameters = new LinkedHashMap<>();
158+
parameters.put("param1", "value1");
159+
parameters.put("param2", "value2");
160+
MockServerHttpRequest authorizationRequest =
161+
createAuthorizationRequest("/authorization/callback", parameters);
162+
OAuth2AuthorizationRequest oauth2AuthorizationRequest =
163+
createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
164+
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
165+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
166+
when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
167+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
168+
169+
MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
170+
MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
171+
DefaultWebFilterChain chain = new DefaultWebFilterChain(
172+
e -> e.getResponse().setComplete(), Collections.emptyList());
173+
174+
this.filter.filter(exchange, chain).block();
175+
verify(this.authenticationManager, times(1)).authenticate(any());
176+
177+
// 2) redirect_uri with query parameters AND authorization response additional parameters
178+
Map<String, String> additionalParameters = new LinkedHashMap<>();
179+
additionalParameters.put("auth-param1", "value1");
180+
additionalParameters.put("auth-param2", "value2");
181+
authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
182+
exchange = MockServerWebExchange.from(authorizationResponse);
183+
184+
this.filter.filter(exchange, chain).block();
185+
verify(this.authenticationManager, times(2)).authenticate(any());
186+
}
187+
188+
// gh-7966
189+
@Test
190+
public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotProcessed() {
191+
String requestUri = "/authorization/callback";
192+
Map<String, String> parameters = new LinkedHashMap<>();
193+
parameters.put("param1", "value1");
194+
parameters.put("param2", "value2");
195+
MockServerHttpRequest authorizationRequest =
196+
createAuthorizationRequest(requestUri, parameters);
197+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
198+
OAuth2AuthorizationRequest oauth2AuthorizationRequest =
199+
createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
200+
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
201+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
202+
203+
// 1) Parameter value
204+
Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
205+
parametersNotMatch.put("param2", "value8");
206+
MockServerHttpRequest authorizationResponse = createAuthorizationResponse(
207+
createAuthorizationRequest(requestUri, parametersNotMatch));
208+
MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
209+
DefaultWebFilterChain chain = new DefaultWebFilterChain(
210+
e -> e.getResponse().setComplete(), Collections.emptyList());
211+
212+
this.filter.filter(exchange, chain).block();
213+
verifyZeroInteractions(this.authenticationManager);
214+
215+
// 2) Parameter order
216+
parametersNotMatch = new LinkedHashMap<>();
217+
parametersNotMatch.put("param2", "value2");
218+
parametersNotMatch.put("param1", "value1");
219+
authorizationResponse = createAuthorizationResponse(
220+
createAuthorizationRequest(requestUri, parametersNotMatch));
221+
exchange = MockServerWebExchange.from(authorizationResponse);
222+
223+
this.filter.filter(exchange, chain).block();
224+
verifyZeroInteractions(this.authenticationManager);
225+
226+
// 3) Parameter missing
227+
parametersNotMatch = new LinkedHashMap<>(parameters);
228+
parametersNotMatch.remove("param2");
229+
authorizationResponse = createAuthorizationResponse(
230+
createAuthorizationRequest(requestUri, parametersNotMatch));
231+
exchange = MockServerWebExchange.from(authorizationResponse);
232+
233+
this.filter.filter(exchange, chain).block();
234+
verifyZeroInteractions(this.authenticationManager);
235+
}
236+
237+
private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
238+
MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
239+
Map<String, Object> attributes = new HashMap<>();
240+
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
241+
return request()
242+
.attributes(attributes)
243+
.redirectUri(authorizationRequest.getURI().toString())
244+
.build();
245+
}
246+
247+
private static MockServerHttpRequest createAuthorizationRequest(String requestUri) {
248+
return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
249+
}
250+
251+
private static MockServerHttpRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
252+
MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
253+
.get(requestUri);
254+
if (!CollectionUtils.isEmpty(parameters)) {
255+
parameters.forEach(builder::queryParam);
256+
}
257+
return builder.build();
258+
}
259+
260+
private static MockServerHttpRequest createAuthorizationResponse(MockServerHttpRequest authorizationRequest) {
261+
return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
262+
}
263+
264+
private static MockServerHttpRequest createAuthorizationResponse(
265+
MockServerHttpRequest authorizationRequest, Map<String, String> additionalParameters) {
266+
MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
267+
.get(authorizationRequest.getURI().toString());
268+
builder.queryParam(OAuth2ParameterNames.CODE, "code");
269+
builder.queryParam(OAuth2ParameterNames.STATE, "state");
270+
additionalParameters.forEach(builder::queryParam);
271+
builder.cookies(authorizationRequest.getCookies());
272+
return builder.build();
273+
}
153274
}

0 commit comments

Comments
 (0)