1#!/usr/bin/env python
2# Copyright 2010 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import BaseHTTPServer
17import certutils
18import collections
19import errno
20import logging
21import socket
22import SocketServer
23import ssl
24import sys
25import time
26import urlparse
27
28import daemonserver
29import httparchive
30import platformsettings
31import proxyshaper
32import sslproxy
33
34def _HandleSSLCertificateError():
35  """
36  This method is intended to be called from
37  BaseHTTPServer.HTTPServer.handle_error().
38  """
39  exc_type, exc_value, exc_traceback = sys.exc_info()
40  if isinstance(exc_value, ssl.SSLError):
41    return
42
43  raise
44
45
46class HttpProxyError(Exception):
47  """Module catch-all error."""
48  pass
49
50
51class HttpProxyServerError(HttpProxyError):
52  """Raised for errors like 'Address already in use'."""
53  pass
54
55
56class HttpArchiveHandler(BaseHTTPServer.BaseHTTPRequestHandler):
57  protocol_version = 'HTTP/1.1'  # override BaseHTTPServer setting
58
59  # Since we do lots of small wfile.write() calls, turn on buffering.
60  wbufsize = -1  # override StreamRequestHandler (a base class) setting
61
62  def setup(self):
63    """Override StreamRequestHandler method."""
64    BaseHTTPServer.BaseHTTPRequestHandler.setup(self)
65    if self.server.traffic_shaping_up_bps:
66      self.rfile = proxyshaper.RateLimitedFile(
67          self.server.get_active_request_count, self.rfile,
68          self.server.traffic_shaping_up_bps)
69    if self.server.traffic_shaping_down_bps:
70      self.wfile = proxyshaper.RateLimitedFile(
71          self.server.get_active_request_count, self.wfile,
72          self.server.traffic_shaping_down_bps)
73
74  # Make request handler logging match our logging format.
75  def log_request(self, code='-', size='-'):
76    pass
77
78  def log_error(self, format, *args):  # pylint:disable=redefined-builtin
79    logging.error(format, *args)
80
81  def log_message(self, format, *args):  # pylint:disable=redefined-builtin
82    logging.info(format, *args)
83
84  def read_request_body(self):
85    request_body = None
86    length = int(self.headers.get('content-length', 0)) or None
87    if length:
88      request_body = self.rfile.read(length)
89    return request_body
90
91  def get_header_dict(self):
92    return dict(self.headers.items())
93
94  def get_archived_http_request(self):
95    host = self.headers.get('host')
96    if host is None:
97      logging.error('Request without host header')
98      return None
99
100    parsed = urlparse.urlparse(self.path)
101    params = ';%s' % parsed.params if parsed.params else ''
102    query = '?%s' % parsed.query if parsed.query else ''
103    fragment = '#%s' % parsed.fragment if parsed.fragment else ''
104    full_path = '%s%s%s%s' % (parsed.path, params, query, fragment)
105
106    StubRequest = collections.namedtuple('StubRequest', ('host', 'full_path'))
107    request, response = StubRequest(host, full_path), None
108
109    self.server.log_url(request, response)
110
111    return httparchive.ArchivedHttpRequest(
112        self.command,
113        host,
114        full_path,
115        self.read_request_body(),
116        self.get_header_dict(),
117        self.server.is_ssl)
118
119  def send_archived_http_response(self, response):
120    try:
121      # We need to set the server name before we start the response.
122      is_chunked = response.is_chunked()
123      has_content_length = response.get_header('content-length') is not None
124      self.server_version = response.get_header('server', 'WebPageReplay')
125      self.sys_version = ''
126
127      if response.version == 10:
128        self.protocol_version = 'HTTP/1.0'
129
130      # If we don't have chunked encoding and there is no content length,
131      # we need to manually compute the content-length.
132      if not is_chunked and not has_content_length:
133        content_length = sum(len(c) for c in response.response_data)
134        response.headers.append(('content-length', str(content_length)))
135
136      is_replay = not self.server.http_archive_fetch.is_record_mode
137      if is_replay and self.server.traffic_shaping_delay_ms:
138        logging.debug('Using round trip delay: %sms',
139                      self.server.traffic_shaping_delay_ms)
140        time.sleep(self.server.traffic_shaping_delay_ms / 1000.0)
141      if is_replay and self.server.use_delays:
142        logging.debug('Using delays (ms): %s', response.delays)
143        time.sleep(response.delays['headers'] / 1000.0)
144        delays = response.delays['data']
145      else:
146        delays = [0] * len(response.response_data)
147      self.send_response(response.status, response.reason)
148      # TODO(mbelshe): This is lame - each write is a packet!
149      for header, value in response.headers:
150        if header in ('last-modified', 'expires'):
151          self.send_header(header, response.update_date(value))
152        elif header not in ('date', 'server'):
153          self.send_header(header, value)
154      self.end_headers()
155
156      for chunk, delay in zip(response.response_data, delays):
157        if delay:
158          self.wfile.flush()
159          time.sleep(delay / 1000.0)
160        if is_chunked:
161          # Write chunk length (hex) and data (e.g. "A\r\nTESSELATED\r\n").
162          self.wfile.write('%x\r\n%s\r\n' % (len(chunk), chunk))
163        else:
164          self.wfile.write(chunk)
165      if is_chunked:
166        self.wfile.write('0\r\n\r\n')  # write final, zero-length chunk.
167      self.wfile.flush()
168
169      # TODO(mbelshe): This connection close doesn't seem to work.
170      if response.version == 10:
171        self.close_connection = 1
172
173    except Exception, e:
174      logging.error('Error sending response for %s%s: %s',
175                    self.headers['host'], self.path, e)
176
177  def handle_one_request(self):
178    """Handle a single HTTP request.
179
180    This method overrides a method from BaseHTTPRequestHandler. When this
181    method returns, it must leave self.close_connection in the correct state.
182    If this method raises an exception, the state of self.close_connection
183    doesn't matter.
184    """
185    try:
186      self.raw_requestline = self.rfile.readline(65537)
187      self.do_parse_and_handle_one_request()
188    except socket.timeout, e:
189      # A read or a write timed out.  Discard this connection
190      self.log_error('Request timed out: %r', e)
191      self.close_connection = 1
192      return
193    except ssl.SSLError:
194      # There is insufficient information passed up the stack from OpenSSL to
195      # determine the true cause of the SSL error. This almost always happens
196      # because the client refuses to accept the self-signed certs of
197      # WebPageReplay.
198      self.close_connection = 1
199      return
200    except socket.error, e:
201      # Connection reset errors happen all the time due to the browser closing
202      # without terminating the connection properly.  They can be safely
203      # ignored.
204      if e[0] == errno.ECONNRESET:
205        self.close_connection = 1
206        return
207      raise
208
209
210  def do_parse_and_handle_one_request(self):
211    start_time = time.time()
212    self.server.num_active_requests += 1
213    request = None
214    try:
215      if len(self.raw_requestline) > 65536:
216        self.requestline = ''
217        self.request_version = ''
218        self.command = ''
219        self.send_error(414)
220        self.close_connection = 0
221        return
222      if not self.raw_requestline:
223        # This indicates that the socket has been closed by the client.
224        self.close_connection = 1
225        return
226
227      # self.parse_request() sets self.close_connection. There is no need to
228      # set the property after the method is executed, unless custom behavior
229      # is desired.
230      if not self.parse_request():
231        # An error code has been sent, just exit.
232        return
233
234      try:
235        response = None
236        request = self.get_archived_http_request()
237
238        if request is None:
239          self.send_error(500)
240          return
241        response = self.server.custom_handlers.handle(request)
242        if not response:
243          response = self.server.http_archive_fetch(request)
244          if (response and response.status == 200 and
245              self.server.allow_generate_304 and
246              request.command in set(['GET', 'HEAD']) and
247              (request.headers.get('if-modified-since', None) or
248               request.headers.get('if-none-match', None))):
249            # The WPR archive never get modified since it is not being recorded.
250            response = httparchive.create_response(
251                status=304, headers=response.headers)
252        if response:
253          self.send_archived_http_response(response)
254        else:
255          self.send_error(404)
256      finally:
257        self.wfile.flush()  # Actually send the response if not already done.
258    finally:
259      request_time_ms = (time.time() - start_time) * 1000.0
260      self.server.total_request_time += request_time_ms
261      if request:
262        if response:
263          logging.debug('Served: %s (%dms)', request, request_time_ms)
264        else:
265          logging.warning('Failed to find response for: %s (%dms)',
266                          request, request_time_ms)
267      self.server.num_active_requests -= 1
268
269  def send_error(self, status, body=None):
270    """Override the default send error with a version that doesn't unnecessarily
271    close the connection.
272    """
273    response = httparchive.create_response(status, body=body)
274    self.send_archived_http_response(response)
275
276
277class HttpProxyServer(SocketServer.ThreadingMixIn,
278                      BaseHTTPServer.HTTPServer,
279                      daemonserver.DaemonServer):
280  HANDLER = HttpArchiveHandler
281
282  # Increase the request queue size. The default value, 5, is set in
283  # SocketServer.TCPServer (the parent of BaseHTTPServer.HTTPServer).
284  # Since we're intercepting many domains through this single server,
285  # it is quite possible to get more than 5 concurrent requests.
286  request_queue_size = 256
287
288  # The number of simultaneous connections that the HTTP server supports. This
289  # is primarily limited by system limits such as RLIMIT_NOFILE.
290  connection_limit = 500
291
292  # Allow sockets to be reused. See
293  # http://svn.python.org/projects/python/trunk/Lib/SocketServer.py for more
294  # details.
295  allow_reuse_address = True
296
297  # Don't prevent python from exiting when there is thread activity.
298  daemon_threads = True
299
300  def __init__(self, http_archive_fetch, custom_handlers, rules,
301               host='localhost', port=80, use_delays=False, is_ssl=False,
302               protocol='HTTP', allow_generate_304=False,
303               down_bandwidth='0', up_bandwidth='0', delay_ms='0'):
304    """Start HTTP server.
305
306    Args:
307      rules: a rule_parser Rules.
308      host: a host string (name or IP) for the web proxy.
309      port: a port string (e.g. '80') for the web proxy.
310      use_delays: if True, add response data delays during replay.
311      is_ssl: True iff proxy is using SSL.
312      up_bandwidth: Upload bandwidth
313      down_bandwidth: Download bandwidth
314           Bandwidths measured in [K|M]{bit/s|Byte/s}. '0' means unlimited.
315      delay_ms: Propagation delay in milliseconds. '0' means no delay.
316    """
317    if platformsettings.SupportsFdLimitControl():
318      # BaseHTTPServer opens a new thread and two fds for each connection.
319      # Check that the process can open at least 1000 fds.
320      soft_limit, hard_limit = platformsettings.GetFdLimit()
321      # Add some wiggle room since there are probably fds not associated with
322      # connections.
323      wiggle_room = 100
324      desired_limit = 2 * HttpProxyServer.connection_limit + wiggle_room
325      if soft_limit < desired_limit:
326        assert desired_limit <= hard_limit, (
327            'The hard limit for number of open files per process is %s which '
328            'is lower than the desired limit of %s.' %
329            (hard_limit, desired_limit))
330        platformsettings.AdjustFdLimit(desired_limit, hard_limit)
331
332    try:
333      BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER)
334    except Exception, e:
335      raise HttpProxyServerError('Could not start HTTPServer on port %d: %s' %
336                                 (port, e))
337    self.http_archive_fetch = http_archive_fetch
338    self.custom_handlers = custom_handlers
339    self.use_delays = use_delays
340    self.is_ssl = is_ssl
341    self.traffic_shaping_down_bps = proxyshaper.GetBitsPerSecond(down_bandwidth)
342    self.traffic_shaping_up_bps = proxyshaper.GetBitsPerSecond(up_bandwidth)
343    self.traffic_shaping_delay_ms = int(delay_ms)
344    self.num_active_requests = 0
345    self.num_active_connections = 0
346    self.total_request_time = 0
347    self.protocol = protocol
348    self.allow_generate_304 = allow_generate_304
349    self.log_url = rules.Find('log_url')
350
351    # Note: This message may be scraped. Do not change it.
352    logging.warning(
353        '%s server started on %s:%d' % (self.protocol, self.server_address[0],
354                                        self.server_address[1]))
355
356  def cleanup(self):
357    try:
358      self.shutdown()
359      self.server_close()
360    except KeyboardInterrupt:
361      pass
362    logging.info('Stopped %s server. Total time processing requests: %dms',
363                 self.protocol, self.total_request_time)
364
365  def get_active_request_count(self):
366    return self.num_active_requests
367
368  def get_request(self):
369    self.num_active_connections += 1
370    if self.num_active_connections >= HttpProxyServer.connection_limit:
371      logging.error(
372          'Number of active connections (%s) surpasses the '
373          'supported limit of %s.' %
374          (self.num_active_connections, HttpProxyServer.connection_limit))
375    return BaseHTTPServer.HTTPServer.get_request(self)
376
377  def close_request(self, request):
378    BaseHTTPServer.HTTPServer.close_request(self, request)
379    self.num_active_connections -= 1
380
381
382class HttpsProxyServer(HttpProxyServer):
383  """SSL server that generates certs for each host."""
384
385  def __init__(self, http_archive_fetch, custom_handlers, rules,
386               https_root_ca_cert_path, **kwargs):
387    self.ca_cert_path = https_root_ca_cert_path
388    self.HANDLER = sslproxy.wrap_handler(HttpArchiveHandler)
389    HttpProxyServer.__init__(self, http_archive_fetch, custom_handlers, rules,
390                             is_ssl=True, protocol='HTTPS', **kwargs)
391    with open(self.ca_cert_path, 'r') as cert_file:
392      self._ca_cert_str = cert_file.read()
393    self._host_to_cert_map = {}
394    self._server_cert_to_cert_map = {}
395
396  def cleanup(self):
397    try:
398      self.shutdown()
399      self.server_close()
400    except KeyboardInterrupt:
401      pass
402
403  def get_certificate(self, host):
404    if host in self._host_to_cert_map:
405      return self._host_to_cert_map[host]
406
407    server_cert = self.http_archive_fetch.http_archive.get_server_cert(host)
408    if server_cert in self._server_cert_to_cert_map:
409      cert = self._server_cert_to_cert_map[server_cert]
410      self._host_to_cert_map[host] = cert
411      return cert
412
413    cert = certutils.generate_cert(self._ca_cert_str, server_cert, host)
414    self._server_cert_to_cert_map[server_cert] = cert
415    self._host_to_cert_map[host] = cert
416    return cert
417
418  def handle_error(self, request, client_address):
419    _HandleSSLCertificateError()
420
421
422class SingleCertHttpsProxyServer(HttpProxyServer):
423  """SSL server."""
424
425  def __init__(self, http_archive_fetch, custom_handlers, rules,
426               https_root_ca_cert_path, **kwargs):
427    HttpProxyServer.__init__(self, http_archive_fetch, custom_handlers, rules,
428                             is_ssl=True, protocol='HTTPS', **kwargs)
429    self.socket = ssl.wrap_socket(
430        self.socket, certfile=https_root_ca_cert_path, server_side=True,
431        do_handshake_on_connect=False)
432    # Ancestor class, DaemonServer, calls serve_forever() during its __init__.
433
434  def handle_error(self, request, client_address):
435    _HandleSSLCertificateError()
436
437
438class HttpToHttpsProxyServer(HttpProxyServer):
439  """Listens for HTTP requests but sends them to the target as HTTPS requests"""
440
441  def __init__(self, http_archive_fetch, custom_handlers, rules, **kwargs):
442    HttpProxyServer.__init__(self, http_archive_fetch, custom_handlers, rules,
443                             is_ssl=True, protocol='HTTP-to-HTTPS', **kwargs)
444
445  def handle_error(self, request, client_address):
446    _HandleSSLCertificateError()
447