Skip to content

Commit 4c0af86

Browse files
committed
Add workaround for setting WebClientResolver
Based on spring-projects/spring-security#13274
1 parent cd9e44d commit 4c0af86

File tree

3 files changed

+304
-7
lines changed

3 files changed

+304
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package no.nav.pensjon.gateway.pensjonazureadappgateway;
18+
19+
import java.util.function.Function;
20+
21+
import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenValidator;
22+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
23+
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
24+
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
25+
import org.springframework.security.oauth2.jwt.Jwt;
26+
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
27+
28+
/**
29+
* @author Joe Grandja
30+
* @since 5.2
31+
*/
32+
class DefaultOidcIdTokenValidatorFactory implements Function<ClientRegistration, OAuth2TokenValidator<Jwt>> {
33+
34+
@Override
35+
public OAuth2TokenValidator<Jwt> apply(ClientRegistration clientRegistration) {
36+
return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(),
37+
new OidcIdTokenValidator(clientRegistration));
38+
}
39+
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package no.nav.pensjon.gateway.pensjonazureadappgateway;
18+
19+
import java.net.URL;
20+
import java.nio.charset.StandardCharsets;
21+
import java.time.Instant;
22+
import java.util.Collection;
23+
import java.util.Collections;
24+
import java.util.HashMap;
25+
import java.util.Map;
26+
import java.util.concurrent.ConcurrentHashMap;
27+
import java.util.function.Function;
28+
29+
import javax.crypto.spec.SecretKeySpec;
30+
31+
import org.springframework.core.convert.TypeDescriptor;
32+
import org.springframework.core.convert.converter.Converter;
33+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
34+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
35+
import org.springframework.security.oauth2.core.OAuth2Error;
36+
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
37+
import org.springframework.security.oauth2.core.converter.ClaimConversionService;
38+
import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
39+
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
40+
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
41+
import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
42+
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
43+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
44+
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
45+
import org.springframework.security.oauth2.jwt.Jwt;
46+
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
47+
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
48+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
49+
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
50+
import org.springframework.util.Assert;
51+
import org.springframework.util.StringUtils;
52+
import org.springframework.web.reactive.function.client.WebClient;
53+
54+
/**
55+
* A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder}
56+
* used for {@link OidcIdToken} signature verification. The provided
57+
* {@link ReactiveJwtDecoder} is associated to a specific {@link ClientRegistration}.
58+
*
59+
* @author Joe Grandja
60+
* @author Rafael Dominguez
61+
* @author Mark Heckler
62+
* @author Ubaid ur Rehman
63+
* @since 5.2
64+
* @see ReactiveJwtDecoderFactory
65+
* @see ClientRegistration
66+
* @see OidcIdToken
67+
*/
68+
public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
69+
70+
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
71+
72+
private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
73+
static {
74+
Map<JwsAlgorithm, String> mappings = new HashMap<JwsAlgorithm, String>();
75+
mappings.put(MacAlgorithm.HS256, "HmacSHA256");
76+
mappings.put(MacAlgorithm.HS384, "HmacSHA384");
77+
mappings.put(MacAlgorithm.HS512, "HmacSHA512");
78+
JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
79+
}
80+
81+
private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter(
82+
createDefaultClaimTypeConverters());
83+
84+
private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
85+
86+
private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory();
87+
88+
private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = (
89+
clientRegistration) -> SignatureAlgorithm.RS256;
90+
91+
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
92+
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
93+
94+
private Function<ClientRegistration, WebClient> webClientResolver = (clientRegistration) -> WebClient.create();
95+
96+
/**
97+
* Returns the default {@link Converter}'s used for type conversion of claim values
98+
* for an {@link OidcIdToken}.
99+
* @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames
100+
* claim name}
101+
*/
102+
public static Map<String, Converter<Object, ?>> createDefaultClaimTypeConverters() {
103+
Converter<Object, ?> booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class));
104+
Converter<Object, ?> instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class));
105+
Converter<Object, ?> urlConverter = getConverter(TypeDescriptor.valueOf(URL.class));
106+
Converter<Object, ?> stringConverter = getConverter(TypeDescriptor.valueOf(String.class));
107+
Converter<Object, ?> collectionStringConverter = getConverter(
108+
TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class)));
109+
Map<String, Converter<Object, ?>> converters = new HashMap<>();
110+
converters.put(IdTokenClaimNames.ISS, urlConverter);
111+
converters.put(IdTokenClaimNames.AUD, collectionStringConverter);
112+
converters.put(IdTokenClaimNames.NONCE, stringConverter);
113+
converters.put(IdTokenClaimNames.EXP, instantConverter);
114+
converters.put(IdTokenClaimNames.IAT, instantConverter);
115+
converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter);
116+
converters.put(IdTokenClaimNames.AMR, collectionStringConverter);
117+
converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter);
118+
converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter);
119+
converters.put(StandardClaimNames.UPDATED_AT, instantConverter);
120+
return converters;
121+
}
122+
123+
private static Converter<Object, ?> getConverter(TypeDescriptor targetDescriptor) {
124+
final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class);
125+
return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor,
126+
targetDescriptor);
127+
}
128+
129+
@Override
130+
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
131+
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
132+
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> {
133+
NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
134+
jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration));
135+
Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory
136+
.apply(clientRegistration);
137+
if (claimTypeConverter != null) {
138+
jwtDecoder.setClaimSetConverter(claimTypeConverter);
139+
}
140+
return jwtDecoder;
141+
});
142+
}
143+
144+
private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
145+
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
146+
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
147+
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
148+
//
149+
// 6. If the ID Token is received via direct communication between the Client
150+
// and the Token Endpoint (which it is in this flow),
151+
// the TLS server validation MAY be used to validate the issuer in place of
152+
// checking the token signature.
153+
// The Client MUST validate the signature of all other ID Tokens according to
154+
// JWS [JWS]
155+
// using the algorithm specified in the JWT alg Header Parameter.
156+
// The Client MUST use the keys provided by the Issuer.
157+
//
158+
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by
159+
// the Client
160+
// in the id_token_signed_response_alg parameter during Registration.
161+
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
162+
if (!StringUtils.hasText(jwkSetUri)) {
163+
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
164+
"Failed to find a Signature Verifier for Client Registration: '"
165+
+ clientRegistration.getRegistrationId()
166+
+ "'. Check to ensure you have configured the JwkSet URI.",
167+
null);
168+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
169+
}
170+
return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
171+
.webClient(this.webClientResolver.apply(clientRegistration)).build();
172+
}
173+
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
174+
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
175+
//
176+
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as
177+
// HS256, HS384, or HS512,
178+
// the octets of the UTF-8 representation of the client_secret
179+
// corresponding to the client_id contained in the aud (audience) Claim
180+
// are used as the key to validate the signature.
181+
// For MAC based algorithms, the behavior is unspecified if the aud is
182+
// multi-valued or
183+
// if an azp value is present that is different than the aud value.
184+
185+
String clientSecret = clientRegistration.getClientSecret();
186+
if (!StringUtils.hasText(clientSecret)) {
187+
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
188+
"Failed to find a Signature Verifier for Client Registration: '"
189+
+ clientRegistration.getRegistrationId()
190+
+ "'. Check to ensure you have configured the client secret.",
191+
null);
192+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
193+
}
194+
SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
195+
JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
196+
return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm)
197+
.build();
198+
}
199+
OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
200+
"Failed to find a Signature Verifier for Client Registration: '"
201+
+ clientRegistration.getRegistrationId()
202+
+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'",
203+
null);
204+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
205+
}
206+
207+
/**
208+
* Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by
209+
* the {@link ReactiveJwtDecoder}. The default composes {@link JwtTimestampValidator}
210+
* and {@link OidcIdTokenValidator}.
211+
* @param jwtValidatorFactory the factory that provides an
212+
* {@link OAuth2TokenValidator}
213+
*/
214+
public void setJwtValidatorFactory(Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory) {
215+
Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
216+
this.jwtValidatorFactory = jwtValidatorFactory;
217+
}
218+
219+
/**
220+
* Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
221+
* used for the signature or MAC on the {@link OidcIdToken ID Token}. The default
222+
* resolves to {@link SignatureAlgorithm#RS256 RS256} for all
223+
* {@link ClientRegistration clients}.
224+
* @param jwsAlgorithmResolver the resolver that provides the expected
225+
* {@link JwsAlgorithm JWS algorithm} for a specific {@link ClientRegistration client}
226+
*/
227+
public void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) {
228+
Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null");
229+
this.jwsAlgorithmResolver = jwsAlgorithmResolver;
230+
}
231+
232+
/**
233+
* Sets the factory that provides a {@link Converter} used for type conversion of
234+
* claim values for an {@link OidcIdToken}. The default is {@link ClaimTypeConverter}
235+
* for all {@link ClientRegistration clients}.
236+
* @param claimTypeConverterFactory the factory that provides a {@link Converter} used
237+
* for type conversion of claim values for a specific {@link ClientRegistration
238+
* client}
239+
*/
240+
public void setClaimTypeConverterFactory(
241+
Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory) {
242+
Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null");
243+
this.claimTypeConverterFactory = claimTypeConverterFactory;
244+
}
245+
246+
/**
247+
* Sets the resolver that provides the {@link WebClient} that will be used in
248+
* {@link NimbusReactiveJwtDecoder}. The default resolver provides {@link WebClient}
249+
* that is created by {@code WebClient.create()}. This is optional method if we need
250+
* to set custom web client in {@link NimbusReactiveJwtDecoder}.
251+
* @param webClientResolver a function that will provide {@link WebClient} for a
252+
* {@link ClientRegistration}
253+
*/
254+
public void setWebClientResolver(Function<ClientRegistration, WebClient> webClientResolver) {
255+
Assert.notNull(webClientResolver, "webClientResolver cannot be null");
256+
this.webClientResolver = webClientResolver;
257+
}
258+
259+
}

