1 // 2 // ======================================================================== 3 // Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd. 4 // ------------------------------------------------------------------------ 5 // All rights reserved. This program and the accompanying materials 6 // are made available under the terms of the Eclipse Public License v1.0 7 // and Apache License v2.0 which accompanies this distribution. 8 // 9 // The Eclipse Public License is available at 10 // http://www.eclipse.org/legal/epl-v10.html 11 // 12 // The Apache License v2.0 is available at 13 // http://www.opensource.org/licenses/apache2.0.php 14 // 15 // You may elect to redistribute this code under either of these licenses. 16 // ======================================================================== 17 // 18 19 package org.eclipse.jetty.servlets; 20 21 import java.io.IOException; 22 import java.io.Serializable; 23 import java.util.ArrayList; 24 import java.util.Iterator; 25 import java.util.List; 26 import java.util.Queue; 27 import java.util.concurrent.ConcurrentHashMap; 28 import java.util.concurrent.ConcurrentLinkedQueue; 29 import java.util.concurrent.CopyOnWriteArrayList; 30 import java.util.concurrent.Semaphore; 31 import java.util.concurrent.TimeUnit; 32 import java.util.regex.Matcher; 33 import java.util.regex.Pattern; 34 import javax.servlet.Filter; 35 import javax.servlet.FilterChain; 36 import javax.servlet.FilterConfig; 37 import javax.servlet.ServletContext; 38 import javax.servlet.ServletException; 39 import javax.servlet.ServletRequest; 40 import javax.servlet.ServletResponse; 41 import javax.servlet.http.HttpServletRequest; 42 import javax.servlet.http.HttpServletResponse; 43 import javax.servlet.http.HttpSession; 44 import javax.servlet.http.HttpSessionActivationListener; 45 import javax.servlet.http.HttpSessionBindingEvent; 46 import javax.servlet.http.HttpSessionBindingListener; 47 import javax.servlet.http.HttpSessionEvent; 48 49 import org.eclipse.jetty.continuation.Continuation; 50 import org.eclipse.jetty.continuation.ContinuationListener; 51 import org.eclipse.jetty.continuation.ContinuationSupport; 52 import org.eclipse.jetty.server.handler.ContextHandler; 53 import org.eclipse.jetty.util.log.Log; 54 import org.eclipse.jetty.util.log.Logger; 55 import org.eclipse.jetty.util.thread.Timeout; 56 57 /** 58 * Denial of Service filter 59 * <p/> 60 * <p> 61 * This filter is useful for limiting 62 * exposure to abuse from request flooding, whether malicious, or as a result of 63 * a misconfigured client. 64 * <p> 65 * The filter keeps track of the number of requests from a connection per 66 * second. If a limit is exceeded, the request is either rejected, delayed, or 67 * throttled. 68 * <p> 69 * When a request is throttled, it is placed in a priority queue. Priority is 70 * given first to authenticated users and users with an HttpSession, then 71 * connections which can be identified by their IP addresses. Connections with 72 * no way to identify them are given lowest priority. 73 * <p> 74 * The {@link #extractUserId(ServletRequest request)} function should be 75 * implemented, in order to uniquely identify authenticated users. 76 * <p> 77 * The following init parameters control the behavior of the filter:<dl> 78 * <p/> 79 * <dt>maxRequestsPerSec</dt> 80 * <dd>the maximum number of requests from a connection per 81 * second. Requests in excess of this are first delayed, 82 * then throttled.</dd> 83 * <p/> 84 * <dt>delayMs</dt> 85 * <dd>is the delay given to all requests over the rate limit, 86 * before they are considered at all. -1 means just reject request, 87 * 0 means no delay, otherwise it is the delay.</dd> 88 * <p/> 89 * <dt>maxWaitMs</dt> 90 * <dd>how long to blocking wait for the throttle semaphore.</dd> 91 * <p/> 92 * <dt>throttledRequests</dt> 93 * <dd>is the number of requests over the rate limit able to be 94 * considered at once.</dd> 95 * <p/> 96 * <dt>throttleMs</dt> 97 * <dd>how long to async wait for semaphore.</dd> 98 * <p/> 99 * <dt>maxRequestMs</dt> 100 * <dd>how long to allow this request to run.</dd> 101 * <p/> 102 * <dt>maxIdleTrackerMs</dt> 103 * <dd>how long to keep track of request rates for a connection, 104 * before deciding that the user has gone away, and discarding it</dd> 105 * <p/> 106 * <dt>insertHeaders</dt> 107 * <dd>if true , insert the DoSFilter headers into the response. Defaults to true.</dd> 108 * <p/> 109 * <dt>trackSessions</dt> 110 * <dd>if true, usage rate is tracked by session if a session exists. Defaults to true.</dd> 111 * <p/> 112 * <dt>remotePort</dt> 113 * <dd>if true and session tracking is not used, then rate is tracked by IP+port (effectively connection). Defaults to false.</dd> 114 * <p/> 115 * <dt>ipWhitelist</dt> 116 * <dd>a comma-separated list of IP addresses that will not be rate limited</dd> 117 * <p/> 118 * <dt>managedAttr</dt> 119 * <dd>if set to true, then this servlet is set as a {@link ServletContext} attribute with the 120 * filter name as the attribute name. This allows context external mechanism (eg JMX via {@link ContextHandler#MANAGED_ATTRIBUTES}) to 121 * manage the configuration of the filter.</dd> 122 * </dl> 123 * </p> 124 */ 125 public class DoSFilter implements Filter 126 { 127 private static final Logger LOG = Log.getLogger(DoSFilter.class); 128 129 private static final String IPv4_GROUP = "(\\d{1,3})"; 130 private static final Pattern IPv4_PATTERN = Pattern.compile(IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP+"\\."+IPv4_GROUP); 131 private static final String IPv6_GROUP = "(\\p{XDigit}{1,4})"; 132 private static final Pattern IPv6_PATTERN = Pattern.compile(IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP+":"+IPv6_GROUP); 133 private static final Pattern CIDR_PATTERN = Pattern.compile("([^/]+)/(\\d+)"); 134 135 private static final String __TRACKER = "DoSFilter.Tracker"; 136 private static final String __THROTTLED = "DoSFilter.Throttled"; 137 138 private static final int __DEFAULT_MAX_REQUESTS_PER_SEC = 25; 139 private static final int __DEFAULT_DELAY_MS = 100; 140 private static final int __DEFAULT_THROTTLE = 5; 141 private static final int __DEFAULT_MAX_WAIT_MS = 50; 142 private static final long __DEFAULT_THROTTLE_MS = 30000L; 143 private static final long __DEFAULT_MAX_REQUEST_MS_INIT_PARAM = 30000L; 144 private static final long __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM = 30000L; 145 146 static final String MANAGED_ATTR_INIT_PARAM = "managedAttr"; 147 static final String MAX_REQUESTS_PER_S_INIT_PARAM = "maxRequestsPerSec"; 148 static final String DELAY_MS_INIT_PARAM = "delayMs"; 149 static final String THROTTLED_REQUESTS_INIT_PARAM = "throttledRequests"; 150 static final String MAX_WAIT_INIT_PARAM = "maxWaitMs"; 151 static final String THROTTLE_MS_INIT_PARAM = "throttleMs"; 152 static final String MAX_REQUEST_MS_INIT_PARAM = "maxRequestMs"; 153 static final String MAX_IDLE_TRACKER_MS_INIT_PARAM = "maxIdleTrackerMs"; 154 static final String INSERT_HEADERS_INIT_PARAM = "insertHeaders"; 155 static final String TRACK_SESSIONS_INIT_PARAM = "trackSessions"; 156 static final String REMOTE_PORT_INIT_PARAM = "remotePort"; 157 static final String IP_WHITELIST_INIT_PARAM = "ipWhitelist"; 158 static final String ENABLED_INIT_PARAM = "enabled"; 159 160 private static final int USER_AUTH = 2; 161 private static final int USER_SESSION = 2; 162 private static final int USER_IP = 1; 163 private static final int USER_UNKNOWN = 0; 164 165 private ServletContext _context; 166 private volatile long _delayMs; 167 private volatile long _throttleMs; 168 private volatile long _maxWaitMs; 169 private volatile long _maxRequestMs; 170 private volatile long _maxIdleTrackerMs; 171 private volatile boolean _insertHeaders; 172 private volatile boolean _trackSessions; 173 private volatile boolean _remotePort; 174 private volatile boolean _enabled; 175 private Semaphore _passes; 176 private volatile int _throttledRequests; 177 private volatile int _maxRequestsPerSec; 178 private Queue<Continuation>[] _queue; 179 private ContinuationListener[] _listeners; 180 private final ConcurrentHashMap<String, RateTracker> _rateTrackers = new ConcurrentHashMap<String, RateTracker>(); 181 private final List<String> _whitelist = new CopyOnWriteArrayList<String>(); 182 private final Timeout _requestTimeoutQ = new Timeout(); 183 private final Timeout _trackerTimeoutQ = new Timeout(); 184 private Thread _timerThread; 185 private volatile boolean _running; 186 init(FilterConfig filterConfig)187 public void init(FilterConfig filterConfig) 188 { 189 _context = filterConfig.getServletContext(); 190 191 _queue = new Queue[getMaxPriority() + 1]; 192 _listeners = new ContinuationListener[getMaxPriority() + 1]; 193 for (int p = 0; p < _queue.length; p++) 194 { 195 _queue[p] = new ConcurrentLinkedQueue<Continuation>(); 196 197 final int priority = p; 198 _listeners[p] = new ContinuationListener() 199 { 200 public void onComplete(Continuation continuation) 201 { 202 } 203 204 public void onTimeout(Continuation continuation) 205 { 206 _queue[priority].remove(continuation); 207 } 208 }; 209 } 210 211 _rateTrackers.clear(); 212 213 int maxRequests = __DEFAULT_MAX_REQUESTS_PER_SEC; 214 String parameter = filterConfig.getInitParameter(MAX_REQUESTS_PER_S_INIT_PARAM); 215 if (parameter != null) 216 maxRequests = Integer.parseInt(parameter); 217 setMaxRequestsPerSec(maxRequests); 218 219 long delay = __DEFAULT_DELAY_MS; 220 parameter = filterConfig.getInitParameter(DELAY_MS_INIT_PARAM); 221 if (parameter != null) 222 delay = Long.parseLong(parameter); 223 setDelayMs(delay); 224 225 int throttledRequests = __DEFAULT_THROTTLE; 226 parameter = filterConfig.getInitParameter(THROTTLED_REQUESTS_INIT_PARAM); 227 if (parameter != null) 228 throttledRequests = Integer.parseInt(parameter); 229 setThrottledRequests(throttledRequests); 230 231 long maxWait = __DEFAULT_MAX_WAIT_MS; 232 parameter = filterConfig.getInitParameter(MAX_WAIT_INIT_PARAM); 233 if (parameter != null) 234 maxWait = Long.parseLong(parameter); 235 setMaxWaitMs(maxWait); 236 237 long throttle = __DEFAULT_THROTTLE_MS; 238 parameter = filterConfig.getInitParameter(THROTTLE_MS_INIT_PARAM); 239 if (parameter != null) 240 throttle = Long.parseLong(parameter); 241 setThrottleMs(throttle); 242 243 long maxRequestMs = __DEFAULT_MAX_REQUEST_MS_INIT_PARAM; 244 parameter = filterConfig.getInitParameter(MAX_REQUEST_MS_INIT_PARAM); 245 if (parameter != null) 246 maxRequestMs = Long.parseLong(parameter); 247 setMaxRequestMs(maxRequestMs); 248 249 long maxIdleTrackerMs = __DEFAULT_MAX_IDLE_TRACKER_MS_INIT_PARAM; 250 parameter = filterConfig.getInitParameter(MAX_IDLE_TRACKER_MS_INIT_PARAM); 251 if (parameter != null) 252 maxIdleTrackerMs = Long.parseLong(parameter); 253 setMaxIdleTrackerMs(maxIdleTrackerMs); 254 255 String whiteList = ""; 256 parameter = filterConfig.getInitParameter(IP_WHITELIST_INIT_PARAM); 257 if (parameter != null) 258 whiteList = parameter; 259 setWhitelist(whiteList); 260 261 parameter = filterConfig.getInitParameter(INSERT_HEADERS_INIT_PARAM); 262 setInsertHeaders(parameter == null || Boolean.parseBoolean(parameter)); 263 264 parameter = filterConfig.getInitParameter(TRACK_SESSIONS_INIT_PARAM); 265 setTrackSessions(parameter == null || Boolean.parseBoolean(parameter)); 266 267 parameter = filterConfig.getInitParameter(REMOTE_PORT_INIT_PARAM); 268 setRemotePort(parameter != null && Boolean.parseBoolean(parameter)); 269 270 parameter = filterConfig.getInitParameter(ENABLED_INIT_PARAM); 271 setEnabled(parameter == null || Boolean.parseBoolean(parameter)); 272 273 _requestTimeoutQ.setNow(); 274 _requestTimeoutQ.setDuration(_maxRequestMs); 275 276 _trackerTimeoutQ.setNow(); 277 _trackerTimeoutQ.setDuration(_maxIdleTrackerMs); 278 279 _running = true; 280 _timerThread = (new Thread() 281 { 282 public void run() 283 { 284 try 285 { 286 while (_running) 287 { 288 long now = _requestTimeoutQ.setNow(); 289 _requestTimeoutQ.tick(); 290 _trackerTimeoutQ.setNow(now); 291 _trackerTimeoutQ.tick(); 292 try 293 { 294 Thread.sleep(100); 295 } 296 catch (InterruptedException e) 297 { 298 LOG.ignore(e); 299 } 300 } 301 } 302 finally 303 { 304 LOG.debug("DoSFilter timer exited"); 305 } 306 } 307 }); 308 _timerThread.start(); 309 310 if (_context != null && Boolean.parseBoolean(filterConfig.getInitParameter(MANAGED_ATTR_INIT_PARAM))) 311 _context.setAttribute(filterConfig.getFilterName(), this); 312 } 313 doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)314 public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException 315 { 316 doFilter((HttpServletRequest)request, (HttpServletResponse)response, filterChain); 317 } 318 doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)319 protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException 320 { 321 if (!isEnabled()) 322 { 323 filterChain.doFilter(request, response); 324 return; 325 } 326 327 final long now = _requestTimeoutQ.getNow(); 328 329 // Look for the rate tracker for this request 330 RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER); 331 332 if (tracker == null) 333 { 334 // This is the first time we have seen this request. 335 336 // get a rate tracker associated with this request, and record one hit 337 tracker = getRateTracker(request); 338 339 // Calculate the rate and check it is over the allowed limit 340 final boolean overRateLimit = tracker.isRateExceeded(now); 341 342 // pass it through if we are not currently over the rate limit 343 if (!overRateLimit) 344 { 345 doFilterChain(filterChain, request, response); 346 return; 347 } 348 349 // We are over the limit. 350 351 // So either reject it, delay it or throttle it 352 long delayMs = getDelayMs(); 353 boolean insertHeaders = isInsertHeaders(); 354 switch ((int)delayMs) 355 { 356 case -1: 357 { 358 // Reject this request 359 LOG.warn("DOS ALERT: Request rejected ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal()); 360 if (insertHeaders) 361 response.addHeader("DoSFilter", "unavailable"); 362 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); 363 return; 364 } 365 case 0: 366 { 367 // fall through to throttle code 368 LOG.warn("DOS ALERT: Request throttled ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal()); 369 request.setAttribute(__TRACKER, tracker); 370 break; 371 } 372 default: 373 { 374 // insert a delay before throttling the request 375 LOG.warn("DOS ALERT: Request delayed="+delayMs+"ms ip=" + request.getRemoteAddr() + ",session=" + request.getRequestedSessionId() + ",user=" + request.getUserPrincipal()); 376 if (insertHeaders) 377 response.addHeader("DoSFilter", "delayed"); 378 Continuation continuation = ContinuationSupport.getContinuation(request); 379 request.setAttribute(__TRACKER, tracker); 380 if (delayMs > 0) 381 continuation.setTimeout(delayMs); 382 continuation.suspend(); 383 return; 384 } 385 } 386 } 387 388 // Throttle the request 389 boolean accepted = false; 390 try 391 { 392 // check if we can afford to accept another request at this time 393 accepted = _passes.tryAcquire(getMaxWaitMs(), TimeUnit.MILLISECONDS); 394 395 if (!accepted) 396 { 397 // we were not accepted, so either we suspend to wait,or if we were woken up we insist or we fail 398 final Continuation continuation = ContinuationSupport.getContinuation(request); 399 400 Boolean throttled = (Boolean)request.getAttribute(__THROTTLED); 401 long throttleMs = getThrottleMs(); 402 if (throttled != Boolean.TRUE && throttleMs > 0) 403 { 404 int priority = getPriority(request, tracker); 405 request.setAttribute(__THROTTLED, Boolean.TRUE); 406 if (isInsertHeaders()) 407 response.addHeader("DoSFilter", "throttled"); 408 if (throttleMs > 0) 409 continuation.setTimeout(throttleMs); 410 continuation.suspend(); 411 412 continuation.addContinuationListener(_listeners[priority]); 413 _queue[priority].add(continuation); 414 return; 415 } 416 // else were we resumed? 417 else if (request.getAttribute("javax.servlet.resumed") == Boolean.TRUE) 418 { 419 // we were resumed and somebody stole our pass, so we wait for the next one. 420 _passes.acquire(); 421 accepted = true; 422 } 423 } 424 425 // if we were accepted (either immediately or after throttle) 426 if (accepted) 427 // call the chain 428 doFilterChain(filterChain, request, response); 429 else 430 { 431 // fail the request 432 if (isInsertHeaders()) 433 response.addHeader("DoSFilter", "unavailable"); 434 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); 435 } 436 } 437 catch (InterruptedException e) 438 { 439 _context.log("DoS", e); 440 response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE); 441 } 442 finally 443 { 444 if (accepted) 445 { 446 // wake up the next highest priority request. 447 for (int p = _queue.length; p-- > 0; ) 448 { 449 Continuation continuation = _queue[p].poll(); 450 if (continuation != null && continuation.isSuspended()) 451 { 452 continuation.resume(); 453 break; 454 } 455 } 456 _passes.release(); 457 } 458 } 459 } 460 doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response)461 protected void doFilterChain(FilterChain chain, final HttpServletRequest request, final HttpServletResponse response) throws IOException, ServletException 462 { 463 final Thread thread = Thread.currentThread(); 464 465 final Timeout.Task requestTimeout = new Timeout.Task() 466 { 467 public void expired() 468 { 469 closeConnection(request, response, thread); 470 } 471 }; 472 473 try 474 { 475 _requestTimeoutQ.schedule(requestTimeout); 476 chain.doFilter(request, response); 477 } 478 finally 479 { 480 requestTimeout.cancel(); 481 } 482 } 483 484 /** 485 * Takes drastic measures to return this response and stop this thread. 486 * Due to the way the connection is interrupted, may return mixed up headers. 487 * 488 * @param request current request 489 * @param response current response, which must be stopped 490 * @param thread the handling thread 491 */ closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)492 protected void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread) 493 { 494 // take drastic measures to return this response and stop this thread. 495 if (!response.isCommitted()) 496 { 497 response.setHeader("Connection", "close"); 498 } 499 try 500 { 501 try 502 { 503 response.getWriter().close(); 504 } 505 catch (IllegalStateException e) 506 { 507 response.getOutputStream().close(); 508 } 509 } 510 catch (IOException e) 511 { 512 LOG.warn(e); 513 } 514 515 // interrupt the handling thread 516 thread.interrupt(); 517 } 518 519 /** 520 * Get priority for this request, based on user type 521 * 522 * @param request the current request 523 * @param tracker the rate tracker for this request 524 * @return the priority for this request 525 */ getPriority(HttpServletRequest request, RateTracker tracker)526 protected int getPriority(HttpServletRequest request, RateTracker tracker) 527 { 528 if (extractUserId(request) != null) 529 return USER_AUTH; 530 if (tracker != null) 531 return tracker.getType(); 532 return USER_UNKNOWN; 533 } 534 535 /** 536 * @return the maximum priority that we can assign to a request 537 */ getMaxPriority()538 protected int getMaxPriority() 539 { 540 return USER_AUTH; 541 } 542 543 /** 544 * Return a request rate tracker associated with this connection; keeps 545 * track of this connection's request rate. If this is not the first request 546 * from this connection, return the existing object with the stored stats. 547 * If it is the first request, then create a new request tracker. 548 * <p/> 549 * Assumes that each connection has an identifying characteristic, and goes 550 * through them in order, taking the first that matches: user id (logged 551 * in), session id, client IP address. Unidentifiable connections are lumped 552 * into one. 553 * <p/> 554 * When a session expires, its rate tracker is automatically deleted. 555 * 556 * @param request the current request 557 * @return the request rate tracker for the current connection 558 */ getRateTracker(ServletRequest request)559 public RateTracker getRateTracker(ServletRequest request) 560 { 561 HttpSession session = ((HttpServletRequest)request).getSession(false); 562 563 String loadId = extractUserId(request); 564 final int type; 565 if (loadId != null) 566 { 567 type = USER_AUTH; 568 } 569 else 570 { 571 if (_trackSessions && session != null && !session.isNew()) 572 { 573 loadId = session.getId(); 574 type = USER_SESSION; 575 } 576 else 577 { 578 loadId = _remotePort ? (request.getRemoteAddr() + request.getRemotePort()) : request.getRemoteAddr(); 579 type = USER_IP; 580 } 581 } 582 583 RateTracker tracker = _rateTrackers.get(loadId); 584 585 if (tracker == null) 586 { 587 boolean allowed = checkWhitelist(_whitelist, request.getRemoteAddr()); 588 tracker = allowed ? new FixedRateTracker(loadId, type, _maxRequestsPerSec) 589 : new RateTracker(loadId, type, _maxRequestsPerSec); 590 RateTracker existing = _rateTrackers.putIfAbsent(loadId, tracker); 591 if (existing != null) 592 tracker = existing; 593 594 if (type == USER_IP) 595 { 596 // USER_IP expiration from _rateTrackers is handled by the _trackerTimeoutQ 597 _trackerTimeoutQ.schedule(tracker); 598 } 599 else if (session != null) 600 { 601 // USER_SESSION expiration from _rateTrackers are handled by the HttpSessionBindingListener 602 session.setAttribute(__TRACKER, tracker); 603 } 604 } 605 606 return tracker; 607 } 608 checkWhitelist(List<String> whitelist, String candidate)609 protected boolean checkWhitelist(List<String> whitelist, String candidate) 610 { 611 for (String address : whitelist) 612 { 613 if (address.contains("/")) 614 { 615 if (subnetMatch(address, candidate)) 616 return true; 617 } 618 else 619 { 620 if (address.equals(candidate)) 621 return true; 622 } 623 } 624 return false; 625 } 626 subnetMatch(String subnetAddress, String address)627 protected boolean subnetMatch(String subnetAddress, String address) 628 { 629 Matcher cidrMatcher = CIDR_PATTERN.matcher(subnetAddress); 630 if (!cidrMatcher.matches()) 631 return false; 632 633 String subnet = cidrMatcher.group(1); 634 int prefix; 635 try 636 { 637 prefix = Integer.parseInt(cidrMatcher.group(2)); 638 } 639 catch (NumberFormatException x) 640 { 641 LOG.info("Ignoring malformed CIDR address {}", subnetAddress); 642 return false; 643 } 644 645 byte[] subnetBytes = addressToBytes(subnet); 646 if (subnetBytes == null) 647 { 648 LOG.info("Ignoring malformed CIDR address {}", subnetAddress); 649 return false; 650 } 651 byte[] addressBytes = addressToBytes(address); 652 if (addressBytes == null) 653 { 654 LOG.info("Ignoring malformed remote address {}", address); 655 return false; 656 } 657 658 // Comparing IPv4 with IPv6 ? 659 int length = subnetBytes.length; 660 if (length != addressBytes.length) 661 return false; 662 663 byte[] mask = prefixToBytes(prefix, length); 664 665 for (int i = 0; i < length; ++i) 666 { 667 if ((subnetBytes[i] & mask[i]) != (addressBytes[i] & mask[i])) 668 return false; 669 } 670 671 return true; 672 } 673 addressToBytes(String address)674 private byte[] addressToBytes(String address) 675 { 676 Matcher ipv4Matcher = IPv4_PATTERN.matcher(address); 677 if (ipv4Matcher.matches()) 678 { 679 byte[] result = new byte[4]; 680 for (int i = 0; i < result.length; ++i) 681 result[i] = Integer.valueOf(ipv4Matcher.group(i + 1)).byteValue(); 682 return result; 683 } 684 else 685 { 686 Matcher ipv6Matcher = IPv6_PATTERN.matcher(address); 687 if (ipv6Matcher.matches()) 688 { 689 byte[] result = new byte[16]; 690 for (int i = 0; i < result.length; i += 2) 691 { 692 int word = Integer.valueOf(ipv6Matcher.group(i / 2 + 1), 16); 693 result[i] = (byte)((word & 0xFF00) >>> 8); 694 result[i + 1] = (byte)(word & 0xFF); 695 } 696 return result; 697 } 698 } 699 return null; 700 } 701 prefixToBytes(int prefix, int length)702 private byte[] prefixToBytes(int prefix, int length) 703 { 704 byte[] result = new byte[length]; 705 int index = 0; 706 while (prefix / 8 > 0) 707 { 708 result[index] = -1; 709 prefix -= 8; 710 ++index; 711 } 712 // Sets the _prefix_ most significant bits to 1 713 result[index] = (byte)~((1 << (8 - prefix)) - 1); 714 return result; 715 } 716 destroy()717 public void destroy() 718 { 719 LOG.debug("Destroy {}",this); 720 _running = false; 721 _timerThread.interrupt(); 722 _requestTimeoutQ.cancelAll(); 723 _trackerTimeoutQ.cancelAll(); 724 _rateTrackers.clear(); 725 _whitelist.clear(); 726 } 727 728 /** 729 * Returns the user id, used to track this connection. 730 * This SHOULD be overridden by subclasses. 731 * 732 * @param request the current request 733 * @return a unique user id, if logged in; otherwise null. 734 */ extractUserId(ServletRequest request)735 protected String extractUserId(ServletRequest request) 736 { 737 return null; 738 } 739 740 /** 741 * Get maximum number of requests from a connection per 742 * second. Requests in excess of this are first delayed, 743 * then throttled. 744 * 745 * @return maximum number of requests 746 */ getMaxRequestsPerSec()747 public int getMaxRequestsPerSec() 748 { 749 return _maxRequestsPerSec; 750 } 751 752 /** 753 * Get maximum number of requests from a connection per 754 * second. Requests in excess of this are first delayed, 755 * then throttled. 756 * 757 * @param value maximum number of requests 758 */ setMaxRequestsPerSec(int value)759 public void setMaxRequestsPerSec(int value) 760 { 761 _maxRequestsPerSec = value; 762 } 763 764 /** 765 * Get delay (in milliseconds) that is applied to all requests 766 * over the rate limit, before they are considered at all. 767 */ getDelayMs()768 public long getDelayMs() 769 { 770 return _delayMs; 771 } 772 773 /** 774 * Set delay (in milliseconds) that is applied to all requests 775 * over the rate limit, before they are considered at all. 776 * 777 * @param value delay (in milliseconds), 0 - no delay, -1 - reject request 778 */ setDelayMs(long value)779 public void setDelayMs(long value) 780 { 781 _delayMs = value; 782 } 783 784 /** 785 * Get maximum amount of time (in milliseconds) the filter will 786 * blocking wait for the throttle semaphore. 787 * 788 * @return maximum wait time 789 */ getMaxWaitMs()790 public long getMaxWaitMs() 791 { 792 return _maxWaitMs; 793 } 794 795 /** 796 * Set maximum amount of time (in milliseconds) the filter will 797 * blocking wait for the throttle semaphore. 798 * 799 * @param value maximum wait time 800 */ setMaxWaitMs(long value)801 public void setMaxWaitMs(long value) 802 { 803 _maxWaitMs = value; 804 } 805 806 /** 807 * Get number of requests over the rate limit able to be 808 * considered at once. 809 * 810 * @return number of requests 811 */ getThrottledRequests()812 public int getThrottledRequests() 813 { 814 return _throttledRequests; 815 } 816 817 /** 818 * Set number of requests over the rate limit able to be 819 * considered at once. 820 * 821 * @param value number of requests 822 */ setThrottledRequests(int value)823 public void setThrottledRequests(int value) 824 { 825 int permits = _passes == null ? 0 : _passes.availablePermits(); 826 _passes = new Semaphore((value - _throttledRequests + permits), true); 827 _throttledRequests = value; 828 } 829 830 /** 831 * Get amount of time (in milliseconds) to async wait for semaphore. 832 * 833 * @return wait time 834 */ getThrottleMs()835 public long getThrottleMs() 836 { 837 return _throttleMs; 838 } 839 840 /** 841 * Set amount of time (in milliseconds) to async wait for semaphore. 842 * 843 * @param value wait time 844 */ setThrottleMs(long value)845 public void setThrottleMs(long value) 846 { 847 _throttleMs = value; 848 } 849 850 /** 851 * Get maximum amount of time (in milliseconds) to allow 852 * the request to process. 853 * 854 * @return maximum processing time 855 */ getMaxRequestMs()856 public long getMaxRequestMs() 857 { 858 return _maxRequestMs; 859 } 860 861 /** 862 * Set maximum amount of time (in milliseconds) to allow 863 * the request to process. 864 * 865 * @param value maximum processing time 866 */ setMaxRequestMs(long value)867 public void setMaxRequestMs(long value) 868 { 869 _maxRequestMs = value; 870 } 871 872 /** 873 * Get maximum amount of time (in milliseconds) to keep track 874 * of request rates for a connection, before deciding that 875 * the user has gone away, and discarding it. 876 * 877 * @return maximum tracking time 878 */ getMaxIdleTrackerMs()879 public long getMaxIdleTrackerMs() 880 { 881 return _maxIdleTrackerMs; 882 } 883 884 /** 885 * Set maximum amount of time (in milliseconds) to keep track 886 * of request rates for a connection, before deciding that 887 * the user has gone away, and discarding it. 888 * 889 * @param value maximum tracking time 890 */ setMaxIdleTrackerMs(long value)891 public void setMaxIdleTrackerMs(long value) 892 { 893 _maxIdleTrackerMs = value; 894 } 895 896 /** 897 * Check flag to insert the DoSFilter headers into the response. 898 * 899 * @return value of the flag 900 */ isInsertHeaders()901 public boolean isInsertHeaders() 902 { 903 return _insertHeaders; 904 } 905 906 /** 907 * Set flag to insert the DoSFilter headers into the response. 908 * 909 * @param value value of the flag 910 */ setInsertHeaders(boolean value)911 public void setInsertHeaders(boolean value) 912 { 913 _insertHeaders = value; 914 } 915 916 /** 917 * Get flag to have usage rate tracked by session if a session exists. 918 * 919 * @return value of the flag 920 */ isTrackSessions()921 public boolean isTrackSessions() 922 { 923 return _trackSessions; 924 } 925 926 /** 927 * Set flag to have usage rate tracked by session if a session exists. 928 * 929 * @param value value of the flag 930 */ setTrackSessions(boolean value)931 public void setTrackSessions(boolean value) 932 { 933 _trackSessions = value; 934 } 935 936 /** 937 * Get flag to have usage rate tracked by IP+port (effectively connection) 938 * if session tracking is not used. 939 * 940 * @return value of the flag 941 */ isRemotePort()942 public boolean isRemotePort() 943 { 944 return _remotePort; 945 } 946 947 /** 948 * Set flag to have usage rate tracked by IP+port (effectively connection) 949 * if session tracking is not used. 950 * 951 * @param value value of the flag 952 */ setRemotePort(boolean value)953 public void setRemotePort(boolean value) 954 { 955 _remotePort = value; 956 } 957 958 /** 959 * @return whether this filter is enabled 960 */ isEnabled()961 public boolean isEnabled() 962 { 963 return _enabled; 964 } 965 966 /** 967 * @param enabled whether this filter is enabled 968 */ setEnabled(boolean enabled)969 public void setEnabled(boolean enabled) 970 { 971 _enabled = enabled; 972 } 973 974 /** 975 * Get a list of IP addresses that will not be rate limited. 976 * 977 * @return comma-separated whitelist 978 */ getWhitelist()979 public String getWhitelist() 980 { 981 StringBuilder result = new StringBuilder(); 982 for (Iterator<String> iterator = _whitelist.iterator(); iterator.hasNext();) 983 { 984 String address = iterator.next(); 985 result.append(address); 986 if (iterator.hasNext()) 987 result.append(","); 988 } 989 return result.toString(); 990 } 991 992 /** 993 * Set a list of IP addresses that will not be rate limited. 994 * 995 * @param value comma-separated whitelist 996 */ setWhitelist(String value)997 public void setWhitelist(String value) 998 { 999 List<String> result = new ArrayList<String>(); 1000 for (String address : value.split(",")) 1001 addWhitelistAddress(result, address); 1002 _whitelist.clear(); 1003 _whitelist.addAll(result); 1004 LOG.debug("Whitelisted IP addresses: {}", result); 1005 } 1006 clearWhitelist()1007 public void clearWhitelist() 1008 { 1009 _whitelist.clear(); 1010 } 1011 addWhitelistAddress(String address)1012 public boolean addWhitelistAddress(String address) 1013 { 1014 return addWhitelistAddress(_whitelist, address); 1015 } 1016 addWhitelistAddress(List<String> list, String address)1017 private boolean addWhitelistAddress(List<String> list, String address) 1018 { 1019 address = address.trim(); 1020 return address.length() > 0 && list.add(address); 1021 } 1022 removeWhitelistAddress(String address)1023 public boolean removeWhitelistAddress(String address) 1024 { 1025 return _whitelist.remove(address); 1026 } 1027 1028 /** 1029 * A RateTracker is associated with a connection, and stores request rate 1030 * data. 1031 */ 1032 class RateTracker extends Timeout.Task implements HttpSessionBindingListener, HttpSessionActivationListener, Serializable 1033 { 1034 private static final long serialVersionUID = 3534663738034577872L; 1035 1036 transient protected final String _id; 1037 transient protected final int _type; 1038 transient protected final long[] _timestamps; 1039 transient protected int _next; 1040 RateTracker(String id, int type, int maxRequestsPerSecond)1041 public RateTracker(String id, int type, int maxRequestsPerSecond) 1042 { 1043 _id = id; 1044 _type = type; 1045 _timestamps = new long[maxRequestsPerSecond]; 1046 _next = 0; 1047 } 1048 1049 /** 1050 * @return the current calculated request rate over the last second 1051 */ isRateExceeded(long now)1052 public boolean isRateExceeded(long now) 1053 { 1054 final long last; 1055 synchronized (this) 1056 { 1057 last = _timestamps[_next]; 1058 _timestamps[_next] = now; 1059 _next = (_next + 1) % _timestamps.length; 1060 } 1061 1062 return last != 0 && (now - last) < 1000L; 1063 } 1064 getId()1065 public String getId() 1066 { 1067 return _id; 1068 } 1069 getType()1070 public int getType() 1071 { 1072 return _type; 1073 } 1074 valueBound(HttpSessionBindingEvent event)1075 public void valueBound(HttpSessionBindingEvent event) 1076 { 1077 if (LOG.isDebugEnabled()) 1078 LOG.debug("Value bound: {}", getId()); 1079 } 1080 valueUnbound(HttpSessionBindingEvent event)1081 public void valueUnbound(HttpSessionBindingEvent event) 1082 { 1083 //take the tracker out of the list of trackers 1084 _rateTrackers.remove(_id); 1085 if (LOG.isDebugEnabled()) 1086 LOG.debug("Tracker removed: {}", getId()); 1087 } 1088 sessionWillPassivate(HttpSessionEvent se)1089 public void sessionWillPassivate(HttpSessionEvent se) 1090 { 1091 //take the tracker of the list of trackers (if its still there) 1092 //and ensure that we take ourselves out of the session so we are not saved 1093 _rateTrackers.remove(_id); 1094 se.getSession().removeAttribute(__TRACKER); 1095 if (LOG.isDebugEnabled()) LOG.debug("Value removed: {}", getId()); 1096 } 1097 sessionDidActivate(HttpSessionEvent se)1098 public void sessionDidActivate(HttpSessionEvent se) 1099 { 1100 LOG.warn("Unexpected session activation"); 1101 } 1102 expired()1103 public void expired() 1104 { 1105 long now = _trackerTimeoutQ.getNow(); 1106 int latestIndex = _next == 0 ? (_timestamps.length - 1) : (_next - 1); 1107 long last = _timestamps[latestIndex]; 1108 boolean hasRecentRequest = last != 0 && (now - last) < 1000L; 1109 1110 if (hasRecentRequest) 1111 reschedule(); 1112 else 1113 _rateTrackers.remove(_id); 1114 } 1115 1116 @Override 1117 public String toString() 1118 { 1119 return "RateTracker/" + _id + "/" + _type; 1120 } 1121 } 1122 1123 class FixedRateTracker extends RateTracker 1124 { 1125 public FixedRateTracker(String id, int type, int numRecentRequestsTracked) 1126 { 1127 super(id, type, numRecentRequestsTracked); 1128 } 1129 1130 @Override 1131 public boolean isRateExceeded(long now) 1132 { 1133 // rate limit is never exceeded, but we keep track of the request timestamps 1134 // so that we know whether there was recent activity on this tracker 1135 // and whether it should be expired 1136 synchronized (this) 1137 { 1138 _timestamps[_next] = now; 1139 _next = (_next + 1) % _timestamps.length; 1140 } 1141 1142 return false; 1143 } 1144 1145 @Override 1146 public String toString() 1147 { 1148 return "Fixed" + super.toString(); 1149 } 1150 } 1151 } 1152