Skip to content

Commit fa15c97

Browse files
committed
Merge branch '5.8.x' into 6.0.x
Closes gh-14084
2 parents a690232 + ffd12ee commit fa15c97

File tree

4 files changed

+292
-17
lines changed

4 files changed

+292
-17
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

+130-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
import jakarta.servlet.DispatcherType;
2727
import jakarta.servlet.ServletContext;
2828
import jakarta.servlet.ServletRegistration;
29+
import jakarta.servlet.http.HttpServletRequest;
2930

3031
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
3132
import org.springframework.context.ApplicationContext;
@@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
203204
if (!hasDispatcherServlet(registrations)) {
204205
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
205206
}
206-
if (registrations.size() > 1) {
207-
String errorMessage = computeErrorMessage(registrations.values());
208-
throw new IllegalArgumentException(errorMessage);
207+
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
208+
if (dispatcherServlet != null) {
209+
if (registrations.size() == 1) {
210+
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
211+
}
212+
List<RequestMatcher> matchers = new ArrayList<>();
213+
for (String pattern : patterns) {
214+
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
215+
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
216+
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
217+
}
218+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
209219
}
210-
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
220+
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
221+
if (dispatcherServlet != null) {
222+
String mapping = dispatcherServlet.getMappings().iterator().next();
223+
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
224+
for (MvcRequestMatcher matcher : matchers) {
225+
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
226+
}
227+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
228+
}
229+
String errorMessage = computeErrorMessage(registrations.values());
230+
throw new IllegalArgumentException(errorMessage);
211231
}
212232

213233
private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
@@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
225245
if (registrations == null) {
226246
return false;
227247
}
228-
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
229-
null);
230248
for (ServletRegistration registration : registrations.values()) {
231-
try {
232-
Class<?> clazz = Class.forName(registration.getClassName());
233-
if (dispatcherServlet.isAssignableFrom(clazz)) {
234-
return true;
235-
}
236-
}
237-
catch (ClassNotFoundException ex) {
238-
return false;
249+
if (isDispatcherServlet(registration)) {
250+
return true;
239251
}
240252
}
241253
return false;
242254
}
243255

256+
private ServletRegistration requireOneRootDispatcherServlet(
257+
Map<String, ? extends ServletRegistration> registrations) {
258+
ServletRegistration rootDispatcherServlet = null;
259+
for (ServletRegistration registration : registrations.values()) {
260+
if (!isDispatcherServlet(registration)) {
261+
continue;
262+
}
263+
if (registration.getMappings().size() > 1) {
264+
return null;
265+
}
266+
if (!"/".equals(registration.getMappings().iterator().next())) {
267+
return null;
268+
}
269+
rootDispatcherServlet = registration;
270+
}
271+
return rootDispatcherServlet;
272+
}
273+
274+
private ServletRegistration requireOnlyPathMappedDispatcherServlet(
275+
Map<String, ? extends ServletRegistration> registrations) {
276+
ServletRegistration pathDispatcherServlet = null;
277+
for (ServletRegistration registration : registrations.values()) {
278+
if (!isDispatcherServlet(registration)) {
279+
return null;
280+
}
281+
if (registration.getMappings().size() > 1) {
282+
return null;
283+
}
284+
String mapping = registration.getMappings().iterator().next();
285+
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
286+
return null;
287+
}
288+
if (pathDispatcherServlet != null) {
289+
return null;
290+
}
291+
pathDispatcherServlet = registration;
292+
}
293+
return pathDispatcherServlet;
294+
}
295+
296+
private boolean isDispatcherServlet(ServletRegistration registration) {
297+
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
298+
null);
299+
try {
300+
Class<?> clazz = Class.forName(registration.getClassName());
301+
return dispatcherServlet.isAssignableFrom(clazz);
302+
}
303+
catch (ClassNotFoundException ex) {
304+
return false;
305+
}
306+
}
307+
244308
private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
245309
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
246310
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
380444

381445
}
382446

447+
static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
448+
449+
private final AntPathRequestMatcher ant;
450+
451+
private final MvcRequestMatcher mvc;
452+
453+
private final ServletContext servletContext;
454+
455+
DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
456+
ServletContext servletContext) {
457+
this.ant = ant;
458+
this.mvc = mvc;
459+
this.servletContext = servletContext;
460+
}
461+
462+
@Override
463+
public boolean matches(HttpServletRequest request) {
464+
String name = request.getHttpServletMapping().getServletName();
465+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
466+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
467+
if (isDispatcherServlet(registration)) {
468+
return this.mvc.matches(request);
469+
}
470+
return this.ant.matches(request);
471+
}
472+
473+
@Override
474+
public MatchResult matcher(HttpServletRequest request) {
475+
String name = request.getHttpServletMapping().getServletName();
476+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
477+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
478+
if (isDispatcherServlet(registration)) {
479+
return this.mvc.matcher(request);
480+
}
481+
return this.ant.matcher(request);
482+
}
483+
484+
private boolean isDispatcherServlet(ServletRegistration registration) {
485+
Class<?> dispatcherServlet = ClassUtils
486+
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
487+
try {
488+
Class<?> clazz = Class.forName(registration.getClassName());
489+
return dispatcherServlet.isAssignableFrom(clazz);
490+
}
491+
catch (ClassNotFoundException ex) {
492+
return false;
493+
}
494+
}
495+
496+
}
497+
383498
}

config/src/test/java/org/springframework/security/config/MockServletContext.java

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class
5555
return this.registrations;
5656
}
5757

58+
@Override
59+
public ServletRegistration getServletRegistration(String servletName) {
60+
return this.registrations.get(servletName);
61+
}
62+
5863
private static class MockServletRegistration implements ServletRegistration.Dynamic {
5964

6065
private final String name;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config;
18+
19+
import jakarta.servlet.http.HttpServletRequest;
20+
import jakarta.servlet.http.MappingMatch;
21+
22+
import org.springframework.mock.web.MockHttpServletMapping;
23+
24+
public final class TestMockHttpServletMappings {
25+
26+
private TestMockHttpServletMappings() {
27+
28+
}
29+
30+
public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
31+
String uri = request.getRequestURI();
32+
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
33+
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
34+
}
35+
36+
public static MockHttpServletMapping path(HttpServletRequest request, String path) {
37+
String uri = request.getRequestURI();
38+
String matchValue = uri.substring(path.length());
39+
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
40+
}
41+
42+
public static MockHttpServletMapping defaultMapping() {
43+
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
44+
}
45+
46+
}

0 commit comments

Comments
 (0)