1
1
/*
2
- * Copyright 2012-2017 the original author or authors.
2
+ * Copyright 2012-2020 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
24
24
import java .util .HashSet ;
25
25
import java .util .List ;
26
26
import java .util .Set ;
27
+ import java .util .function .Predicate ;
27
28
28
29
/**
29
30
* <p>
59
60
* Rejects URLs that contain a URL encoded percent. See
60
61
* {@link #setAllowUrlEncodedPercent(boolean)}
61
62
* </li>
63
+ * <li>
64
+ * Rejects hosts that are not allowed. See
65
+ * {@link #setAllowedHostnames(Predicate)}
66
+ * </li>
62
67
* </ul>
63
68
*
64
69
* @see DefaultHttpFirewall
65
70
* @author Rob Winch
71
+ * @author Eddú Meléndez
66
72
* @since 5.0.1
67
73
*/
68
74
public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +88,8 @@ public class StrictHttpFirewall implements HttpFirewall {
82
88
83
89
private Set <String > decodedUrlBlacklist = new HashSet <String >();
84
90
91
+ private Predicate <String > allowedHostnames = hostname -> true ;
92
+
85
93
public StrictHttpFirewall () {
86
94
urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
87
95
urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -230,6 +238,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
230
238
}
231
239
}
232
240
241
+ public void setAllowedHostnames (Predicate <String > allowedHostnames ) {
242
+ if (allowedHostnames == null ) {
243
+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
244
+ }
245
+ this .allowedHostnames = allowedHostnames ;
246
+ }
247
+
233
248
private void urlBlacklistsAddAll (Collection <String > values ) {
234
249
this .encodedUrlBlacklist .addAll (values );
235
250
this .decodedUrlBlacklist .addAll (values );
@@ -243,6 +258,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
243
258
@ Override
244
259
public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
245
260
rejectedBlacklistedUrls (request );
261
+ rejectedUntrustedHosts (request );
246
262
247
263
if (!isNormalized (request )) {
248
264
throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -272,6 +288,13 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
272
288
}
273
289
}
274
290
291
+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
292
+ String serverName = request .getServerName ();
293
+ if (serverName != null && !this .allowedHostnames .test (serverName )) {
294
+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
295
+ }
296
+ }
297
+
275
298
@ Override
276
299
public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
277
300
return new FirewalledResponse (response );
0 commit comments