Skip to content

Commit 7543eff

Browse files
committed
Add SecurityContextHolderStrategy Java Configuration for OAuth2
Issue gh-11061
1 parent 7e38411 commit 7543eff

File tree

5 files changed

+80
-4
lines changed

5 files changed

+80
-4
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@
2323
import org.springframework.context.annotation.Import;
2424
import org.springframework.context.annotation.ImportSelector;
2525
import org.springframework.core.type.AnnotationMetadata;
26+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
2627
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
2728
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
2829
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
@@ -75,11 +76,18 @@ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer
7576

7677
private OAuth2AuthorizedClientManager authorizedClientManager;
7778

79+
private SecurityContextHolderStrategy securityContextHolderStrategy;
80+
7881
@Override
7982
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
8083
OAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager();
8184
if (authorizedClientManager != null) {
82-
argumentResolvers.add(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
85+
OAuth2AuthorizedClientArgumentResolver resolver = new OAuth2AuthorizedClientArgumentResolver(
86+
authorizedClientManager);
87+
if (this.securityContextHolderStrategy != null) {
88+
resolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
89+
}
90+
argumentResolvers.add(resolver);
8391
}
8492
}
8593

@@ -110,6 +118,11 @@ void setAuthorizedClientManager(List<OAuth2AuthorizedClientManager> authorizedCl
110118
}
111119
}
112120

121+
@Autowired(required = false)
122+
void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
123+
this.securityContextHolderStrategy = strategy;
124+
}
125+
113126
private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
114127
if (this.authorizedClientManager != null) {
115128
return this.authorizedClientManager;

config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -289,6 +289,7 @@ private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B
289289
if (this.authorizationRequestRepository != null) {
290290
authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
291291
}
292+
authorizationCodeGrantFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
292293
RequestCache requestCache = builder.getSharedObject(RequestCache.class);
293294
if (requestCache != null) {
294295
authorizationCodeGrantFilter.setRequestCache(requestCache);

config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -265,6 +265,7 @@ public void configure(H http) {
265265
BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(resolver);
266266
filter.setBearerTokenResolver(bearerTokenResolver);
267267
filter.setAuthenticationEntryPoint(this.authenticationEntryPoint);
268+
filter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
268269
filter = postProcess(filter);
269270
http.addFilter(filter);
270271
}

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

+25
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.springframework.mock.web.MockHttpServletRequest;
4343
import org.springframework.mock.web.MockHttpServletResponse;
4444
import org.springframework.security.authentication.event.AuthenticationSuccessEvent;
45+
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
4546
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
4647
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
4748
import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
@@ -52,6 +53,8 @@
5253
import org.springframework.security.core.authority.AuthorityUtils;
5354
import org.springframework.security.core.authority.SimpleGrantedAuthority;
5455
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
56+
import org.springframework.security.core.context.SecurityContextChangedListener;
57+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
5558
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
5659
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
5760
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
@@ -102,7 +105,10 @@
102105
import static org.mockito.ArgumentMatchers.anyString;
103106
import static org.mockito.BDDMockito.given;
104107
import static org.mockito.BDDMockito.then;
108+
import static org.mockito.Mockito.atLeastOnce;
105109
import static org.mockito.Mockito.mock;
110+
import static org.mockito.Mockito.verify;
111+
import static org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.setAuthentication;
106112
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
107113
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
108114
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -196,6 +202,25 @@ public void oauth2Login() throws Exception {
196202
.hasToString("OAUTH2_USER");
197203
}
198204

205+
@Test
206+
public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
207+
loadConfig(OAuth2LoginConfig.class, SecurityContextChangedListenerConfig.class);
208+
OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest();
209+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, this.request, this.response);
210+
this.request.setParameter("code", "code123");
211+
this.request.setParameter("state", authorizationRequest.getState());
212+
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
213+
Authentication authentication = this.securityContextRepository
214+
.loadContext(new HttpRequestResponseHolder(this.request, this.response)).getAuthentication();
215+
assertThat(authentication.getAuthorities()).hasSize(1);
216+
assertThat(authentication.getAuthorities()).first().isInstanceOf(OAuth2UserAuthority.class)
217+
.hasToString("OAUTH2_USER");
218+
SecurityContextHolderStrategy strategy = this.context.getBean(SecurityContextHolderStrategy.class);
219+
verify(strategy, atLeastOnce()).getDeferredContext();
220+
SecurityContextChangedListener listener = this.context.getBean(SecurityContextChangedListener.class);
221+
verify(listener).securityContextChanged(setAuthentication(OAuth2AuthenticationToken.class));
222+
}
223+
199224
@Test
200225
public void requestWhenOauth2LoginInLambdaThenAuthenticationContainsOauth2UserAuthority() throws Exception {
201226
loadConfig(OAuth2LoginInLambdaConfig.class);

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java

+36
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import org.hamcrest.core.StringStartsWith;
5151
import org.junit.jupiter.api.Test;
5252
import org.junit.jupiter.api.extension.ExtendWith;
53+
import org.mockito.verification.VerificationMode;
5354

5455
import org.springframework.beans.factory.BeanCreationException;
5556
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
@@ -82,6 +83,7 @@
8283
import org.springframework.security.authentication.AuthenticationServiceException;
8384
import org.springframework.security.authentication.TestingAuthenticationToken;
8485
import org.springframework.security.config.annotation.ObjectPostProcessor;
86+
import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
8587
import org.springframework.security.config.annotation.method.configuration.EnableGlobalMethodSecurity;
8688
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
8789
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
@@ -92,6 +94,8 @@
9294
import org.springframework.security.core.Authentication;
9395
import org.springframework.security.core.GrantedAuthority;
9496
import org.springframework.security.core.authority.SimpleGrantedAuthority;
97+
import org.springframework.security.core.context.SecurityContextChangedListener;
98+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
9599
import org.springframework.security.core.userdetails.UserDetailsService;
96100
import org.springframework.security.oauth2.client.registration.ClientRegistration;
97101
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -153,6 +157,7 @@
153157
import static org.mockito.ArgumentMatchers.anyString;
154158
import static org.mockito.ArgumentMatchers.eq;
155159
import static org.mockito.BDDMockito.given;
160+
import static org.mockito.Mockito.atLeastOnce;
156161
import static org.mockito.Mockito.mock;
157162
import static org.mockito.Mockito.never;
158163
import static org.mockito.Mockito.verify;
@@ -218,6 +223,33 @@ public void getWhenUsingDefaultsWithValidBearerTokenThenAcceptsRequest() throws
218223
// @formatter:on
219224
}
220225

