Skip to content

Commit 15cc15c

Browse files
committed
Simplify OpenSamlImplementation
- Removed reflection usage - Simplified method signatures Issue gh-7711 Fixes gh-8147
1 parent 1bbbf3b commit 15cc15c

File tree

4 files changed

+91
-98
lines changed

4 files changed

+91
-98
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

+27-27
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@
1616

1717
package org.springframework.security.saml2.provider.service.authentication;
1818

19+
import java.time.Clock;
20+
import java.time.Instant;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.UUID;
24+
1925
import org.joda.time.DateTime;
2026
import org.opensaml.saml.common.xml.SAMLConstants;
2127
import org.opensaml.saml.saml2.core.AuthnRequest;
2228
import org.opensaml.saml.saml2.core.Issuer;
29+
2330
import org.springframework.security.saml2.credentials.Saml2X509Credential;
2431
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
2532
import org.springframework.util.Assert;
2633

27-
import java.time.Clock;
28-
import java.time.Instant;
29-
import java.util.List;
30-
import java.util.Map;
31-
import java.util.UUID;
32-
3334
import static java.nio.charset.StandardCharsets.UTF_8;
34-
import static java.util.Collections.emptyList;
3535
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
3636
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
3737

@@ -46,19 +46,21 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
4646
@Override
4747
@Deprecated
4848
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
49-
return createAuthenticationRequest(request, request.getCredentials());
49+
AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
50+
request.getDestination(), request.getAssertionConsumerServiceUrl());
51+
return this.saml.serialize(authnRequest, request.getCredentials());
5052
}
5153

5254
/**
5355
* {@inheritDoc}
5456
*/
5557
@Override
5658
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
57-
List<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ?
58-
context.getRelyingPartyRegistration().getSigningCredentials() :
59-
emptyList();
59+
AuthnRequest authnRequest = createAuthnRequest(context);
60+
String xml = context.getRelyingPartyRegistration().getProviderDetails().isSignAuthNRequest() ?
61+
this.saml.serialize(authnRequest, context.getRelyingPartyRegistration().getSigningCredentials()) :
62+
this.saml.serialize(authnRequest);
6063

61-
String xml = createAuthenticationRequest(context, signingCredentials);
6264
return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
6365
.samlRequest(samlEncode(xml.getBytes(UTF_8)))
6466
.build();
@@ -69,7 +71,8 @@ public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2Authe
6971
*/
7072
@Override
7173
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
72-
String xml = createAuthenticationRequest(context, emptyList());
74+
AuthnRequest authnRequest = createAuthnRequest(context);
75+
String xml = this.saml.serialize(authnRequest);
7376
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
7477
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
7578
result.samlRequest(deflatedAndEncoded)
@@ -91,27 +94,24 @@ public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Sa
9194
return result.build();
9295
}
9396

94-
private String createAuthenticationRequest(Saml2AuthenticationRequestContext request, List<Saml2X509Credential> credentials) {
95-
return createAuthenticationRequest(Saml2AuthenticationRequest.withAuthenticationRequestContext(request).build(), credentials);
97+
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
98+
return createAuthnRequest(context.getIssuer(),
99+
context.getDestination(), context.getAssertionConsumerServiceUrl());
96100
}
97101

98-
private String createAuthenticationRequest(Saml2AuthenticationRequest context, List<Saml2X509Credential> credentials) {
99-
AuthnRequest auth = this.saml.buildSAMLObject(AuthnRequest.class);
102+
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
103+
AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
100104
auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
101105
auth.setIssueInstant(new DateTime(this.clock.millis()));
102106
auth.setForceAuthn(Boolean.FALSE);
103107
auth.setIsPassive(Boolean.FALSE);
104108
auth.setProtocolBinding(protocolBinding);
105-
Issuer issuer = this.saml.buildSAMLObject(Issuer.class);
106-
issuer.setValue(context.getIssuer());
107-
auth.setIssuer(issuer);
108-
auth.setDestination(context.getDestination());
109-
auth.setAssertionConsumerServiceURL(context.getAssertionConsumerServiceUrl());
110-
return this.saml.toXml(
111-
auth,
112-
credentials,
113-
context.getIssuer()
114-
);
109+
Issuer iss = this.saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
110+
iss.setValue(issuer);
111+
auth.setIssuer(iss);
112+
auth.setDestination(destination);
113+
auth.setAssertionConsumerServiceURL(assertionConsumerServiceUrl);
114+
return auth;
115115
}
116116

