25
25
import org .junit .Test ;
26
26
import org .mockito .stubbing .Answer ;
27
27
import org .openqa .selenium .WebDriver ;
28
+
29
+ import org .springframework .security .core .AuthenticationException ;
30
+ import org .springframework .security .oauth2 .core .OAuth2AuthenticationException ;
31
+ import org .springframework .security .oauth2 .core .OAuth2Error ;
28
32
import org .springframework .security .web .server .WebFilterExchange ;
33
+ import org .springframework .security .web .server .authentication .RedirectServerAuthenticationFailureHandler ;
29
34
import org .springframework .security .web .server .authentication .RedirectServerAuthenticationSuccessHandler ;
35
+ import org .springframework .security .web .server .authentication .ServerAuthenticationFailureHandler ;
30
36
import org .springframework .security .web .server .authentication .ServerAuthenticationSuccessHandler ;
31
37
import reactor .core .publisher .Mono ;
32
38
97
103
98
104
/**
99
105
* @author Rob Winch
106
+ * @author Eddú Meléndez
100
107
* @since 5.1
101
108
*/
102
109
public class OAuth2LoginTests {
@@ -233,6 +240,59 @@ public void oauth2LoginWhenCustomObjectsThenUsed() {
233
240
verify (successHandler ).onAuthenticationSuccess (any (), any ());
234
241
}
235
242
243
+ @ Test
244
+ public void oauth2LoginFailsWhenCustomObjectsThenUsed () {
245
+ this .spring .register (OAuth2LoginWithSingleClientRegistrations .class ,
246
+ OAuth2LoginMockAuthenticationManagerConfig .class ).autowire ();
247
+
248
+ String redirectLocation = "/custom-redirect-location" ;
249
+ String failureRedirectLocation = "/failure-redirect-location" ;
250
+
251
+ WebTestClient webTestClient = WebTestClientBuilder
252
+ .bindToWebFilters (this .springSecurity )
253
+ .build ();
254
+
255
+ OAuth2LoginMockAuthenticationManagerConfig config = this .spring .getContext ()
256
+ .getBean (OAuth2LoginMockAuthenticationManagerConfig .class );
257
+ ServerAuthenticationConverter converter = config .authenticationConverter ;
258
+ ReactiveAuthenticationManager manager = config .manager ;
259
+ ServerWebExchangeMatcher matcher = config .matcher ;
260
+ ServerOAuth2AuthorizationRequestResolver resolver = config .resolver ;
261
+ ServerAuthenticationSuccessHandler successHandler = config .successHandler ;
262
+ ServerAuthenticationFailureHandler failureHandler = config .failureHandler ;
263
+
264
+ when (converter .convert (any ())).thenReturn (Mono .just (new TestingAuthenticationToken ("a" , "b" , "c" )));
265
+ when (manager .authenticate (any ())).thenReturn (Mono .error (new OAuth2AuthenticationException (new OAuth2Error ("error" ), "message" )));
266
+ when (matcher .matches (any ())).thenReturn (ServerWebExchangeMatcher .MatchResult .match ());
267
+ when (resolver .resolve (any ())).thenReturn (Mono .empty ());
268
+ when (successHandler .onAuthenticationSuccess (any (), any ())).thenAnswer ((Answer <Mono <Void >>) invocation -> {
269
+ WebFilterExchange webFilterExchange = invocation .getArgument (0 );
270
+ Authentication authentication = invocation .getArgument (1 );
271
+
272
+ return new RedirectServerAuthenticationSuccessHandler (redirectLocation )
273
+ .onAuthenticationSuccess (webFilterExchange , authentication );
274
+ });
275
+ when (failureHandler .onAuthenticationFailure (any (), any ())).thenAnswer ((Answer <Mono <Void >>) invocation -> {
276
+ WebFilterExchange webFilterExchange = invocation .getArgument (0 );
277
+ AuthenticationException authenticationException = invocation .getArgument (1 );
278
+
279
+ return new RedirectServerAuthenticationFailureHandler (failureRedirectLocation )
280
+ .onAuthenticationFailure (webFilterExchange , authenticationException );
281
+ });
282
+
283
+ webTestClient .get ()
284
+ .uri ("/login/oauth2/code/github" )
285
+ .exchange ()
286
+ .expectStatus ().is3xxRedirection ()
287
+ .expectHeader ().valueEquals ("Location" , failureRedirectLocation );
288
+
289
+ verify (converter ).convert (any ());
290
+ verify (manager ).authenticate (any ());
291
+ verify (matcher ).matches (any ());
292
+ verify (resolver ).resolve (any ());
293
+ verify (failureHandler ).onAuthenticationFailure (any (), any ());
294
+ }
295
+
236
296
@ Configuration
237
297
static class OAuth2LoginMockAuthenticationManagerConfig {
238
298
ReactiveAuthenticationManager manager = mock (ReactiveAuthenticationManager .class );
@@ -245,6 +305,8 @@ static class OAuth2LoginMockAuthenticationManagerConfig {
245
305
246
306
ServerAuthenticationSuccessHandler successHandler = mock (ServerAuthenticationSuccessHandler .class );
247
307
308
+ ServerAuthenticationFailureHandler failureHandler = mock (ServerAuthenticationFailureHandler .class );
309
+
248
310
@ Bean
249
311
public SecurityWebFilterChain springSecurityFilter (ServerHttpSecurity http ) {
250
312
http
@@ -256,7 +318,8 @@ public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
256
318
.authenticationManager (manager )
257
319
.authenticationMatcher (matcher )
258
320
.authorizationRequestResolver (resolver )
259
- .authenticationSuccessHandler (successHandler );
321
+ .authenticationSuccessHandler (successHandler )
322
+ .authenticationFailureHandler (failureHandler );
260
323
return http .build ();
261
324
}
262
325
}
0 commit comments