226+
@Test
227+
public void getWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
228+
this.spring.register(RestOperationsConfig.class, DefaultConfig.class, BasicController.class, SecurityContextChangedListenerConfig.class).autowire();
229+
mockRestOperations(jwks("Default"));
230+
String token = this.token("ValidNoScopes");
231+
// @formatter:off
232+
this.mvc.perform(get("/").with(bearerToken(token)))
233+
.andExpect(status().isOk())
234+
.andExpect(content().string("ok"));
235+
// @formatter:on
236+
verifyBean(SecurityContextHolderStrategy.class, atLeastOnce()).getContext();
237+
}
238+
239+
@Test
240+
public void getWhenSecurityContextHolderStrategyThenUses() throws Exception {
241+
this.spring.register(RestOperationsConfig.class, DefaultConfig.class,
242+
SecurityContextChangedListenerConfig.class, BasicController.class).autowire();
243+
mockRestOperations(jwks("Default"));
244+
String token = this.token("ValidNoScopes");
245+
// @formatter:off
246+
this.mvc.perform(get("/").with(bearerToken(token)))
247+
.andExpect(status().isOk())
248+
.andExpect(content().string("ok"));
249+
// @formatter:on
250+
verifyBean(SecurityContextChangedListener.class, atLeastOnce()).securityContextChanged(any());
251+
}
252+
221253
@Test
222254
public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest() throws Exception {
223255
this.spring.register(RestOperationsConfig.class, DefaultInLambdaConfig.class, BasicController.class).autowire();
@@ -1435,6 +1467,10 @@ private <T> T verifyBean(Class<T> beanClass) {
14351467
return verify(this.spring.getContext().getBean(beanClass));
14361468
}
14371469

1470+
private <T> T verifyBean(Class<T> beanClass, VerificationMode mode) {
1471+
return verify(this.spring.getContext().getBean(beanClass), mode);
1472+
}
1473+
14381474
private String json(String name) throws IOException {
14391475
return resource(name + ".json");
14401476
}

0 commit comments

Comments
 (0)