117117
/**

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlImplementation.java

+31-35
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616

1717
package org.springframework.security.saml2.provider.service.authentication;
1818

19+
import java.io.ByteArrayInputStream;
20+
import java.nio.charset.Charset;
21+
import java.nio.charset.StandardCharsets;
22+
import java.util.HashMap;
23+
import java.util.LinkedHashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
import javax.xml.XMLConstants;
27+
import javax.xml.namespace.QName;
28+
1929
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
2030
import net.shibboleth.utilities.java.support.xml.BasicParserPool;
2131
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
@@ -24,13 +34,14 @@
2434
import org.opensaml.core.config.InitializationException;
2535
import org.opensaml.core.config.InitializationService;
2636
import org.opensaml.core.xml.XMLObject;
37+
import org.opensaml.core.xml.XMLObjectBuilderFactory;
2738
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
2839
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
2940
import org.opensaml.core.xml.io.MarshallerFactory;
3041
import org.opensaml.core.xml.io.MarshallingException;
3142
import org.opensaml.core.xml.io.UnmarshallerFactory;
3243
import org.opensaml.core.xml.io.UnmarshallingException;
33-
import org.opensaml.saml.common.SignableSAMLObject;
44+
import org.opensaml.saml.saml2.core.AuthnRequest;
3445
import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
3546
import org.opensaml.security.SecurityException;
3647
import org.opensaml.security.credential.BasicCredential;
@@ -47,35 +58,26 @@
4758
import org.opensaml.xmlsec.signature.support.SignatureConstants;
4859
import org.opensaml.xmlsec.signature.support.SignatureException;
4960
import org.opensaml.xmlsec.signature.support.SignatureSupport;
61+
import org.w3c.dom.Document;
62+
import org.w3c.dom.Element;
63+
5064
import org.springframework.security.saml2.Saml2Exception;
5165
import org.springframework.security.saml2.credentials.Saml2X509Credential;
52-
import org.springframework.security.saml2.provider.service.authentication.Saml2Utils;
5366
import org.springframework.util.Assert;
5467
import org.springframework.web.util.UriUtils;
55-
import org.w3c.dom.Document;
56-
import org.w3c.dom.Element;
57-
58-
import javax.xml.XMLConstants;
59-
import javax.xml.namespace.QName;
60-
import java.io.ByteArrayInputStream;
61-
import java.nio.charset.Charset;
62-
import java.nio.charset.StandardCharsets;
63-
import java.util.HashMap;
64-
import java.util.LinkedHashMap;
65-
import java.util.List;
66-
import java.util.Map;
6768

6869
import static java.lang.Boolean.FALSE;
6970
import static java.lang.Boolean.TRUE;
7071
import static java.util.Arrays.asList;
71-
import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
7272
import static org.springframework.util.StringUtils.hasText;
7373

7474
/**
7575
* @since 5.2
7676
*/
7777
final class OpenSamlImplementation {
7878
private static OpenSamlImplementation instance = new OpenSamlImplementation();
79+
private static XMLObjectBuilderFactory xmlObjectBuilderFactory =
80+
XMLObjectProviderRegistrySupport.getBuilderFactory();
7981

8082
private final BasicParserPool parserPool = new BasicParserPool();
8183
private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
@@ -167,37 +169,31 @@ EncryptedKeyResolver getEncryptedKeyResolver() {
167169
return this.encryptedKeyResolver;
168170
}
169171

170-
<T> T buildSAMLObject(final Class<T> clazz) {
171-
try {
172-
QName defaultElementName = (QName) clazz.getDeclaredField("DEFAULT_ELEMENT_NAME").get(null);
173-
return (T) getBuilderFactory().getBuilder(defaultElementName).buildObject(defaultElementName);
174-
}
175-
catch (NoSuchFieldException | IllegalAccessException e) {
176-
throw new Saml2Exception("Could not create SAML object", e);
177-
}
172+
<T> T buildSamlObject(QName qName) {
173+
return (T) xmlObjectBuilderFactory.getBuilder(qName).buildObject(qName);
178174
}
179175

180176
XMLObject resolve(String xml) {
181177
return resolve(xml.getBytes(StandardCharsets.UTF_8));
182178
}
183179

184-
String toXml(XMLObject object, List<Saml2X509Credential> signingCredentials, String localSpEntityId) {
185-
if (object instanceof SignableSAMLObject && null != hasSigningCredential(signingCredentials)) {
186-
signXmlObject(
187-
(SignableSAMLObject) object,
188-
signingCredentials,
189-
localSpEntityId
190-
);
191-
}
180+
String serialize(XMLObject xmlObject) {
192181
final MarshallerFactory marshallerFactory = XMLObjectProviderRegistrySupport.getMarshallerFactory();
193182
try {
194-
Element element = marshallerFactory.getMarshaller(object).marshall(object);
183+
Element element = marshallerFactory.getMarshaller(xmlObject).marshall(xmlObject);
195184
return SerializeSupport.nodeToString(element);
196185
} catch (MarshallingException e) {
197186
throw new Saml2Exception(e);
198187
}
199188
}
200189

190+
String serialize(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
191+
if (hasSigningCredential(signingCredentials) != null) {
192+
signAuthnRequest(authnRequest, signingCredentials);
193+
}
194+
return serialize(authnRequest);
195+
}
196+
201197
/**
202198
* Returns query parameter after creating a Query String signature
203199
* All return values are unencoded and will need to be encoded prior to sending
@@ -306,15 +302,15 @@ private Credential getSigningCredential(List<Saml2X509Credential> signingCredent
306302
return cred;
307303
}
308304

309-
private void signXmlObject(SignableSAMLObject object, List<Saml2X509Credential> signingCredentials, String entityId) {
305+
private void signAuthnRequest(AuthnRequest authnRequest, List<Saml2X509Credential> signingCredentials) {
310306
SignatureSigningParameters parameters = new SignatureSigningParameters();
311-
Credential credential = getSigningCredential(signingCredentials, entityId);
307+
Credential credential = getSigningCredential(signingCredentials, authnRequest.getIssuer().getValue());
312308
parameters.setSigningCredential(credential);
313309
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
314310
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
315311
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
316312
try {
317-
SignatureSupport.signObject(object, parameters);
313+
SignatureSupport.signObject(authnRequest, parameters);
318314
} catch (MarshallingException | SignatureException | SecurityException e) {
319315
throw new Saml2Exception(e);
320316
}

0 commit comments

Comments
 (0)