1#!/usr/bin/env python
2# Copyright 2012 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Retrieve web resources over http."""
17
18import copy
19import datetime
20import httplib
21import logging
22import random
23import ssl
24import StringIO
25
26import httparchive
27import platformsettings
28import script_injector
29
30
31# PIL isn't always available, but we still want to be able to run without
32# the image scrambling functionality in this case.
33try:
34  import Image
35except ImportError:
36  Image = None
37
38TIMER = platformsettings.timer
39
40
41class HttpClientException(Exception):
42  """Base class for all exceptions in httpclient."""
43  pass
44
45
46def _InjectScripts(response, injector):
47  """Injects script generated by |injector| immediately after <head> or <html>.
48
49  Copies |response| if it is modified.
50
51  Args:
52    response: an ArchivedHttpResponse
53    injector: function which generates JavaScript string
54      based on recording time (e.g. "Math.random = function(){...}")
55  Returns:
56    an ArchivedHttpResponse
57  """
58  if type(response) == tuple:
59    logging.warn('tuple response: %s', response)
60  content_type = response.get_header('content-type')
61  if content_type and content_type.startswith('text/html'):
62    text_chunks = response.get_data_as_chunks()
63    text_chunks, just_injected = script_injector.InjectScript(
64        text_chunks, 'text/html', injector(response.request_time))
65    if just_injected:
66      response = copy.deepcopy(response)
67      response.set_data_from_chunks(text_chunks)
68  return response
69
70
71def _ScrambleImages(response):
72  """If the |response| is an image, attempt to scramble it.
73
74  Copies |response| if it is modified.
75
76  Args:
77    response: an ArchivedHttpResponse
78  Returns:
79    an ArchivedHttpResponse
80  """
81
82  assert Image, '--scramble_images requires the PIL module to be installed.'
83
84  content_type = response.get_header('content-type')
85  if content_type and content_type.startswith('image/'):
86    try:
87      image_data = response.response_data[0]
88      image_data.decode(encoding='base64')
89      im = Image.open(StringIO.StringIO(image_data))
90
91      pixel_data = list(im.getdata())
92      random.shuffle(pixel_data)
93
94      scrambled_image = im.copy()
95      scrambled_image.putdata(pixel_data)
96
97      output_image_io = StringIO.StringIO()
98      scrambled_image.save(output_image_io, im.format)
99      output_image_data = output_image_io.getvalue()
100      output_image_data.encode(encoding='base64')
101
102      response = copy.deepcopy(response)
103      response.set_data(output_image_data)
104    except Exception:
105      pass
106
107  return response
108
109
110class DetailedHTTPResponse(httplib.HTTPResponse):
111  """Preserve details relevant to replaying responses.
112
113  WARNING: This code uses attributes and methods of HTTPResponse
114  that are not part of the public interface.
115  """
116
117  def read_chunks(self):
118    """Return the response body content and timing data.
119
120    The returned chunks have the chunk size and CRLFs stripped off.
121    If the response was compressed, the returned data is still compressed.
122
123    Returns:
124      (chunks, delays)
125        chunks:
126          [response_body]                  # non-chunked responses
127          [chunk_1, chunk_2, ...]          # chunked responses
128        delays:
129          [0]                              # non-chunked responses
130          [chunk_1_first_byte_delay, ...]  # chunked responses
131
132      The delay for the first body item should be recorded by the caller.
133    """
134    buf = []
135    chunks = []
136    delays = []
137    if not self.chunked:
138      chunks.append(self.read())
139      delays.append(0)
140    else:
141      start = TIMER()
142      try:
143        while True:
144          line = self.fp.readline()
145          chunk_size = self._read_chunk_size(line)
146          if chunk_size is None:
147            raise httplib.IncompleteRead(''.join(chunks))
148          if chunk_size == 0:
149            break
150          delays.append(TIMER() - start)
151          chunks.append(self._safe_read(chunk_size))
152          self._safe_read(2)  # skip the CRLF at the end of the chunk
153          start = TIMER()
154
155        # Ignore any trailers.
156        while True:
157          line = self.fp.readline()
158          if not line or line == '\r\n':
159            break
160      finally:
161        self.close()
162    return chunks, delays
163
164  @classmethod
165  def _read_chunk_size(cls, line):
166    chunk_extensions_pos = line.find(';')
167    if chunk_extensions_pos != -1:
168      line = line[:chunk_extensions_pos]  # strip chunk-extensions
169    try:
170      chunk_size = int(line, 16)
171    except ValueError:
172      return None
173    return chunk_size
174
175
176class DetailedHTTPConnection(httplib.HTTPConnection):
177  """Preserve details relevant to replaying connections."""
178  response_class = DetailedHTTPResponse
179
180
181class DetailedHTTPSResponse(DetailedHTTPResponse):
182  """Preserve details relevant to replaying SSL responses."""
183  pass
184
185
186class DetailedHTTPSConnection(httplib.HTTPSConnection):
187  """Preserve details relevant to replaying SSL connections."""
188  response_class = DetailedHTTPSResponse
189
190  def __init__(self, host, port):
191    # https://www.python.org/dev/peps/pep-0476/#opting-out
192    if hasattr(ssl, '_create_unverified_context'):
193      httplib.HTTPSConnection.__init__(
194          self, host=host, port=port, context=ssl._create_unverified_context())
195    else:
196      httplib.HTTPSConnection.__init__(self, host=host, port=port)
197
198
199class RealHttpFetch(object):
200
201  def __init__(self, real_dns_lookup):
202    """Initialize RealHttpFetch.
203
204    Args:
205      real_dns_lookup: a function that resolves a host to an IP. RealHttpFetch
206        will resolve host name to the IP before making fetching request if this
207        is not None.
208    """
209    self._real_dns_lookup = real_dns_lookup
210
211  @staticmethod
212  def _GetHeaderNameValue(header):
213    """Parse the header line and return a name/value tuple.
214
215    Args:
216      header: a string for a header such as "Content-Length: 314".
217    Returns:
218      A tuple (header_name, header_value) on success or None if the header
219      is not in expected format. header_name is in lowercase.
220    """
221    i = header.find(':')
222    if i > 0:
223      return (header[:i].lower(), header[i+1:].strip())
224    return None
225
226  @staticmethod
227  def _ToTuples(headers):
228    """Parse headers and save them to a list of tuples.
229
230    This method takes HttpResponse.msg.headers as input and convert it
231    to a list of (header_name, header_value) tuples.
232    HttpResponse.msg.headers is a list of strings where each string
233    represents either a header or a continuation line of a header.
234    1. a normal header consists of two parts which are separated by colon :
235       "header_name:header_value..."
236    2. a continuation line is a string starting with whitespace
237       "[whitespace]continued_header_value..."
238    If a header is not in good shape or an unexpected continuation line is
239    seen, it will be ignored.
240
241    Should avoid using response.getheaders() directly
242    because response.getheaders() can't handle multiple headers
243    with the same name properly. Instead, parse the
244    response.msg.headers using this method to get all headers.
245
246    Args:
247      headers: an instance of HttpResponse.msg.headers.
248    Returns:
249      A list of tuples which looks like:
250      [(header_name, header_value), (header_name2, header_value2)...]
251    """
252    all_headers = []
253    for line in headers:
254      if line[0] in '\t ':
255        if not all_headers:
256          logging.warning(
257              'Unexpected response header continuation line [%s]', line)
258          continue
259        name, value = all_headers.pop()
260        value += '\n ' + line.strip()
261      else:
262        name_value = RealHttpFetch._GetHeaderNameValue(line)
263        if not name_value:
264          logging.warning(
265              'Response header in wrong format [%s]', line)
266          continue
267        name, value = name_value  # pylint: disable=unpacking-non-sequence
268      all_headers.append((name, value))
269    return all_headers
270
271  @staticmethod
272  def _get_request_host_port(request):
273    host_parts = request.host.split(':')
274    host = host_parts[0]
275    port = int(host_parts[1]) if len(host_parts) == 2 else None
276    return host, port
277
278  @staticmethod
279  def _get_system_proxy(is_ssl):
280    return platformsettings.get_system_proxy(is_ssl)
281
282  def _get_connection(self, request_host, request_port, is_ssl):
283    """Return a detailed connection object for host/port pair.
284
285    If a system proxy is defined (see platformsettings.py), it will be used.
286
287    Args:
288      request_host: a host string (e.g. "www.example.com").
289      request_port: a port integer (e.g. 8080) or None (for the default port).
290      is_ssl: True if HTTPS connection is needed.
291    Returns:
292      A DetailedHTTPSConnection or DetailedHTTPConnection instance.
293    """
294    connection_host = request_host
295    connection_port = request_port
296    system_proxy = self._get_system_proxy(is_ssl)
297    if system_proxy:
298      connection_host = system_proxy.host
299      connection_port = system_proxy.port
300
301    # Use an IP address because WPR may override DNS settings.
302    if self._real_dns_lookup:
303      connection_ip = self._real_dns_lookup(connection_host)
304      if not connection_ip:
305        logging.critical(
306            'Unable to find IP for host name: %s', connection_host)
307        return None
308      connection_host = connection_ip
309
310    if is_ssl:
311      connection = DetailedHTTPSConnection(connection_host, connection_port)
312      if system_proxy:
313        connection.set_tunnel(request_host, request_port)
314    else:
315      connection = DetailedHTTPConnection(connection_host, connection_port)
316    return connection
317
318  def __call__(self, request):
319    """Fetch an HTTP request.
320
321    Args:
322      request: an ArchivedHttpRequest
323    Returns:
324      an ArchivedHttpResponse
325    """
326    logging.debug('RealHttpFetch: %s %s', request.host, request.full_path)
327    request_host, request_port = self._get_request_host_port(request)
328    retries = 3
329    while True:
330      try:
331        request_time = datetime.datetime.utcnow()
332        connection = self._get_connection(
333            request_host, request_port, request.is_ssl)
334        connect_start = TIMER()
335        connection.connect()
336        connect_delay = int((TIMER() - connect_start) * 1000)
337        start = TIMER()
338        connection.request(
339            request.command,
340            request.full_path,
341            request.request_body,
342            request.headers)
343        response = connection.getresponse()
344        headers_delay = int((TIMER() - start) * 1000)
345
346        chunks, chunk_delays = response.read_chunks()
347        delays = {
348            'connect': connect_delay,
349            'headers': headers_delay,
350            'data': chunk_delays
351            }
352        archived_http_response = httparchive.ArchivedHttpResponse(
353            response.version,
354            response.status,
355            response.reason,
356            RealHttpFetch._ToTuples(response.msg.headers),
357            chunks,
358            delays,
359            request_time)
360        return archived_http_response
361      except Exception, e:
362        if retries:
363          retries -= 1
364          logging.warning('Retrying fetch %s: %s', request, repr(e))
365          continue
366        logging.critical('Could not fetch %s: %s', request, repr(e))
367        return None
368
369
370class RecordHttpArchiveFetch(object):
371  """Make real HTTP fetches and save responses in the given HttpArchive."""
372
373  def __init__(self, http_archive, injector):
374    """Initialize RecordHttpArchiveFetch.
375
376    Args:
377      http_archive: an instance of a HttpArchive
378      injector: script injector to inject scripts in all pages
379    """
380    self.http_archive = http_archive
381    # Do not resolve host name to IP when recording to avoid SSL3 handshake
382    # failure.
383    # See https://github.com/chromium/web-page-replay/issues/73 for details.
384    self.real_http_fetch = RealHttpFetch(real_dns_lookup=None)
385    self.injector = injector
386
387  def __call__(self, request):
388    """Fetch the request and return the response.
389
390    Args:
391      request: an ArchivedHttpRequest.
392    Returns:
393      an ArchivedHttpResponse
394    """
395    # If request is already in the archive, return the archived response.
396    if request in self.http_archive:
397      logging.debug('Repeated request found: %s', request)
398      response = self.http_archive[request]
399    else:
400      response = self.real_http_fetch(request)
401      if response is None:
402        return None
403      self.http_archive[request] = response
404    if self.injector:
405      response = _InjectScripts(response, self.injector)
406    logging.debug('Recorded: %s', request)
407    return response
408
409
410class ReplayHttpArchiveFetch(object):
411  """Serve responses from the given HttpArchive."""
412
413  def __init__(self, http_archive, real_dns_lookup, injector,
414               use_diff_on_unknown_requests=False,
415               use_closest_match=False, scramble_images=False):
416    """Initialize ReplayHttpArchiveFetch.
417
418    Args:
419      http_archive: an instance of a HttpArchive
420      real_dns_lookup: a function that resolves a host to an IP.
421      injector: script injector to inject scripts in all pages
422      use_diff_on_unknown_requests: If True, log unknown requests
423        with a diff to requests that look similar.
424      use_closest_match: If True, on replay mode, serve the closest match
425        in the archive instead of giving a 404.
426    """
427    self.http_archive = http_archive
428    self.injector = injector
429    self.use_diff_on_unknown_requests = use_diff_on_unknown_requests
430    self.use_closest_match = use_closest_match
431    self.scramble_images = scramble_images
432    self.real_http_fetch = RealHttpFetch(real_dns_lookup)
433
434  def __call__(self, request):
435    """Fetch the request and return the response.
436
437    Args:
438      request: an instance of an ArchivedHttpRequest.
439    Returns:
440      Instance of ArchivedHttpResponse (if found) or None
441    """
442    if request.host.startswith('127.0.0.1:'):
443      return self.real_http_fetch(request)
444
445    response = self.http_archive.get(request)
446
447    if self.use_closest_match and not response:
448      closest_request = self.http_archive.find_closest_request(
449          request, use_path=True)
450      if closest_request:
451        response = self.http_archive.get(closest_request)
452        if response:
453          logging.info('Request not found: %s\nUsing closest match: %s',
454                       request, closest_request)
455
456    if not response:
457      reason = str(request)
458      if self.use_diff_on_unknown_requests:
459        diff = self.http_archive.diff(request)
460        if diff:
461          reason += (
462              "\nNearest request diff "
463              "('-' for archived request, '+' for current request):\n%s" % diff)
464      logging.warning('Could not replay: %s', reason)
465    else:
466      if self.injector:
467        response = _InjectScripts(response, self.injector)
468      if self.scramble_images:
469        response = _ScrambleImages(response)
470    return response
471
472
473class ControllableHttpArchiveFetch(object):
474  """Controllable fetch function that can swap between record and replay."""
475
476  def __init__(self, http_archive, real_dns_lookup,
477               injector, use_diff_on_unknown_requests,
478               use_record_mode, use_closest_match, scramble_images):
479    """Initialize HttpArchiveFetch.
480
481    Args:
482      http_archive: an instance of a HttpArchive
483      real_dns_lookup: a function that resolves a host to an IP.
484      injector: function to inject scripts in all pages.
485        takes recording time as datetime.datetime object.
486      use_diff_on_unknown_requests: If True, log unknown requests
487        with a diff to requests that look similar.
488      use_record_mode: If True, start in server in record mode.
489      use_closest_match: If True, on replay mode, serve the closest match
490        in the archive instead of giving a 404.
491    """
492    self.http_archive = http_archive
493    self.record_fetch = RecordHttpArchiveFetch(http_archive, injector)
494    self.replay_fetch = ReplayHttpArchiveFetch(
495        http_archive, real_dns_lookup, injector,
496        use_diff_on_unknown_requests, use_closest_match, scramble_images)
497    if use_record_mode:
498      self.SetRecordMode()
499    else:
500      self.SetReplayMode()
501
502  def SetRecordMode(self):
503    self.fetch = self.record_fetch
504    self.is_record_mode = True
505
506  def SetReplayMode(self):
507    self.fetch = self.replay_fetch
508    self.is_record_mode = False
509
510  def __call__(self, *args, **kwargs):
511    """Forward calls to Replay/Record fetch functions depending on mode."""
512    return self.fetch(*args, **kwargs)
513