|
| 1 | +/* |
| 2 | + * Copyright 2002-2023 the original author or authors. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package no.nav.pensjon.gateway.pensjonazureadappgateway; |
| 18 | + |
| 19 | +import java.net.URL; |
| 20 | +import java.nio.charset.StandardCharsets; |
| 21 | +import java.time.Instant; |
| 22 | +import java.util.Collection; |
| 23 | +import java.util.Collections; |
| 24 | +import java.util.HashMap; |
| 25 | +import java.util.Map; |
| 26 | +import java.util.concurrent.ConcurrentHashMap; |
| 27 | +import java.util.function.Function; |
| 28 | + |
| 29 | +import javax.crypto.spec.SecretKeySpec; |
| 30 | + |
| 31 | +import org.springframework.core.convert.TypeDescriptor; |
| 32 | +import org.springframework.core.convert.converter.Converter; |
| 33 | +import org.springframework.security.oauth2.client.registration.ClientRegistration; |
| 34 | +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
| 35 | +import org.springframework.security.oauth2.core.OAuth2Error; |
| 36 | +import org.springframework.security.oauth2.core.OAuth2TokenValidator; |
| 37 | +import org.springframework.security.oauth2.core.converter.ClaimConversionService; |
| 38 | +import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; |
| 39 | +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; |
| 40 | +import org.springframework.security.oauth2.core.oidc.OidcIdToken; |
| 41 | +import org.springframework.security.oauth2.core.oidc.StandardClaimNames; |
| 42 | +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; |
| 43 | +import org.springframework.security.oauth2.jose.jws.MacAlgorithm; |
| 44 | +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; |
| 45 | +import org.springframework.security.oauth2.jwt.Jwt; |
| 46 | +import org.springframework.security.oauth2.jwt.JwtTimestampValidator; |
| 47 | +import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder; |
| 48 | +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; |
| 49 | +import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; |
| 50 | +import org.springframework.util.Assert; |
| 51 | +import org.springframework.util.StringUtils; |
| 52 | +import org.springframework.web.reactive.function.client.WebClient; |
| 53 | + |
| 54 | +/** |
| 55 | + * A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder} |
| 56 | + * used for {@link OidcIdToken} signature verification. The provided |
| 57 | + * {@link ReactiveJwtDecoder} is associated to a specific {@link ClientRegistration}. |
| 58 | + * |
| 59 | + * @author Joe Grandja |
| 60 | + * @author Rafael Dominguez |
| 61 | + * @author Mark Heckler |
| 62 | + * @author Ubaid ur Rehman |
| 63 | + * @since 5.2 |
| 64 | + * @see ReactiveJwtDecoderFactory |
| 65 | + * @see ClientRegistration |
| 66 | + * @see OidcIdToken |
| 67 | + */ |
| 68 | +public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> { |
| 69 | + |
| 70 | + private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier"; |
| 71 | + |
| 72 | + private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS; |
| 73 | + static { |
| 74 | + Map<JwsAlgorithm, String> mappings = new HashMap<JwsAlgorithm, String>(); |
| 75 | + mappings.put(MacAlgorithm.HS256, "HmacSHA256"); |
| 76 | + mappings.put(MacAlgorithm.HS384, "HmacSHA384"); |
| 77 | + mappings.put(MacAlgorithm.HS512, "HmacSHA512"); |
| 78 | + JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings); |
| 79 | + } |
| 80 | + |
| 81 | + private static final ClaimTypeConverter DEFAULT_CLAIM_TYPE_CONVERTER = new ClaimTypeConverter( |
| 82 | + createDefaultClaimTypeConverters()); |
| 83 | + |
| 84 | + private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>(); |
| 85 | + |
| 86 | + private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = new DefaultOidcIdTokenValidatorFactory(); |
| 87 | + |
| 88 | + private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = ( |
| 89 | + clientRegistration) -> SignatureAlgorithm.RS256; |
| 90 | + |
| 91 | + private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = ( |
| 92 | + clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; |
| 93 | + |
| 94 | + private Function<ClientRegistration, WebClient> webClientResolver = (clientRegistration) -> WebClient.create(); |
| 95 | + |
| 96 | + /** |
| 97 | + * Returns the default {@link Converter}'s used for type conversion of claim values |
| 98 | + * for an {@link OidcIdToken}. |
| 99 | + * @return a {@link Map} of {@link Converter}'s keyed by {@link IdTokenClaimNames |
| 100 | + * claim name} |
| 101 | + */ |
| 102 | + public static Map<String, Converter<Object, ?>> createDefaultClaimTypeConverters() { |
| 103 | + Converter<Object, ?> booleanConverter = getConverter(TypeDescriptor.valueOf(Boolean.class)); |
| 104 | + Converter<Object, ?> instantConverter = getConverter(TypeDescriptor.valueOf(Instant.class)); |
| 105 | + Converter<Object, ?> urlConverter = getConverter(TypeDescriptor.valueOf(URL.class)); |
| 106 | + Converter<Object, ?> stringConverter = getConverter(TypeDescriptor.valueOf(String.class)); |
| 107 | + Converter<Object, ?> collectionStringConverter = getConverter( |
| 108 | + TypeDescriptor.collection(Collection.class, TypeDescriptor.valueOf(String.class))); |
| 109 | + Map<String, Converter<Object, ?>> converters = new HashMap<>(); |
| 110 | + converters.put(IdTokenClaimNames.ISS, urlConverter); |
| 111 | + converters.put(IdTokenClaimNames.AUD, collectionStringConverter); |
| 112 | + converters.put(IdTokenClaimNames.NONCE, stringConverter); |
| 113 | + converters.put(IdTokenClaimNames.EXP, instantConverter); |
| 114 | + converters.put(IdTokenClaimNames.IAT, instantConverter); |
| 115 | + converters.put(IdTokenClaimNames.AUTH_TIME, instantConverter); |
| 116 | + converters.put(IdTokenClaimNames.AMR, collectionStringConverter); |
| 117 | + converters.put(StandardClaimNames.EMAIL_VERIFIED, booleanConverter); |
| 118 | + converters.put(StandardClaimNames.PHONE_NUMBER_VERIFIED, booleanConverter); |
| 119 | + converters.put(StandardClaimNames.UPDATED_AT, instantConverter); |
| 120 | + return converters; |
| 121 | + } |
| 122 | + |
| 123 | + private static Converter<Object, ?> getConverter(TypeDescriptor targetDescriptor) { |
| 124 | + final TypeDescriptor sourceDescriptor = TypeDescriptor.valueOf(Object.class); |
| 125 | + return (source) -> ClaimConversionService.getSharedInstance().convert(source, sourceDescriptor, |
| 126 | + targetDescriptor); |
| 127 | + } |
| 128 | + |
| 129 | + @Override |
| 130 | + public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) { |
| 131 | + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); |
| 132 | + return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), (key) -> { |
| 133 | + NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration); |
| 134 | + jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(clientRegistration)); |
| 135 | + Converter<Map<String, Object>, Map<String, Object>> claimTypeConverter = this.claimTypeConverterFactory |
| 136 | + .apply(clientRegistration); |
| 137 | + if (claimTypeConverter != null) { |
| 138 | + jwtDecoder.setClaimSetConverter(claimTypeConverter); |
| 139 | + } |
| 140 | + return jwtDecoder; |
| 141 | + }); |
| 142 | + } |
| 143 | + |
| 144 | + private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) { |
| 145 | + JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration); |
| 146 | + if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { |
| 147 | + // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation |
| 148 | + // |
| 149 | + // 6. If the ID Token is received via direct communication between the Client |
| 150 | + // and the Token Endpoint (which it is in this flow), |
| 151 | + // the TLS server validation MAY be used to validate the issuer in place of |
| 152 | + // checking the token signature. |
| 153 | + // The Client MUST validate the signature of all other ID Tokens according to |
| 154 | + // JWS [JWS] |
| 155 | + // using the algorithm specified in the JWT alg Header Parameter. |
| 156 | + // The Client MUST use the keys provided by the Issuer. |
| 157 | + // |
| 158 | + // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by |
| 159 | + // the Client |
| 160 | + // in the id_token_signed_response_alg parameter during Registration. |
| 161 | + String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri(); |
| 162 | + if (!StringUtils.hasText(jwkSetUri)) { |
| 163 | + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, |
| 164 | + "Failed to find a Signature Verifier for Client Registration: '" |
| 165 | + + clientRegistration.getRegistrationId() |
| 166 | + + "'. Check to ensure you have configured the JwkSet URI.", |
| 167 | + null); |
| 168 | + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); |
| 169 | + } |
| 170 | + return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) |
| 171 | + .webClient(this.webClientResolver.apply(clientRegistration)).build(); |
| 172 | + } |
| 173 | + if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { |
| 174 | + // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation |
| 175 | + // |
| 176 | + // 8. If the JWT alg Header Parameter uses a MAC based algorithm such as |
| 177 | + // HS256, HS384, or HS512, |
| 178 | + // the octets of the UTF-8 representation of the client_secret |
| 179 | + // corresponding to the client_id contained in the aud (audience) Claim |
| 180 | + // are used as the key to validate the signature. |
| 181 | + // For MAC based algorithms, the behavior is unspecified if the aud is |
| 182 | + // multi-valued or |
| 183 | + // if an azp value is present that is different than the aud value. |
| 184 | + |
| 185 | + String clientSecret = clientRegistration.getClientSecret(); |
| 186 | + if (!StringUtils.hasText(clientSecret)) { |
| 187 | + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, |
| 188 | + "Failed to find a Signature Verifier for Client Registration: '" |
| 189 | + + clientRegistration.getRegistrationId() |
| 190 | + + "'. Check to ensure you have configured the client secret.", |
| 191 | + null); |
| 192 | + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); |
| 193 | + } |
| 194 | + SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), |
| 195 | + JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm)); |
| 196 | + return NimbusReactiveJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm) |
| 197 | + .build(); |
| 198 | + } |
| 199 | + OAuth2Error oauth2Error = new OAuth2Error(MISSING_SIGNATURE_VERIFIER_ERROR_CODE, |
| 200 | + "Failed to find a Signature Verifier for Client Registration: '" |
| 201 | + + clientRegistration.getRegistrationId() |
| 202 | + + "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'", |
| 203 | + null); |
| 204 | + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); |
| 205 | + } |
| 206 | + |
| 207 | + /** |
| 208 | + * Sets the factory that provides an {@link OAuth2TokenValidator}, which is used by |
| 209 | + * the {@link ReactiveJwtDecoder}. The default composes {@link JwtTimestampValidator} |
| 210 | + * and {@link OidcIdTokenValidator}. |
| 211 | + * @param jwtValidatorFactory the factory that provides an |
| 212 | + * {@link OAuth2TokenValidator} |
| 213 | + */ |
| 214 | + public void setJwtValidatorFactory(Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory) { |
| 215 | + Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null"); |
| 216 | + this.jwtValidatorFactory = jwtValidatorFactory; |
| 217 | + } |
| 218 | + |
| 219 | + /** |
| 220 | + * Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm} |
| 221 | + * used for the signature or MAC on the {@link OidcIdToken ID Token}. The default |
| 222 | + * resolves to {@link SignatureAlgorithm#RS256 RS256} for all |
| 223 | + * {@link ClientRegistration clients}. |
| 224 | + * @param jwsAlgorithmResolver the resolver that provides the expected |
| 225 | + * {@link JwsAlgorithm JWS algorithm} for a specific {@link ClientRegistration client} |
| 226 | + */ |
| 227 | + public void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) { |
| 228 | + Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null"); |
| 229 | + this.jwsAlgorithmResolver = jwsAlgorithmResolver; |
| 230 | + } |
| 231 | + |
| 232 | + /** |
| 233 | + * Sets the factory that provides a {@link Converter} used for type conversion of |
| 234 | + * claim values for an {@link OidcIdToken}. The default is {@link ClaimTypeConverter} |
| 235 | + * for all {@link ClientRegistration clients}. |
| 236 | + * @param claimTypeConverterFactory the factory that provides a {@link Converter} used |
| 237 | + * for type conversion of claim values for a specific {@link ClientRegistration |
| 238 | + * client} |
| 239 | + */ |
| 240 | + public void setClaimTypeConverterFactory( |
| 241 | + Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory) { |
| 242 | + Assert.notNull(claimTypeConverterFactory, "claimTypeConverterFactory cannot be null"); |
| 243 | + this.claimTypeConverterFactory = claimTypeConverterFactory; |
| 244 | + } |
| 245 | + |
| 246 | + /** |
| 247 | + * Sets the resolver that provides the {@link WebClient} that will be used in |
| 248 | + * {@link NimbusReactiveJwtDecoder}. The default resolver provides {@link WebClient} |
| 249 | + * that is created by {@code WebClient.create()}. This is optional method if we need |
| 250 | + * to set custom web client in {@link NimbusReactiveJwtDecoder}. |
| 251 | + * @param webClientResolver a function that will provide {@link WebClient} for a |
| 252 | + * {@link ClientRegistration} |
| 253 | + */ |
| 254 | + public void setWebClientResolver(Function<ClientRegistration, WebClient> webClientResolver) { |
| 255 | + Assert.notNull(webClientResolver, "webClientResolver cannot be null"); |
| 256 | + this.webClientResolver = webClientResolver; |
| 257 | + } |
| 258 | + |
| 259 | +} |
0 commit comments