Skip to content

Commit 96e3e4f

Browse files
committed
Customize when user info is called
Closes gh-13259
1 parent 27b370b commit 96e3e4f

File tree

4 files changed

+143
-4
lines changed

4 files changed

+143
-4
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java

+26-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
2222
import java.util.Map;
2323
import java.util.Set;
2424
import java.util.function.Function;
25+
import java.util.function.Predicate;
2526

2627
import reactor.core.publisher.Mono;
2728

@@ -33,6 +34,7 @@
3334
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
3435
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
3536
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
37+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3638
import org.springframework.security.oauth2.core.OAuth2AccessToken;
3739
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3840
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -71,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
7173
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
7274
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
7375

76+
private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
77+
7478
/**
7579
* Returns the default {@link Converter}'s used for type conversion of claim values
7680
* for an {@link OidcUserInfo}.
@@ -123,7 +127,7 @@ public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2Authent
123127
}
124128

125129
private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
126-
if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) {
130+
if (!this.retrieveUserInfo.test(userRequest)) {
127131
return Mono.empty();
128132
}
129133
// @formatter:off
@@ -169,4 +173,24 @@ public final void setClaimTypeConverterFactory(
169173
this.claimTypeConverterFactory = claimTypeConverterFactory;
170174
}
171175

176+
/**
177+
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
178+
* called to retrieve information about the End-User (Resource Owner).
179+
* <p>
180+
* By default, the UserInfo Endpoint is called if all of the following are true:
181+
* <ul>
182+
* <li>The user info endpoint is defined on the ClientRegistration</li>
183+
* <li>The Client Registration uses the
184+
* {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token
185+
* are defined in the {@link ClientRegistration}</li>
186+
* </ul>
187+
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
188+
* should be called
189+
* @since 6.3
190+
*/
191+
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
192+
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
193+
this.retrieveUserInfo = retrieveUserInfo;
194+
}
195+
172196
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@
2424
import java.util.Map;
2525
import java.util.Set;
2626
import java.util.function.Function;
27+
import java.util.function.Predicate;
2728

2829
import org.springframework.core.convert.TypeDescriptor;
2930
import org.springframework.core.convert.converter.Converter;
@@ -78,6 +79,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
7879
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
7980
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
8081

82+
private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
83+
8184
/**
8285
* Returns the default {@link Converter}'s used for type conversion of claim values
8386
* for an {@link OidcUserInfo}.
@@ -105,7 +108,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
105108
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
106109
Assert.notNull(userRequest, "userRequest cannot be null");
107110
OidcUserInfo userInfo = null;
108-
if (this.shouldRetrieveUserInfo(userRequest)) {
111+
if (this.retrieveUserInfo.test(userRequest)) {
109112
OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
110113
Map<String, Object> claims = getClaims(userRequest, oauth2User);
111114
userInfo = new OidcUserInfo(claims);
@@ -221,10 +224,35 @@ public final void setClaimTypeConverterFactory(
221224
* resource will be requested, otherwise it will not.
222225
* @param accessibleScopes the scope(s) that allow access to the user info resource
223226
* @since 5.2
227+
* @deprecated Use {@link #setRetrieveUserInfo(Predicate)} instead
224228
*/
229+
@Deprecated(since = "6.3", forRemoval = true)
225230
public final void setAccessibleScopes(Set<String> accessibleScopes) {
226231
Assert.notNull(accessibleScopes, "accessibleScopes cannot be null");
227232
this.accessibleScopes = accessibleScopes;
228233
}
229234

235+
/**
236+
* Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
237+
* called to retrieve information about the End-User (Resource Owner).
238+
* <p>
239+
* By default, the UserInfo Endpoint is called if all of the following are true:
240+
* <ul>
241+
* <li>The user info endpoint is defined on the ClientRegistration</li>
242+
* <li>The Client Registration uses the
243+
* {@link AuthorizationGrantType#AUTHORIZATION_CODE}</li>
244+
* <li>The access token contains one or more scopes allowed to access the UserInfo
245+
* Endpoint ({@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email},
246+
* {@link OidcScopes#ADDRESS address} or {@link OidcScopes#PHONE phone}) or the access
247+
* token scopes are empty</li>
248+
* </ul>
249+
* @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
250+
* should be called
251+
* @since 6.3
252+
*/
253+
public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
254+
Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
255+
this.retrieveUserInfo = retrieveUserInfo;
256+
}
257+
230258
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java

+52
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.Iterator;
2525
import java.util.Map;
2626
import java.util.function.Function;
27+
import java.util.function.Predicate;
2728

