1
1
/*
2
- * Copyright 2002-2022 the original author or authors.
2
+ * Copyright 2002-2023 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
26
26
import jakarta .servlet .DispatcherType ;
27
27
import jakarta .servlet .ServletContext ;
28
28
import jakarta .servlet .ServletRegistration ;
29
+ import jakarta .servlet .http .HttpServletRequest ;
29
30
30
31
import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
31
32
import org .springframework .context .ApplicationContext ;
@@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
203
204
if (!hasDispatcherServlet (registrations )) {
204
205
return requestMatchers (RequestMatchers .antMatchersAsArray (method , patterns ));
205
206
}
206
- if (registrations .size () > 1 ) {
207
- String errorMessage = computeErrorMessage (registrations .values ());
208
- throw new IllegalArgumentException (errorMessage );
207
+ ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet (registrations );
208
+ if (dispatcherServlet != null ) {
209
+ if (registrations .size () == 1 ) {
210
+ return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
211
+ }
212
+ List <RequestMatcher > matchers = new ArrayList <>();
213
+ for (String pattern : patterns ) {
214
+ AntPathRequestMatcher ant = new AntPathRequestMatcher (pattern , (method != null ) ? method .name () : null );
215
+ MvcRequestMatcher mvc = createMvcMatchers (method , pattern ).get (0 );
216
+ matchers .add (new DispatcherServletDelegatingRequestMatcher (ant , mvc , servletContext ));
217
+ }
218
+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
209
219
}
210
- return requestMatchers (createMvcMatchers (method , patterns ).toArray (new RequestMatcher [0 ]));
220
+ dispatcherServlet = requireOnlyPathMappedDispatcherServlet (registrations );
221
+ if (dispatcherServlet != null ) {
222
+ String mapping = dispatcherServlet .getMappings ().iterator ().next ();
223
+ List <MvcRequestMatcher > matchers = createMvcMatchers (method , patterns );
224
+ for (MvcRequestMatcher matcher : matchers ) {
225
+ matcher .setServletPath (mapping .substring (0 , mapping .length () - 2 ));
226
+ }
227
+ return requestMatchers (matchers .toArray (new RequestMatcher [0 ]));
228
+ }
229
+ String errorMessage = computeErrorMessage (registrations .values ());
230
+ throw new IllegalArgumentException (errorMessage );
211
231
}
212
232
213
233
private Map <String , ? extends ServletRegistration > mappableServletRegistrations (ServletContext servletContext ) {
@@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
225
245
if (registrations == null ) {
226
246
return false ;
227
247
}
228
- Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
229
- null );
230
248
for (ServletRegistration registration : registrations .values ()) {
231
- try {
232
- Class <?> clazz = Class .forName (registration .getClassName ());
233
- if (dispatcherServlet .isAssignableFrom (clazz )) {
234
- return true ;
235
- }
236
- }
237
- catch (ClassNotFoundException ex ) {
238
- return false ;
249
+ if (isDispatcherServlet (registration )) {
250
+ return true ;
239
251
}
240
252
}
241
253
return false ;
242
254
}
243
255
256
+ private ServletRegistration requireOneRootDispatcherServlet (
257
+ Map <String , ? extends ServletRegistration > registrations ) {
258
+ ServletRegistration rootDispatcherServlet = null ;
259
+ for (ServletRegistration registration : registrations .values ()) {
260
+ if (!isDispatcherServlet (registration )) {
261
+ continue ;
262
+ }
263
+ if (registration .getMappings ().size () > 1 ) {
264
+ return null ;
265
+ }
266
+ if (!"/" .equals (registration .getMappings ().iterator ().next ())) {
267
+ return null ;
268
+ }
269
+ rootDispatcherServlet = registration ;
270
+ }
271
+ return rootDispatcherServlet ;
272
+ }
273
+
274
+ private ServletRegistration requireOnlyPathMappedDispatcherServlet (
275
+ Map <String , ? extends ServletRegistration > registrations ) {
276
+ ServletRegistration pathDispatcherServlet = null ;
277
+ for (ServletRegistration registration : registrations .values ()) {
278
+ if (!isDispatcherServlet (registration )) {
279
+ return null ;
280
+ }
281
+ if (registration .getMappings ().size () > 1 ) {
282
+ return null ;
283
+ }
284
+ String mapping = registration .getMappings ().iterator ().next ();
285
+ if (!mapping .startsWith ("/" ) || !mapping .endsWith ("/*" )) {
286
+ return null ;
287
+ }
288
+ if (pathDispatcherServlet != null ) {
289
+ return null ;
290
+ }
291
+ pathDispatcherServlet = registration ;
292
+ }
293
+ return pathDispatcherServlet ;
294
+ }
295
+
296
+ private boolean isDispatcherServlet (ServletRegistration registration ) {
297
+ Class <?> dispatcherServlet = ClassUtils .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" ,
298
+ null );
299
+ try {
300
+ Class <?> clazz = Class .forName (registration .getClassName ());
301
+ return dispatcherServlet .isAssignableFrom (clazz );
302
+ }
303
+ catch (ClassNotFoundException ex ) {
304
+ return false ;
305
+ }
306
+ }
307
+
244
308
private String computeErrorMessage (Collection <? extends ServletRegistration > registrations ) {
245
309
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
246
310
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
380
444
381
445
}
382
446
447
+ static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
448
+
449
+ private final AntPathRequestMatcher ant ;
450
+
451
+ private final MvcRequestMatcher mvc ;
452
+
453
+ private final ServletContext servletContext ;
454
+
455
+ DispatcherServletDelegatingRequestMatcher (AntPathRequestMatcher ant , MvcRequestMatcher mvc ,
456
+ ServletContext servletContext ) {
457
+ this .ant = ant ;
458
+ this .mvc = mvc ;
459
+ this .servletContext = servletContext ;
460
+ }
461
+
462
+ @ Override
463
+ public boolean matches (HttpServletRequest request ) {
464
+ String name = request .getHttpServletMapping ().getServletName ();
465
+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
466
+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
467
+ if (isDispatcherServlet (registration )) {
468
+ return this .mvc .matches (request );
469
+ }
470
+ return this .ant .matches (request );
471
+ }
472
+
473
+ @ Override
474
+ public MatchResult matcher (HttpServletRequest request ) {
475
+ String name = request .getHttpServletMapping ().getServletName ();
476
+ ServletRegistration registration = this .servletContext .getServletRegistration (name );
477
+ Assert .notNull (registration , "Failed to find servlet [" + name + "] in the servlet context" );
478
+ if (isDispatcherServlet (registration )) {
479
+ return this .mvc .matcher (request );
480
+ }
481
+ return this .ant .matcher (request );
482
+ }
483
+
484
+ private boolean isDispatcherServlet (ServletRegistration registration ) {
485
+ Class <?> dispatcherServlet = ClassUtils
486
+ .resolveClassName ("org.springframework.web.servlet.DispatcherServlet" , null );
487
+ try {
488
+ Class <?> clazz = Class .forName (registration .getClassName ());
489
+ return dispatcherServlet .isAssignableFrom (clazz );
490
+ }
491
+ catch (ClassNotFoundException ex ) {
492
+ return false ;
493
+ }
494
+ }
495
+
496
+ }
497
+
383
498
}
0 commit comments