1
1
/*
2
- * Copyright 2002-2019 the original author or authors.
2
+ * Copyright 2002-2020 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
25
25
import org .springframework .mock .web .server .MockServerWebExchange ;
26
26
import org .springframework .security .authentication .AnonymousAuthenticationToken ;
27
27
import org .springframework .security .authentication .ReactiveAuthenticationManager ;
28
- import org .springframework .security .core .Authentication ;
29
- import org .springframework .security .oauth2 .client .authentication .OAuth2AuthorizationCodeAuthenticationToken ;
30
28
import org .springframework .security .oauth2 .client .authentication .TestOAuth2AuthorizationCodeAuthenticationTokens ;
31
29
import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
32
30
import org .springframework .security .oauth2 .client .registration .ReactiveClientRegistrationRepository ;
33
31
import org .springframework .security .oauth2 .client .registration .TestClientRegistrations ;
34
- import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationExchange ;
35
32
import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationRequest ;
36
- import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationResponse ;
37
33
import org .springframework .security .oauth2 .core .endpoint .OAuth2ParameterNames ;
38
- import org .springframework .security .oauth2 .core .endpoint .TestOAuth2AuthorizationRequests ;
39
- import org .springframework .security .oauth2 .core .endpoint .TestOAuth2AuthorizationResponses ;
40
- import org .springframework .security .web .server .authentication .ServerAuthenticationConverter ;
34
+ import org .springframework .util .CollectionUtils ;
41
35
import org .springframework .web .server .handler .DefaultWebFilterChain ;
42
36
import reactor .core .publisher .Mono ;
43
37
38
+ import java .util .Collections ;
39
+ import java .util .HashMap ;
40
+ import java .util .LinkedHashMap ;
41
+ import java .util .Map ;
42
+
44
43
import static org .assertj .core .api .Assertions .assertThatCode ;
45
44
import static org .mockito .ArgumentMatchers .any ;
46
- import static org .mockito .Mockito .*;
45
+ import static org .mockito .Mockito .times ;
46
+ import static org .mockito .Mockito .verify ;
47
+ import static org .mockito .Mockito .verifyZeroInteractions ;
48
+ import static org .mockito .Mockito .when ;
49
+ import static org .springframework .security .oauth2 .core .endpoint .TestOAuth2AuthorizationRequests .request ;
47
50
48
51
/**
49
52
* @author Rob Winch
@@ -102,7 +105,7 @@ public void filterWhenNotMatchThenAuthenticationManagerNotCalled() {
102
105
MockServerWebExchange exchange = MockServerWebExchange
103
106
.from (MockServerHttpRequest .get ("/" ));
104
107
DefaultWebFilterChain chain = new DefaultWebFilterChain (
105
- e -> e .getResponse ().setComplete ());
108
+ e -> e .getResponse ().setComplete (), Collections . emptyList () );
106
109
107
110
this .filter .filter (exchange , chain ).block ();
108
111
@@ -111,43 +114,161 @@ public void filterWhenNotMatchThenAuthenticationManagerNotCalled() {
111
114
112
115
@ Test
113
116
public void filterWhenMatchThenAuthorizedClientSaved () {
114
- OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests .request ()
115
- .redirectUri ("/authorize/registration-id" )
116
- .build ();
117
- OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses .success ()
118
- .redirectUri ("/authorize/registration-id" )
119
- .build ();
120
- OAuth2AuthorizationExchange authorizationExchange =
121
- new OAuth2AuthorizationExchange (authorizationRequest , authorizationResponse );
122
- ClientRegistration registration = TestClientRegistrations .clientRegistration ().build ();
123
- Mono <Authentication > authentication = Mono .just (
124
- new OAuth2AuthorizationCodeAuthenticationToken (registration , authorizationExchange ));
125
- OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
126
- .authenticated ();
127
-
128
- when (this .authenticationManager .authenticate (any ())).thenReturn (
129
- Mono .just (authenticated ));
117
+ ClientRegistration clientRegistration = TestClientRegistrations .clientRegistration ().build ();
118
+ when (this .clientRegistrationRepository .findByRegistrationId (any ()))
119
+ .thenReturn (Mono .just (clientRegistration ));
120
+
121
+ MockServerHttpRequest authorizationRequest =
122
+ createAuthorizationRequest ("/authorization/callback" );
123
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
124
+ createOAuth2AuthorizationRequest (authorizationRequest , clientRegistration );
130
125
when (this .authorizationRequestRepository .loadAuthorizationRequest (any ()))
131
- .thenReturn (Mono .just (authorizationRequest ));
126
+ .thenReturn (Mono .just (oauth2AuthorizationRequest ));
127
+ when (this .authorizationRequestRepository .removeAuthorizationRequest (any ()))
128
+ .thenReturn (Mono .just (oauth2AuthorizationRequest ));
129
+
132
130
when (this .authorizedClientRepository .saveAuthorizedClient (any (), any (), any ()))
133
131
.thenReturn (Mono .empty ());
134
- ServerAuthenticationConverter converter = e -> authentication ;
132
+ when (this .authenticationManager .authenticate (any ()))
133
+ .thenReturn (Mono .just (TestOAuth2AuthorizationCodeAuthenticationTokens .authenticated ()));
135
134
136
- this .filter = new OAuth2AuthorizationCodeGrantWebFilter (
137
- this .authenticationManager , converter , this .authorizedClientRepository );
138
- this .filter .setAuthorizationRequestRepository (this .authorizationRequestRepository );
139
-
140
- MockServerHttpRequest request = MockServerHttpRequest
141
- .get ("/authorize/registration-id" )
142
- .queryParam (OAuth2ParameterNames .CODE , "code" )
143
- .queryParam (OAuth2ParameterNames .STATE , "state" )
144
- .build ();
145
- MockServerWebExchange exchange = MockServerWebExchange .from (request );
135
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse (authorizationRequest );
136
+ MockServerWebExchange exchange = MockServerWebExchange .from (authorizationResponse );
146
137
DefaultWebFilterChain chain = new DefaultWebFilterChain (
147
- e -> e .getResponse ().setComplete ());
138
+ e -> e .getResponse ().setComplete (), Collections . emptyList () );
148
139
149
140
this .filter .filter (exchange , chain ).block ();
150
141
151
142
verify (this .authorizedClientRepository ).saveAuthorizedClient (any (), any (AnonymousAuthenticationToken .class ), any ());
152
143
}
144
+
145
+ // gh-7966
146
+ @ Test
147
+ public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed () {
148
+ ClientRegistration clientRegistration = TestClientRegistrations .clientRegistration ().build ();
149
+ when (this .clientRegistrationRepository .findByRegistrationId (any ()))
150
+ .thenReturn (Mono .just (clientRegistration ));
151
+ when (this .authorizedClientRepository .saveAuthorizedClient (any (), any (), any ()))
152
+ .thenReturn (Mono .empty ());
153
+ when (this .authenticationManager .authenticate (any ()))
154
+ .thenReturn (Mono .just (TestOAuth2AuthorizationCodeAuthenticationTokens .authenticated ()));
155
+
156
+ // 1) redirect_uri with query parameters
157
+ Map <String , String > parameters = new LinkedHashMap <>();
158
+ parameters .put ("param1" , "value1" );
159
+ parameters .put ("param2" , "value2" );
160
+ MockServerHttpRequest authorizationRequest =
161
+ createAuthorizationRequest ("/authorization/callback" , parameters );
162
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
163
+ createOAuth2AuthorizationRequest (authorizationRequest , clientRegistration );
164
+ when (this .authorizationRequestRepository .loadAuthorizationRequest (any ()))
165
+ .thenReturn (Mono .just (oauth2AuthorizationRequest ));
166
+ when (this .authorizationRequestRepository .removeAuthorizationRequest (any ()))
167
+ .thenReturn (Mono .just (oauth2AuthorizationRequest ));
168
+
169
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse (authorizationRequest );
170
+ MockServerWebExchange exchange = MockServerWebExchange .from (authorizationResponse );
171
+ DefaultWebFilterChain chain = new DefaultWebFilterChain (
172
+ e -> e .getResponse ().setComplete (), Collections .emptyList ());
173
+
174
+ this .filter .filter (exchange , chain ).block ();
175
+ verify (this .authenticationManager , times (1 )).authenticate (any ());
176
+
177
+ // 2) redirect_uri with query parameters AND authorization response additional parameters
178
+ Map <String , String > additionalParameters = new LinkedHashMap <>();
179
+ additionalParameters .put ("auth-param1" , "value1" );
180
+ additionalParameters .put ("auth-param2" , "value2" );
181
+ authorizationResponse = createAuthorizationResponse (authorizationRequest , additionalParameters );
182
+ exchange = MockServerWebExchange .from (authorizationResponse );
183
+
184
+ this .filter .filter (exchange , chain ).block ();
185
+ verify (this .authenticationManager , times (2 )).authenticate (any ());
186
+ }
187
+
188
+ // gh-7966
189
+ @ Test
190
+ public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotProcessed () {
191
+ String requestUri = "/authorization/callback" ;
192
+ Map <String , String > parameters = new LinkedHashMap <>();
193
+ parameters .put ("param1" , "value1" );
194
+ parameters .put ("param2" , "value2" );
195
+ MockServerHttpRequest authorizationRequest =
196
+ createAuthorizationRequest (requestUri , parameters );
197
+ ClientRegistration clientRegistration = TestClientRegistrations .clientRegistration ().build ();
198
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
199
+ createOAuth2AuthorizationRequest (authorizationRequest , clientRegistration );
200
+ when (this .authorizationRequestRepository .loadAuthorizationRequest (any ()))
201
+ .thenReturn (Mono .just (oauth2AuthorizationRequest ));
202
+
203
+ // 1) Parameter value
204
+ Map <String , String > parametersNotMatch = new LinkedHashMap <>(parameters );
205
+ parametersNotMatch .put ("param2" , "value8" );
206
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse (
207
+ createAuthorizationRequest (requestUri , parametersNotMatch ));
208
+ MockServerWebExchange exchange = MockServerWebExchange .from (authorizationResponse );
209
+ DefaultWebFilterChain chain = new DefaultWebFilterChain (
210
+ e -> e .getResponse ().setComplete (), Collections .emptyList ());
211
+
212
+ this .filter .filter (exchange , chain ).block ();
213
+ verifyZeroInteractions (this .authenticationManager );
214
+
215
+ // 2) Parameter order
216
+ parametersNotMatch = new LinkedHashMap <>();
217
+ parametersNotMatch .put ("param2" , "value2" );
218
+ parametersNotMatch .put ("param1" , "value1" );
219
+ authorizationResponse = createAuthorizationResponse (
220
+ createAuthorizationRequest (requestUri , parametersNotMatch ));
221
+ exchange = MockServerWebExchange .from (authorizationResponse );
222
+
223
+ this .filter .filter (exchange , chain ).block ();
224
+ verifyZeroInteractions (this .authenticationManager );
225
+
226
+ // 3) Parameter missing
227
+ parametersNotMatch = new LinkedHashMap <>(parameters );
228
+ parametersNotMatch .remove ("param2" );
229
+ authorizationResponse = createAuthorizationResponse (
230
+ createAuthorizationRequest (requestUri , parametersNotMatch ));
231
+ exchange = MockServerWebExchange .from (authorizationResponse );
232
+
233
+ this .filter .filter (exchange , chain ).block ();
234
+ verifyZeroInteractions (this .authenticationManager );
235
+ }
236
+
237
+ private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest (
238
+ MockServerHttpRequest authorizationRequest , ClientRegistration registration ) {
239
+ Map <String , Object > attributes = new HashMap <>();
240
+ attributes .put (OAuth2ParameterNames .REGISTRATION_ID , registration .getRegistrationId ());
241
+ return request ()
242
+ .attributes (attributes )
243
+ .redirectUri (authorizationRequest .getURI ().toString ())
244
+ .build ();
245
+ }
246
+
247
+ private static MockServerHttpRequest createAuthorizationRequest (String requestUri ) {
248
+ return createAuthorizationRequest (requestUri , new LinkedHashMap <>());
249
+ }
250
+
251
+ private static MockServerHttpRequest createAuthorizationRequest (String requestUri , Map <String , String > parameters ) {
252
+ MockServerHttpRequest .BaseBuilder <?> builder = MockServerHttpRequest
253
+ .get (requestUri );
254
+ if (!CollectionUtils .isEmpty (parameters )) {
255
+ parameters .forEach (builder ::queryParam );
256
+ }
257
+ return builder .build ();
258
+ }
259
+
260
+ private static MockServerHttpRequest createAuthorizationResponse (MockServerHttpRequest authorizationRequest ) {
261
+ return createAuthorizationResponse (authorizationRequest , new LinkedHashMap <>());
262
+ }
263
+
264
+ private static MockServerHttpRequest createAuthorizationResponse (
265
+ MockServerHttpRequest authorizationRequest , Map <String , String > additionalParameters ) {
266
+ MockServerHttpRequest .BaseBuilder <?> builder = MockServerHttpRequest
267
+ .get (authorizationRequest .getURI ().toString ());
268
+ builder .queryParam (OAuth2ParameterNames .CODE , "code" );
269
+ builder .queryParam (OAuth2ParameterNames .STATE , "state" );
270
+ additionalParameters .forEach (builder ::queryParam );
271
+ builder .cookies (authorizationRequest .getCookies ());
272
+ return builder .build ();
273
+ }
153
274
}
0 commit comments