2829
import okhttp3.mockwebserver.MockResponse;
2930
import okhttp3.mockwebserver.MockWebServer;
@@ -107,6 +108,15 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio
107108
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null));
108109
}
109110

111+
@Test
112+
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
113+
// @formatter:off
114+
assertThatIllegalArgumentException()
115+
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
116+
.withMessage("retrieveUserInfo cannot be null");
117+
// @formatter:on
118+
}
119+
110120
@Test
111121
public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() {
112122
this.registration.userInfoUri(null);
@@ -183,6 +193,48 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() {
183193
verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
184194
}
185195

196+
@Test
197+
public void loadUserWhenTokenScopesIsEmptyThenUserInfoNotRetrieved() {
198+
// @formatter:off
199+
OAuth2AccessToken accessToken = new OAuth2AccessToken(
200+
this.accessToken.getTokenType(),
201+
this.accessToken.getTokenValue(),
202+
this.accessToken.getIssuedAt(),
203+
this.accessToken.getExpiresAt(),
204+
Collections.emptySet());
205+
// @formatter:on
206+
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
207+
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
208+
assertThat(oidcUser).isNotNull();
209+
assertThat(oidcUser.getUserInfo()).isNull();
210+
}
211+
212+
@Test
213+
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
214+
Map<String, Object> attributes = new HashMap<>();
215+
attributes.put(StandardClaimNames.SUB, "subject");
216+
attributes.put("user", "steve");
217+
OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes,
218+
"user");
219+
given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User));
220+
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
221+
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
222+
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
223+
// @formatter:off
224+
OAuth2AccessToken accessToken = new OAuth2AccessToken(
225+
this.accessToken.getTokenType(),
226+
this.accessToken.getTokenValue(),
227+
this.accessToken.getIssuedAt(),
228+
this.accessToken.getExpiresAt(),
229+
Collections.emptySet());
230+
// @formatter:on
231+
OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
232+
OidcUser oidcUser = this.userService.loadUser(userRequest).block();
233+
assertThat(oidcUser).isNotNull();
234+
assertThat(oidcUser.getUserInfo()).isNotNull();
235+
verify(customRetrieveUserInfo).test(userRequest);
236+
}
237+
186238
@Test
187239
public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
188240
OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

+35
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.concurrent.TimeUnit;
2525
import java.util.function.Function;
26+
import java.util.function.Predicate;
2627

2728
import okhttp3.mockwebserver.MockResponse;
2829
import okhttp3.mockwebserver.MockWebServer;
@@ -58,6 +59,7 @@
5859
import static org.assertj.core.api.Assertions.assertThat;
5960
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
6061
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
62+
import static org.mockito.ArgumentMatchers.any;
6163
import static org.mockito.ArgumentMatchers.same;
6264
import static org.mockito.BDDMockito.given;
6365
import static org.mockito.Mockito.mock;
@@ -129,6 +131,15 @@ public void setAccessibleScopesWhenEmptyThenSet() {
129131
this.userService.setAccessibleScopes(Collections.emptySet());
130132
}
131133

134+
@Test
135+
public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
136+
// @formatter:off
137+
assertThatIllegalArgumentException()
138+
.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
139+
.withMessage("retrieveUserInfo cannot be null");
140+
// @formatter:on
141+
}
142+
132143
@Test
133144
public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
134145
assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
@@ -218,6 +229,30 @@ public void loadUserWhenStandardScopesAuthorizedThenUserInfoEndpointRequested()
218229
assertThat(user.getUserInfo()).isNotNull();
219230
}
220231

232+
@Test
233+
public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
234+
// @formatter:off
235+
String userInfoResponse = "{\n"
236+
+ " \"sub\": \"subject1\",\n"
237+
+ " \"name\": \"first last\",\n"
238+
+ " \"given_name\": \"first\",\n"
239+
+ " \"family_name\": \"last\",\n"
240+
+ " \"preferred_username\": \"user1\",\n"
241+
+ " \"email\": \"user1@example.com\"\n"
242+
+ "}\n";
243+
// @formatter:on
244+
this.server.enqueue(jsonResponse(userInfoResponse));
245+
String userInfoUri = this.server.url("/user").toString();
246+
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
247+
this.accessToken = TestOAuth2AccessTokens.noScopes();
248+
Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
249+
given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
250+
this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
251+
OidcUser user = this.userService
252+
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
253+
assertThat(user.getUserInfo()).isNotNull();
254+
}
255+
221256
@Test
222257
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
223258
// @formatter:off

0 commit comments

Comments
 (0)