1
1
/*
2
- * Copyright 2012-2017 the original author or authors.
2
+ * Copyright 2012-2020 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
26
26
import java .util .HashSet ;
27
27
import java .util .List ;
28
28
import java .util .Set ;
29
+ import java .util .function .Predicate ;
29
30
30
31
/**
31
32
* <p>
66
67
* Rejects URLs that contain a URL encoded percent. See
67
68
* {@link #setAllowUrlEncodedPercent(boolean)}
68
69
* </li>
70
+ * <li>
71
+ * Rejects hosts that are not allowed. See
72
+ * {@link #setAllowedHostnames(Predicate)}
73
+ * </li>
69
74
* </ul>
70
75
*
71
76
* @see DefaultHttpFirewall
72
77
* @author Rob Winch
78
+ * @author Eddú Meléndez
73
79
* @since 4.2.4
74
80
*/
75
81
public class StrictHttpFirewall implements HttpFirewall {
@@ -96,6 +102,8 @@ public class StrictHttpFirewall implements HttpFirewall {
96
102
97
103
private Set <String > allowedHttpMethods = createDefaultAllowedHttpMethods ();
98
104
105
+ private Predicate <String > allowedHostnames = hostname -> true ;
106
+
99
107
public StrictHttpFirewall () {
100
108
urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
101
109
urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -277,6 +285,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
277
285
}
278
286
}
279
287
288
+ public void setAllowedHostnames (Predicate <String > allowedHostnames ) {
289
+ if (allowedHostnames == null ) {
290
+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
291
+ }
292
+ this .allowedHostnames = allowedHostnames ;
293
+ }
294
+
280
295
private void urlBlacklistsAddAll (Collection <String > values ) {
281
296
this .encodedUrlBlacklist .addAll (values );
282
297
this .decodedUrlBlacklist .addAll (values );
@@ -291,6 +306,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
291
306
public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
292
307
rejectForbiddenHttpMethod (request );
293
308
rejectedBlacklistedUrls (request );
309
+ rejectedUntrustedHosts (request );
294
310
295
311
if (!isNormalized (request )) {
296
312
throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -332,6 +348,13 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
332
348
}
333
349
}
334
350
351
+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
352
+ String serverName = request .getServerName ();
353
+ if (serverName != null && !this .allowedHostnames .test (serverName )) {
354
+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
355
+ }
356
+ }
357
+
335
358
@ Override
336
359
public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
337
360
return new FirewalledResponse (response );
0 commit comments