1 //
2 //  ========================================================================
3 //  Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
4 //  ------------------------------------------------------------------------
5 //  All rights reserved. This program and the accompanying materials
6 //  are made available under the terms of the Eclipse Public License v1.0
7 //  and Apache License v2.0 which accompanies this distribution.
8 //
9 //      The Eclipse Public License is available at
10 //      http://www.eclipse.org/legal/epl-v10.html
11 //
12 //      The Apache License v2.0 is available at
13 //      http://www.opensource.org/licenses/apache2.0.php
14 //
15 //  You may elect to redistribute this code under either of these licenses.
16 //  ========================================================================
17 //
18 
19 package org.eclipse.jetty.servlets;
20 
21 import java.io.IOException;
22 import java.io.Serializable;
23 import java.util.ArrayList;
24 import java.util.Iterator;
25 import java.util.List;
26 import java.util.Queue;
27 import java.util.concurrent.ConcurrentHashMap;
28 import java.util.concurrent.ConcurrentLinkedQueue;
29 import java.util.concurrent.CopyOnWriteArrayList;
30 import java.util.concurrent.Semaphore;
31 import java.util.concurrent.TimeUnit;
32 import java.util.regex.Matcher;
33 import java.util.regex.Pattern;
34 import javax.servlet.Filter;
35 import javax.servlet.FilterChain;
36 import javax.servlet.FilterConfig;
37 import javax.servlet.ServletContext;
38 import javax.servlet.ServletException;
39 import javax.servlet.ServletRequest;
40 import javax.servlet.ServletResponse;
41 import javax.servlet.http.HttpServletRequest;
42 import javax.servlet.http.HttpServletResponse;
43 import javax.servlet.http.HttpSession;
44 import javax.servlet.http.HttpSessionActivationListener;
45 import javax.servlet.http.HttpSessionBindingEvent;
46 import javax.servlet.http.HttpSessionBindingListener;
47 import javax.servlet.http.HttpSessionEvent;
48 
49 import org.eclipse.jetty.continuation.Continuation;
50 import org.eclipse.jetty.continuation.ContinuationListener;
51 import org.eclipse.jetty.continuation.ContinuationSupport;
52 import org.eclipse.jetty.server.handler.ContextHandler;
53 import org.eclipse.jetty.util.log.Log;
54 import org.eclipse.jetty.util.log.Logger;
55 import org.eclipse.jetty.util.thread.Timeout;
56 
57 /**
58  * Denial of Service filter
59  * <p/>
60  * <p>
61  * This filter is useful for limiting
62  * exposure to abuse from request flooding, whether malicious, or as a result of
63  * a misconfigured client.
64  * <p>
65  * The filter keeps track of the number of requests from a connection per
66  * second. If a limit is exceeded, the request is either rejected, delayed, or
67  * throttled.
68  * <p>
69  * When a request is throttled, it is placed in a priority queue. Priority is
70  * given first to authenticated users and users with an HttpSession, then
71  * connections which can be identified by their IP addresses. Connections with
72  * no way to identify them are given lowest priority.
73  * <p>
74  * The {@link #extractUserId(ServletRequest request)} function should be
75  * implemented, in order to uniquely identify authenticated users.
76  * <p>
77  * The following init parameters control the behavior of the filter:<dl>
78  * <p/>
79  * <dt>maxRequestsPerSec</dt>
80  * <dd>the maximum number of requests from a connection per
81  * second. Requests in excess of this are first delayed,
82  * then throttled.</dd>
83  * <p/>
84  * <dt>delayMs</dt>
85  * <dd>is the delay given to all requests over the rate limit,
86  * before they are considered at all. -1 means just reject request,
87  * 0 means no delay, otherwise it is the delay.</dd>
88  * <p/>
89  * <dt>maxWaitMs</dt>
90  * <dd>how long to blocking wait for the throttle semaphore.</dd>
91  * <p/>
92  * <dt>throttledRequests</dt>
93  * <dd>is the number of requests over the rate limit able to be
94  * considered at once.</dd>
95  * <p/>
96  * <dt>throttleMs</dt>
97  * <dd>how long to async wait for semaphore.</dd>
98  * <p/>
99  * <dt>maxRequestMs</dt>
100  * <dd>how long to allow this request to run.</dd>
101  * <p/>
102  * <dt>maxIdleTrackerMs</dt>
103  * <dd>how long to keep track of request rates for a connection,
104  * before deciding that the user has gone away, and discarding it</dd>
105  * <p/>
106  * <dt>insertHeaders</dt>
107  * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd>
108  * <p/>
109  * <dt>trackSessions</dt>
110  * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd>
111  * <p/>
112  * <dt>remotePort</dt>
113  * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd>
114  * <p/>
115  * <dt>ipWhitelist</dt>
116  * <dd>a comma-separated list of IP addresses that will not be rate limited</dd>
117  * <p/>
118  * <dt>managedAttr</dt>
119  * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the
120  * filter name as the attribute name.  This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to
121  * manage the configuration of the filter.</dd>
122  * </dl>
123  * </p>
124  */
125 public class DoSFilter implements Filter
126 {
127     private static final Logger LOG = Log.getLogger(DoSFilter.class);
128 
129     private static final String IPv4_GROUP = "(\\d{1,3})";
130     private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP);
131     private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})";
132     private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP);
133     private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)");
134 
135     private static final String __TRACKER = "DoSFilter.Tracker";
136     private static final String __THROTTLED = "DoSFilter.Throttled";
137 
138     private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25;
139     private static final int __DEFAULT_DELAY_MS = 100;
140     private static final int __DEFAULT_THROTTLE = 5;
141     private static final int __DEFAULT_MAX_WAIT_MS = 50;
142     private static final long __DEFAULT_THROTTLE_MS = 30000L;
143     private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L;
144     private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L;
145 
146     static final String MANAGED_ATTR_INIT_PARAM = "managedAttr";
147     static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec";
148     static final String DELAY_MS_INIT_PARAM = "delayMs";
149     static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests";
150     static final String MAX_WAIT_INIT_PARAM = "maxWaitMs";
151     static final String THROTTLE_MS_INIT_PARAM = "throttleMs";
152     static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs";
153     static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs";
154     static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders";
155     static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions";
156     static final String REMOTE_PORT_INIT_PARAM = "remotePort";
157     static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist";
158     static final String ENABLED_INIT_PARAM = "enabled";
159 
160     private static final int USER_AUTH = 2;
161     private static final int USER_SESSION = 2;
162     private static final int USER_IP = 1;
163     private static final int USER_UNKNOWN = 0;
164 
165     private ServletContext _context;
166     private volatile long _delayMs;
167     private volatile long _throttleMs;
168     private volatile long _maxWaitMs;
169     private volatile long _maxRequestMs;
170     private volatile long _maxIdleTrackerMs;
171     private volatile boolean _insertHeaders;
172     private volatile boolean _trackSessions;
173     private volatile boolean _remotePort;
174     private volatile boolean _enabled;
175     private Semaphore _passes;
176     private volatile int _throttledRequests;
177     private volatile int _maxRequestsPerSec;
178     private Queue<Continuation>[] _queue;
179     private ContinuationListener[] _listeners;
180     private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>();
181     private final List<String> _whitelist = new CopyOnWriteArrayList<String>();
182     private final Timeout _requestTimeoutQ = new Timeout();
183     private final Timeout _trackerTimeoutQ = new Timeout();
184     private Thread _timerThread;
185     private volatile boolean _running;
186 
init(FilterConfig filterConfig)187     public void init(FilterConfig filterConfig)
188     {
189         _context = filterConfig.getServletContext();
190 
191         _queue = new Queue[getMaxPriority() + 1];
192         _listeners = new ContinuationListener[getMaxPriority() + 1];
193         for (int p = 0; p < _queue.length; p++)
194         {
195             _queue[p] = new ConcurrentLinkedQueue<Continuation>();
196 
197             final int priority = p;
198             _listeners[p] = new ContinuationListener()
199             {
200                 public void onComplete(Continuation continuation)
201                 {
202                 }
203 
204                 public void onTimeout(Continuation continuation)
205                 {
206                     _queue[priority].remove(continuation);
207                 }
208             };
209         }
210 
211         _rateTrackers.clear();
212 
213         int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC;
214         String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM);
215         if (parameter != null)
216             maxRequests = Integer.parseInt(parameter);
217         setMaxRequestsPerSec(maxRequests);
218 
219         long delay = __DEFAULT_DELAY_MS;
220         parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM);
221         if (parameter != null)
222             delay = Long.parseLong(parameter);
223         setDelayMs(delay);
224 
225         int throttledRequests = __DEFAULT_THROTTLE;
226         parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM);
227         if (parameter != null)
228             throttledRequests = Integer.parseInt(parameter);
229         setThrottledRequests(throttledRequests);
230 
231         long maxWait = __DEFAULT_MAX_WAIT_MS;
232         parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM);
233         if (parameter != null)
234             maxWait = Long.parseLong(parameter);
235         setMaxWaitMs(maxWait);
236 
237         long throttle = __DEFAULT_THROTTLE_MS;
238         parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM);
239         if (parameter != null)
240             throttle = Long.parseLong(parameter);
241         setThrottleMs(throttle);
242 
243         long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM;
244         parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM);
245         if (parameter != null)
246             maxRequestMs = Long.parseLong(parameter);
247         setMaxRequestMs(maxRequestMs);
248 
249         long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM;
250         parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM);
251         if (parameter != null)
252             maxIdleTrackerMs = Long.parseLong(parameter);
253         setMaxIdleTrackerMs(maxIdleTrackerMs);
254 
255         String whiteList = "";
256         parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM);
257         if (parameter != null)
258             whiteList = parameter;
259         setWhitelist(whiteList);
260 
261         parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM);
262         setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter));
263 
264         parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM);
265         setTrackSessions(parameter == null || Boolean.parseBoolean(parameter));
266 
267         parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM);
268         setRemotePort(parameter != null && Boolean.parseBoolean(parameter));
269 
270         parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM);
271         setEnabled(parameter == null || Boolean.parseBoolean(parameter));
272 
273         _requestTimeoutQ.setNow();
274         _requestTimeoutQ.setDuration(_maxRequestMs);
275 
276         _trackerTimeoutQ.setNow();
277         _trackerTimeoutQ.setDuration(_maxIdleTrackerMs);
278 
279         _running = true;
280         _timerThread = (new Thread()
281         {
282             public void run()
283             {
284                 try
285                 {
286                     while (_running)
287                     {
288                         long now = _requestTimeoutQ.setNow();
289                         _requestTimeoutQ.tick();
290                         _trackerTimeoutQ.setNow(now);
291                         _trackerTimeoutQ.tick();
292                         try
293                         {
294                             Thread.sleep(100);
295                         }
296                         catch (InterruptedException e)
297                         {
298                             LOG.ignore(e);
299                         }
300                     }
301                 }
302                 finally
303                 {
304                     LOG.debug("DoSFilter timer exited");
305                 }
306             }
307         });
308         _timerThread.start();
309 
310         if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM)))
311             _context.setAttribute(filterConfig.getFilterName(), this);
312     }
313 
doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)314     public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException
315     {
316         doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain);
317     }
318 
doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)319     protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException
320     {
321         if (!isEnabled())
322         {
323             filterChain.doFilter(request, response);
324             return;
325         }
326 
327         final long now = _requestTimeoutQ.getNow();
328 
329         // Look for the rate tracker for this request
330         RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER);
331 
332         if (tracker == null)
333         {
334             // This is the first time we have seen this request.
335 
336             // get a rate tracker associated with this request, and record one hit
337             tracker = getRateTracker(request);
338 
339             // Calculate the rate and check it is over the allowed limit
340             final boolean overRateLimit = tracker.isRateExceeded(now);
341 
342             // pass it through if  we are not currently over the rate limit
343             if (!overRateLimit)
344             {
345                 doFilterChain(filterChain, request, response);
346                 return;
347             }
348 
349             // We are over the limit.
350 
351             // So either reject it, delay it or throttle it
352             long delayMs = getDelayMs();
353             boolean insertHeaders = isInsertHeaders();
354             switch ((int)delayMs)
355             {
356                 case -1:
357                 {
358                     // Reject this request
359                     LOG.warn("DOS ALERT: Request rejected ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
360                     if (insertHeaders)
361                         response.addHeader("DoSFilter", "unavailable");
362                     response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
363                     return;
364                 }
365                 case 0:
366                 {
367                     // fall through to throttle code
368                     LOG.warn("DOS ALERT: Request throttled ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
369                     request.setAttribute(__TRACKER, tracker);
370                     break;
371                 }
372                 default:
373                 {
374                     // insert a delay before throttling the request
375                     LOG.warn("DOS ALERT: Request delayed="+delayMs+"ms ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal());
376                     if (insertHeaders)
377                         response.addHeader("DoSFilter", "delayed");
378                     Continuation continuation = ContinuationSupport.getContinuation(request);
379                     request.setAttribute(__TRACKER, tracker);
380                     if (delayMs > 0)
381                         continuation.setTimeout(delayMs);
382                     continuation.suspend();
383                     return;
384                 }
385             }
386         }
387 
388         // Throttle the request
389         boolean accepted = false;
390         try
391         {
392             // check if we can afford to accept another request at this time
393             accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS);
394 
395             if (!accepted)
396             {
397                 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail
398                 final Continuation continuation = ContinuationSupport.getContinuation(request);
399 
400                 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED);
401                 long throttleMs = getThrottleMs();
402                 if (throttled != Boolean.TRUE && throttleMs > 0)
403                 {
404                     int priority = getPriority(request, tracker);
405                     request.setAttribute(__THROTTLED, Boolean.TRUE);
406                     if (isInsertHeaders())
407                         response.addHeader("DoSFilter", "throttled");
408                     if (throttleMs > 0)
409                         continuation.setTimeout(throttleMs);
410                     continuation.suspend();
411 
412                     continuation.addContinuationListener(_listeners[priority]);
413                     _queue[priority].add(continuation);
414                     return;
415                 }
416                 // else were we resumed?
417                 else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE)
418                 {
419                     // we were resumed and somebody stole our pass, so we wait for the next one.
420                     _passes.acquire();
421                     accepted = true;
422                 }
423             }
424 
425             // if we were accepted (either immediately or after throttle)
426             if (accepted)
427                 // call the chain
428                 doFilterChain(filterChain, request, response);
429             else
430             {
431                 // fail the request
432                 if (isInsertHeaders())
433                     response.addHeader("DoSFilter", "unavailable");
434                 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
435             }
436         }
437         catch (InterruptedException e)
438         {
439             _context.log("DoS", e);
440             response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE);
441         }
442         finally
443         {
444             if (accepted)
445             {
446                 // wake up the next highest priority request.
447                 for (int p = _queue.length; p-- > 0; )
448                 {
449                     Continuation continuation = _queue[p].poll();
450                     if (continuation != null && continuation.isSuspended())
451                     {
452                         continuation.resume();
453                         break;
454                     }
455                 }
456                 _passes.release();
457             }
458         }
459     }
460 
doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)461     protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException
462     {
463         final Thread thread = Thread.currentThread();
464 
465         final Timeout.Task requestTimeout = new Timeout.Task()
466         {
467             public void expired()
468             {
469                 closeConnection(request, response, thread);
470             }
471         };
472 
473         try
474         {
475             _requestTimeoutQ.schedule(requestTimeout);
476             chain.doFilter(request, response);
477         }
478         finally
479         {
480             requestTimeout.cancel();
481         }
482     }
483 
484     /**
485      * Takes drastic measures to return this response and stop this thread.
486      * Due to the way the connection is interrupted, may return mixed up headers.
487      *
488      * @param request  current request
489      * @param response current response, which must be stopped
490      * @param thread   the handling thread
491      */
closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)492     protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
493     {
494         // take drastic measures to return this response and stop this thread.
495         if (!response.isCommitted())
496         {
497             response.setHeader("Connection", "close");
498         }
499         try
500         {
501             try
502             {
503                 response.getWriter().close();
504             }
505             catch (IllegalStateException e)
506             {
507                 response.getOutputStream().close();
508             }
509         }
510         catch (IOException e)
511         {
512             LOG.warn(e);
513         }
514 
515         // interrupt the handling thread
516         thread.interrupt();
517     }
518 
519     /**
520      * Get priority for this request, based on user type
521      *
522      * @param request the current request
523      * @param tracker the rate tracker for this request
524      * @return the priority for this request
525      */
getPriority(HttpServletRequest request, RateTracker tracker)526     protected int getPriority(HttpServletRequest request, RateTracker tracker)
527     {
528         if (extractUserId(request) != null)
529             return USER_AUTH;
530         if (tracker != null)
531             return tracker.getType();
532         return USER_UNKNOWN;
533     }
534 
535     /**
536      * @return the maximum priority that we can assign to a request
537      */
getMaxPriority()538     protected int getMaxPriority()
539     {
540         return USER_AUTH;
541     }
542 
543     /**
544      * Return a request rate tracker associated with this connection; keeps
545      * track of this connection's request rate. If this is not the first request
546      * from this connection, return the existing object with the stored stats.
547      * If it is the first request, then create a new request tracker.
548      * <p/>
549      * Assumes that each connection has an identifying characteristic, and goes
550      * through them in order, taking the first that matches: user id (logged
551      * in), session id, client IP address. Unidentifiable connections are lumped
552      * into one.
553      * <p/>
554      * When a session expires, its rate tracker is automatically deleted.
555      *
556      * @param request the current request
557      * @return the request rate tracker for the current connection
558      */
getRateTracker(ServletRequest request)559     public RateTracker getRateTracker(ServletRequest request)
560     {
561         HttpSession session = ((HttpServletRequest)request).getSession(false);
562 
563         String loadId = extractUserId(request);
564         final int type;
565         if (loadId != null)
566         {
567             type = USER_AUTH;
568         }
569         else
570         {
571             if (_trackSessions && session != null && !session.isNew())
572             {
573                 loadId = session.getId();
574                 type = USER_SESSION;
575             }
576             else
577             {
578                 loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr();
579                 type = USER_IP;
580             }
581         }
582 
583         RateTracker tracker = _rateTrackers.get(loadId);
584 
585         if (tracker == null)
586         {
587             boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr());
588             tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec)
589                     : new RateTracker(loadId, type, _maxRequestsPerSec);
590             RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker);
591             if (existing != null)
592                 tracker = existing;
593 
594             if (type == USER_IP)
595             {
596                 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ
597                 _trackerTimeoutQ.schedule(tracker);
598             }
599             else if (session != null)
600             {
601                 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener
602                 session.setAttribute(__TRACKER, tracker);
603             }
604         }
605 
606         return tracker;
607     }
608 
checkWhitelist(List<String> whitelist, String candidate)609     protected boolean checkWhitelist(List<String> whitelist, String candidate)
610     {
611         for (String address : whitelist)
612         {
613             if (address.contains("/"))
614             {
615                 if (subnetMatch(address, candidate))
616                     return true;
617             }
618             else
619             {
620                 if (address.equals(candidate))
621                     return true;
622             }
623         }
624         return false;
625     }
626 
subnetMatch(String subnetAddress, String address)627     protected boolean subnetMatch(String subnetAddress, String address)
628     {
629         Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress);
630         if (!cidrMatcher.matches())
631             return false;
632 
633         String subnet = cidrMatcher.group(1);
634         int prefix;
635         try
636         {
637             prefix = Integer.parseInt(cidrMatcher.group(2));
638         }
639         catch (NumberFormatException x)
640         {
641             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
642             return false;
643         }
644 
645         byte[] subnetBytes = addressToBytes(subnet);
646         if (subnetBytes == null)
647         {
648             LOG.info("Ignoring malformed CIDR address {}", subnetAddress);
649             return false;
650         }
651         byte[] addressBytes = addressToBytes(address);
652         if (addressBytes == null)
653         {
654             LOG.info("Ignoring malformed remote address {}", address);
655             return false;
656         }
657 
658         // Comparing IPv4 with IPv6 ?
659         int length = subnetBytes.length;
660         if (length != addressBytes.length)
661             return false;
662 
663         byte[] mask = prefixToBytes(prefix, length);
664 
665         for (int i = 0; i < length; ++i)
666         {
667             if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i]))
668                 return false;
669         }
670 
671         return true;
672     }
673 
addressToBytes(String address)674     private byte[] addressToBytes(String address)
675     {
676         Matcher ipv4Matcher = IPv4_PATTERN.matcher(address);
677         if (ipv4Matcher.matches())
678         {
679             byte[] result = new byte[4];
680             for (int i = 0; i < result.length; ++i)
681                 result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue();
682             return result;
683         }
684         else
685         {
686             Matcher ipv6Matcher = IPv6_PATTERN.matcher(address);
687             if (ipv6Matcher.matches())
688             {
689                 byte[] result = new byte[16];
690                 for (int i = 0; i < result.length; i += 2)
691                 {
692                     int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16);
693                     result[i] = (byte)((word & 0xFF00) >>> 8);
694                     result[i + 1] = (byte)(word & 0xFF);
695                 }
696                 return result;
697             }
698         }
699         return null;
700     }
701 
prefixToBytes(int prefix, int length)702     private byte[] prefixToBytes(int prefix, int length)
703     {
704         byte[] result = new byte[length];
705         int index = 0;
706         while (prefix / 8 > 0)
707         {
708             result[index] = -1;
709             prefix -= 8;
710             ++index;
711         }
712         // Sets the _prefix_ most significant bits to 1
713         result[index] = (byte)~((1 << (8 - prefix)) - 1);
714         return result;
715     }
716 
destroy()717     public void destroy()
718     {
719         LOG.debug("Destroy {}",this);
720         _running = false;
721         _timerThread.interrupt();
722         _requestTimeoutQ.cancelAll();
723         _trackerTimeoutQ.cancelAll();
724         _rateTrackers.clear();
725         _whitelist.clear();
726     }
727 
728     /**
729      * Returns the user id, used to track this connection.
730      * This SHOULD be overridden by subclasses.
731      *
732      * @param request the current request
733      * @return a unique user id, if logged in; otherwise null.
734      */
extractUserId(ServletRequest request)735     protected String extractUserId(ServletRequest request)
736     {
737         return null;
738     }
739 
740     /**
741      * Get maximum number of requests from a connection per
742      * second. Requests in excess of this are first delayed,
743      * then throttled.
744      *
745      * @return maximum number of requests
746      */
getMaxRequestsPerSec()747     public int getMaxRequestsPerSec()
748     {
749         return _maxRequestsPerSec;
750     }
751 
752     /**
753      * Get maximum number of requests from a connection per
754      * second. Requests in excess of this are first delayed,
755      * then throttled.
756      *
757      * @param value maximum number of requests
758      */
setMaxRequestsPerSec(int value)759     public void setMaxRequestsPerSec(int value)
760     {
761         _maxRequestsPerSec = value;
762     }
763 
764     /**
765      * Get delay (in milliseconds) that is applied to all requests
766      * over the rate limit, before they are considered at all.
767      */
getDelayMs()768     public long getDelayMs()
769     {
770         return _delayMs;
771     }
772 
773     /**
774      * Set delay (in milliseconds) that is applied to all requests
775      * over the rate limit, before they are considered at all.
776      *
777      * @param value delay (in milliseconds), 0 - no delay, -1 - reject request
778      */
setDelayMs(long value)779     public void setDelayMs(long value)
780     {
781         _delayMs = value;
782     }
783 
784     /**
785      * Get maximum amount of time (in milliseconds) the filter will
786      * blocking wait for the throttle semaphore.
787      *
788      * @return maximum wait time
789      */
getMaxWaitMs()790     public long getMaxWaitMs()
791     {
792         return _maxWaitMs;
793     }
794 
795     /**
796      * Set maximum amount of time (in milliseconds) the filter will
797      * blocking wait for the throttle semaphore.
798      *
799      * @param value maximum wait time
800      */
setMaxWaitMs(long value)801     public void setMaxWaitMs(long value)
802     {
803         _maxWaitMs = value;
804     }
805 
806     /**
807      * Get number of requests over the rate limit able to be
808      * considered at once.
809      *
810      * @return number of requests
811      */
getThrottledRequests()812     public int getThrottledRequests()
813     {
814         return _throttledRequests;
815     }
816 
817     /**
818      * Set number of requests over the rate limit able to be
819      * considered at once.
820      *
821      * @param value number of requests
822      */
setThrottledRequests(int value)823     public void setThrottledRequests(int value)
824     {
825         int permits = _passes == null ? 0 : _passes.availablePermits();
826         _passes = new Semaphore((value - _throttledRequests + permits), true);
827         _throttledRequests = value;
828     }
829 
830     /**
831      * Get amount of time (in milliseconds) to async wait for semaphore.
832      *
833      * @return wait time
834      */
getThrottleMs()835     public long getThrottleMs()
836     {
837         return _throttleMs;
838     }
839 
840     /**
841      * Set amount of time (in milliseconds) to async wait for semaphore.
842      *
843      * @param value wait time
844      */
setThrottleMs(long value)845     public void setThrottleMs(long value)
846     {
847         _throttleMs = value;
848     }
849 
850     /**
851      * Get maximum amount of time (in milliseconds) to allow
852      * the request to process.
853      *
854      * @return maximum processing time
855      */
getMaxRequestMs()856     public long getMaxRequestMs()
857     {
858         return _maxRequestMs;
859     }
860 
861     /**
862      * Set maximum amount of time (in milliseconds) to allow
863      * the request to process.
864      *
865      * @param value maximum processing time
866      */
setMaxRequestMs(long value)867     public void setMaxRequestMs(long value)
868     {
869         _maxRequestMs = value;
870     }
871 
872     /**
873      * Get maximum amount of time (in milliseconds) to keep track
874      * of request rates for a connection, before deciding that
875      * the user has gone away, and discarding it.
876      *
877      * @return maximum tracking time
878      */
getMaxIdleTrackerMs()879     public long getMaxIdleTrackerMs()
880     {
881         return _maxIdleTrackerMs;
882     }
883 
884     /**
885      * Set maximum amount of time (in milliseconds) to keep track
886      * of request rates for a connection, before deciding that
887      * the user has gone away, and discarding it.
888      *
889      * @param value maximum tracking time
890      */
setMaxIdleTrackerMs(long value)891     public void setMaxIdleTrackerMs(long value)
892     {
893         _maxIdleTrackerMs = value;
894     }
895 
896     /**
897      * Check flag to insert the DoSFilter headers into the response.
898      *
899      * @return value of the flag
900      */
isInsertHeaders()901     public boolean isInsertHeaders()
902     {
903         return _insertHeaders;
904     }
905 
906     /**
907      * Set flag to insert the DoSFilter headers into the response.
908      *
909      * @param value value of the flag
910      */
setInsertHeaders(boolean value)911     public void setInsertHeaders(boolean value)
912     {
913         _insertHeaders = value;
914     }
915 
916     /**
917      * Get flag to have usage rate tracked by session if a session exists.
918      *
919      * @return value of the flag
920      */
isTrackSessions()921     public boolean isTrackSessions()
922     {
923         return _trackSessions;
924     }
925 
926     /**
927      * Set flag to have usage rate tracked by session if a session exists.
928      *
929      * @param value value of the flag
930      */
setTrackSessions(boolean value)931     public void setTrackSessions(boolean value)
932     {
933         _trackSessions = value;
934     }
935 
936     /**
937      * Get flag to have usage rate tracked by IP+port (effectively connection)
938      * if session tracking is not used.
939      *
940      * @return value of the flag
941      */
isRemotePort()942     public boolean isRemotePort()
943     {
944         return _remotePort;
945     }
946 
947     /**
948      * Set flag to have usage rate tracked by IP+port (effectively connection)
949      * if session tracking is not used.
950      *
951      * @param value value of the flag
952      */
setRemotePort(boolean value)953     public void setRemotePort(boolean value)
954     {
955         _remotePort = value;
956     }
957 
958     /**
959      * @return whether this filter is enabled
960      */
isEnabled()961     public boolean isEnabled()
962     {
963         return _enabled;
964     }
965 
966     /**
967      * @param enabled whether this filter is enabled
968      */
setEnabled(boolean enabled)969     public void setEnabled(boolean enabled)
970     {
971         _enabled = enabled;
972     }
973 
974     /**
975      * Get a list of IP addresses that will not be rate limited.
976      *
977      * @return comma-separated whitelist
978      */
getWhitelist()979     public String getWhitelist()
980     {
981         StringBuilder result = new StringBuilder();
982         for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();)
983         {
984             String address = iterator.next();
985             result.append(address);
986             if (iterator.hasNext())
987                 result.append(",");
988         }
989         return result.toString();
990     }
991 
992     /**
993      * Set a list of IP addresses that will not be rate limited.
994      *
995      * @param value comma-separated whitelist
996      */
setWhitelist(String value)997     public void setWhitelist(String value)
998     {
999         List<String> result = new ArrayList<String>();
1000         for (String address : value.split(","))
1001             addWhitelistAddress(result, address);
1002         _whitelist.clear();
1003         _whitelist.addAll(result);
1004         LOG.debug("Whitelisted IP addresses: {}", result);
1005     }
1006 
clearWhitelist()1007     public void clearWhitelist()
1008     {
1009         _whitelist.clear();
1010     }
1011 
addWhitelistAddress(String address)1012     public boolean addWhitelistAddress(String address)
1013     {
1014         return addWhitelistAddress(_whitelist, address);
1015     }
1016 
addWhitelistAddress(List<String> list, String address)1017     private boolean addWhitelistAddress(List<String> list, String address)
1018     {
1019         address = address.trim();
1020         return address.length() > 0 && list.add(address);
1021     }
1022 
removeWhitelistAddress(String address)1023     public boolean removeWhitelistAddress(String address)
1024     {
1025         return _whitelist.remove(address);
1026     }
1027 
1028     /**
1029      * A RateTracker is associated with a connection, and stores request rate
1030      * data.
1031      */
1032     class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener, Serializable
1033     {
1034         private static final long serialVersionUID = 3534663738034577872L;
1035 
1036         transient protected final String _id;
1037         transient protected final int _type;
1038         transient protected final long[] _timestamps;
1039         transient protected int _next;
1040 
RateTracker(String id, int type, int maxRequestsPerSecond)1041         public RateTracker(String id, int type, int maxRequestsPerSecond)
1042         {
1043             _id = id;
1044             _type = type;
1045             _timestamps = new long[maxRequestsPerSecond];
1046             _next = 0;
1047         }
1048 
1049         /**
1050          * @return the current calculated request rate over the last second
1051          */
isRateExceeded(long now)1052         public boolean isRateExceeded(long now)
1053         {
1054             final long last;
1055             synchronized (this)
1056             {
1057                 last = _timestamps[_next];
1058                 _timestamps[_next] = now;
1059                 _next = (_next + 1) % _timestamps.length;
1060             }
1061 
1062             return last != 0 && (now - last) < 1000L;
1063         }
1064 
getId()1065         public String getId()
1066         {
1067             return _id;
1068         }
1069 
getType()1070         public int getType()
1071         {
1072             return _type;
1073         }
1074 
valueBound(HttpSessionBindingEvent event)1075         public void valueBound(HttpSessionBindingEvent event)
1076         {
1077             if (LOG.isDebugEnabled())
1078                 LOG.debug("Value bound: {}", getId());
1079         }
1080 
valueUnbound(HttpSessionBindingEvent event)1081         public void valueUnbound(HttpSessionBindingEvent event)
1082         {
1083             //take the tracker out of the list of trackers
1084             _rateTrackers.remove(_id);
1085             if (LOG.isDebugEnabled())
1086                 LOG.debug("Tracker removed: {}", getId());
1087         }
1088 
sessionWillPassivate(HttpSessionEvent se)1089         public void sessionWillPassivate(HttpSessionEvent se)
1090         {
1091             //take the tracker of the list of trackers (if its still there)
1092             //and ensure that we take ourselves out of the session so we are not saved
1093             _rateTrackers.remove(_id);
1094             se.getSession().removeAttribute(__TRACKER);
1095             if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId());
1096         }
1097 
sessionDidActivate(HttpSessionEvent se)1098         public void sessionDidActivate(HttpSessionEvent se)
1099         {
1100             LOG.warn("Unexpected session activation");
1101         }
1102 
expired()1103         public void expired()
1104         {
1105             long now = _trackerTimeoutQ.getNow();
1106             int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1);
1107             long last = _timestamps[latestIndex];
1108             boolean hasRecentRequest = last != 0 && (now - last) < 1000L;
1109 
1110             if (hasRecentRequest)
1111                 reschedule();
1112             else
1113                 _rateTrackers.remove(_id);
1114         }
1115 
1116         @Override
1117         public String toString()
1118         {
1119             return "RateTracker/" + _id + "/" + _type;
1120         }
1121     }
1122 
1123     class FixedRateTracker extends RateTracker
1124     {
1125         public FixedRateTracker(String id, int type, int numRecentRequestsTracked)
1126         {
1127             super(id, type, numRecentRequestsTracked);
1128         }
1129 
1130         @Override
1131         public boolean isRateExceeded(long now)
1132         {
1133             // rate limit is never exceeded, but we keep track of the request timestamps
1134             // so that we know whether there was recent activity on this tracker
1135             // and whether it should be expired
1136             synchronized (this)
1137             {
1138                 _timestamps[_next] = now;
1139                 _next = (_next + 1) % _timestamps.length;
1140             }
1141 
1142             return false;
1143         }
1144 
1145         @Override
1146         public String toString()
1147         {
1148             return "Fixed" + super.toString();
1149         }
1150     }
1151 }
1152