src/main/kotlin/no/nav/pensjon/gateway/pensjonazureadappgateway/SecurityConfiguration.kt

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package no.nav.pensjon.gateway.pensjonazureadappgateway
22

33
import org.slf4j.Logger
4-
import org.slf4j.LoggerFactory
54
import org.slf4j.LoggerFactory.getLogger
65
import org.springframework.beans.factory.annotation.Value
76
import org.springframework.context.annotation.Bean
@@ -26,7 +25,6 @@ import org.springframework.security.oauth2.core.user.OAuth2User
2625
import org.springframework.security.oauth2.jwt.*
2726
import org.springframework.security.web.server.SecurityWebFilterChain
2827
import org.springframework.web.reactive.function.client.*
29-
import reactor.core.publisher.Mono
3028
import reactor.netty.http.client.HttpClient
3129
import reactor.netty.transport.ProxyProvider.Proxy.HTTP
3230
import java.net.URI
@@ -59,7 +57,7 @@ class SecurityConfiguration {
5957
): ReactiveAuthenticationManager = DelegatingReactiveAuthenticationManager(
6058
OidcAuthorizationCodeReactiveAuthenticationManager(
6159
client, oidcUserService
62-
), OAuth2LoginReactiveAuthenticationManager(
60+
).also { it.setJwtDecoderFactory(ReactiveOidcIdTokenDecoderFactory().apply { setWebClientResolver { webClientProxy() } }) }, OAuth2LoginReactiveAuthenticationManager(
6361
client, oauth2UserService
6462
)
6563
)
@@ -118,12 +116,12 @@ class SecurityConfiguration {
118116
}
119117
)
120118
).filter { request, next ->
121-
logger.info("Proxied request to {}", request.url());
122-
next.exchange(request);
119+
logger.info("Proxied request to {}", request.url())
120+
next.exchange(request)
123121
}.build()
124122
}
125123
?: WebClient.builder().filter { request, next ->
126-
logger.info("Non-proxied request to {}", request.url());
127-
next.exchange(request);
124+
logger.info("Non-proxied request to {}", request.url())
125+
next.exchange(request)
128126
}.build()
129127
}

0 commit comments

Comments
 (0)