1#!/usr/bin/env python
2# Copyright 2012 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Retrieve web resources over http."""
17
18import copy
19import httplib
20import logging
21import random
22import ssl
23import StringIO
24
25import httparchive
26import platformsettings
27import script_injector
28
29
30# PIL isn't always available, but we still want to be able to run without
31# the image scrambling functionality in this case.
32try:
33  import Image
34except ImportError:
35  Image = None
36
37TIMER = platformsettings.timer
38
39
40class HttpClientException(Exception):
41  """Base class for all exceptions in httpclient."""
42  pass
43
44
45def _InjectScripts(response, inject_script):
46  """Injects |inject_script| immediately after <head> or <html>.
47
48  Copies |response| if it is modified.
49
50  Args:
51    response: an ArchivedHttpResponse
52    inject_script: JavaScript string (e.g. "Math.random = function(){...}")
53  Returns:
54    an ArchivedHttpResponse
55  """
56  if type(response) == tuple:
57    logging.warn('tuple response: %s', response)
58  content_type = response.get_header('content-type')
59  if content_type and content_type.startswith('text/html'):
60    text = response.get_data_as_text()
61    text, already_injected = script_injector.InjectScript(
62        text, 'text/html', inject_script)
63    if not already_injected:
64      response = copy.deepcopy(response)
65      response.set_data(text)
66  return response
67
68
69def _ScrambleImages(response):
70  """If the |response| is an image, attempt to scramble it.
71
72  Copies |response| if it is modified.
73
74  Args:
75    response: an ArchivedHttpResponse
76  Returns:
77    an ArchivedHttpResponse
78  """
79
80  assert Image, '--scramble_images requires the PIL module to be installed.'
81
82  content_type = response.get_header('content-type')
83  if content_type and content_type.startswith('image/'):
84    try:
85      image_data = response.response_data[0]
86      image_data.decode(encoding='base64')
87      im = Image.open(StringIO.StringIO(image_data))
88
89      pixel_data = list(im.getdata())
90      random.shuffle(pixel_data)
91
92      scrambled_image = im.copy()
93      scrambled_image.putdata(pixel_data)
94
95      output_image_io = StringIO.StringIO()
96      scrambled_image.save(output_image_io, im.format)
97      output_image_data = output_image_io.getvalue()
98      output_image_data.encode(encoding='base64')
99
100      response = copy.deepcopy(response)
101      response.set_data(output_image_data)
102    except Exception:
103      pass
104
105  return response
106
107
108class DetailedHTTPResponse(httplib.HTTPResponse):
109  """Preserve details relevant to replaying responses.
110
111  WARNING: This code uses attributes and methods of HTTPResponse
112  that are not part of the public interface.
113  """
114
115  def read_chunks(self):
116    """Return the response body content and timing data.
117
118    The returned chunks have the chunk size and CRLFs stripped off.
119    If the response was compressed, the returned data is still compressed.
120
121    Returns:
122      (chunks, delays)
123        chunks:
124          [response_body]                  # non-chunked responses
125          [chunk_1, chunk_2, ...]          # chunked responses
126        delays:
127          [0]                              # non-chunked responses
128          [chunk_1_first_byte_delay, ...]  # chunked responses
129
130      The delay for the first body item should be recorded by the caller.
131    """
132    buf = []
133    chunks = []
134    delays = []
135    if not self.chunked:
136      chunks.append(self.read())
137      delays.append(0)
138    else:
139      start = TIMER()
140      try:
141        while True:
142          line = self.fp.readline()
143          chunk_size = self._read_chunk_size(line)
144          if chunk_size is None:
145            raise httplib.IncompleteRead(''.join(chunks))
146          if chunk_size == 0:
147            break
148          delays.append(TIMER() - start)
149          chunks.append(self._safe_read(chunk_size))
150          self._safe_read(2)  # skip the CRLF at the end of the chunk
151          start = TIMER()
152
153        # Ignore any trailers.
154        while True:
155          line = self.fp.readline()
156          if not line or line == '\r\n':
157            break
158      finally:
159        self.close()
160    return chunks, delays
161
162  @classmethod
163  def _read_chunk_size(cls, line):
164    chunk_extensions_pos = line.find(';')
165    if chunk_extensions_pos != -1:
166      line = line[:chunk_extensions_pos]  # strip chunk-extensions
167    try:
168      chunk_size = int(line, 16)
169    except ValueError:
170      return None
171    return chunk_size
172
173
174class DetailedHTTPConnection(httplib.HTTPConnection):
175  """Preserve details relevant to replaying connections."""
176  response_class = DetailedHTTPResponse
177
178
179class DetailedHTTPSResponse(DetailedHTTPResponse):
180  """Preserve details relevant to replaying SSL responses."""
181  pass
182
183
184class DetailedHTTPSConnection(httplib.HTTPSConnection):
185  """Preserve details relevant to replaying SSL connections."""
186  response_class = DetailedHTTPSResponse
187
188  def __init__(self, host, port):
189    # https://www.python.org/dev/peps/pep-0476/#opting-out
190    if hasattr(ssl, '_create_unverified_context'):
191      httplib.HTTPSConnection.__init__(
192          self, host=host, port=port, context=ssl._create_unverified_context())
193    else:
194      httplib.HTTPSConnection.__init__(self, host=host, port=port)
195
196
197class RealHttpFetch(object):
198
199  def __init__(self, real_dns_lookup):
200    """Initialize RealHttpFetch.
201
202    Args:
203      real_dns_lookup: a function that resolves a host to an IP.
204    """
205    self._real_dns_lookup = real_dns_lookup
206
207  @staticmethod
208  def _GetHeaderNameValue(header):
209    """Parse the header line and return a name/value tuple.
210
211    Args:
212      header: a string for a header such as "Content-Length: 314".
213    Returns:
214      A tuple (header_name, header_value) on success or None if the header
215      is not in expected format. header_name is in lowercase.
216    """
217    i = header.find(':')
218    if i > 0:
219      return (header[:i].lower(), header[i+1:].strip())
220    return None
221
222  @staticmethod
223  def _ToTuples(headers):
224    """Parse headers and save them to a list of tuples.
225
226    This method takes HttpResponse.msg.headers as input and convert it
227    to a list of (header_name, header_value) tuples.
228    HttpResponse.msg.headers is a list of strings where each string
229    represents either a header or a continuation line of a header.
230    1. a normal header consists of two parts which are separated by colon :
231       "header_name:header_value..."
232    2. a continuation line is a string starting with whitespace
233       "[whitespace]continued_header_value..."
234    If a header is not in good shape or an unexpected continuation line is
235    seen, it will be ignored.
236
237    Should avoid using response.getheaders() directly
238    because response.getheaders() can't handle multiple headers
239    with the same name properly. Instead, parse the
240    response.msg.headers using this method to get all headers.
241
242    Args:
243      headers: an instance of HttpResponse.msg.headers.
244    Returns:
245      A list of tuples which looks like:
246      [(header_name, header_value), (header_name2, header_value2)...]
247    """
248    all_headers = []
249    for line in headers:
250      if line[0] in '\t ':
251        if not all_headers:
252          logging.warning(
253              'Unexpected response header continuation line [%s]', line)
254          continue
255        name, value = all_headers.pop()
256        value += '\n ' + line.strip()
257      else:
258        name_value = RealHttpFetch._GetHeaderNameValue(line)
259        if not name_value:
260          logging.warning(
261              'Response header in wrong format [%s]', line)
262          continue
263        name, value = name_value  # pylint: disable=unpacking-non-sequence
264      all_headers.append((name, value))
265    return all_headers
266
267  @staticmethod
268  def _get_request_host_port(request):
269    host_parts = request.host.split(':')
270    host = host_parts[0]
271    port = int(host_parts[1]) if len(host_parts) == 2 else None
272    return host, port
273
274  @staticmethod
275  def _get_system_proxy(is_ssl):
276    return platformsettings.get_system_proxy(is_ssl)
277
278  def _get_connection(self, request_host, request_port, is_ssl):
279    """Return a detailed connection object for host/port pair.
280
281    If a system proxy is defined (see platformsettings.py), it will be used.
282
283    Args:
284      request_host: a host string (e.g. "www.example.com").
285      request_port: a port integer (e.g. 8080) or None (for the default port).
286      is_ssl: True if HTTPS connection is needed.
287    Returns:
288      A DetailedHTTPSConnection or DetailedHTTPConnection instance.
289    """
290    connection_host = request_host
291    connection_port = request_port
292    system_proxy = self._get_system_proxy(is_ssl)
293    if system_proxy:
294      connection_host = system_proxy.host
295      connection_port = system_proxy.port
296
297    # Use an IP address because WPR may override DNS settings.
298    connection_ip = self._real_dns_lookup(connection_host)
299    if not connection_ip:
300      logging.critical('Unable to find host ip for name: %s', connection_host)
301      return None
302
303    if is_ssl:
304      connection = DetailedHTTPSConnection(connection_ip, connection_port)
305      if system_proxy:
306        connection.set_tunnel(request_host, request_port)
307    else:
308      connection = DetailedHTTPConnection(connection_ip, connection_port)
309    return connection
310
311  def __call__(self, request):
312    """Fetch an HTTP request.
313
314    Args:
315      request: an ArchivedHttpRequest
316    Returns:
317      an ArchivedHttpResponse
318    """
319    logging.debug('RealHttpFetch: %s %s', request.host, request.full_path)
320    request_host, request_port = self._get_request_host_port(request)
321    retries = 3
322    while True:
323      try:
324        connection = self._get_connection(
325            request_host, request_port, request.is_ssl)
326        connect_start = TIMER()
327        connection.connect()
328        connect_delay = int((TIMER() - connect_start) * 1000)
329        start = TIMER()
330        connection.request(
331            request.command,
332            request.full_path,
333            request.request_body,
334            request.headers)
335        response = connection.getresponse()
336        headers_delay = int((TIMER() - start) * 1000)
337
338        chunks, chunk_delays = response.read_chunks()
339        delays = {
340            'connect': connect_delay,
341            'headers': headers_delay,
342            'data': chunk_delays
343            }
344        archived_http_response = httparchive.ArchivedHttpResponse(
345            response.version,
346            response.status,
347            response.reason,
348            RealHttpFetch._ToTuples(response.msg.headers),
349            chunks,
350            delays)
351        return archived_http_response
352      except Exception, e:
353        if retries:
354          retries -= 1
355          logging.warning('Retrying fetch %s: %s', request, repr(e))
356          continue
357        logging.critical('Could not fetch %s: %s', request, repr(e))
358        return None
359
360
361class RecordHttpArchiveFetch(object):
362  """Make real HTTP fetches and save responses in the given HttpArchive."""
363
364  def __init__(self, http_archive, real_dns_lookup, inject_script):
365    """Initialize RecordHttpArchiveFetch.
366
367    Args:
368      http_archive: an instance of a HttpArchive
369      real_dns_lookup: a function that resolves a host to an IP.
370      inject_script: script string to inject in all pages
371    """
372    self.http_archive = http_archive
373    self.real_http_fetch = RealHttpFetch(real_dns_lookup)
374    self.inject_script = inject_script
375
376  def __call__(self, request):
377    """Fetch the request and return the response.
378
379    Args:
380      request: an ArchivedHttpRequest.
381    Returns:
382      an ArchivedHttpResponse
383    """
384    # If request is already in the archive, return the archived response.
385    if request in self.http_archive:
386      logging.debug('Repeated request found: %s', request)
387      response = self.http_archive[request]
388    else:
389      response = self.real_http_fetch(request)
390      if response is None:
391        return None
392      self.http_archive[request] = response
393    if self.inject_script:
394      response = _InjectScripts(response, self.inject_script)
395    logging.debug('Recorded: %s', request)
396    return response
397
398
399class ReplayHttpArchiveFetch(object):
400  """Serve responses from the given HttpArchive."""
401
402  def __init__(self, http_archive, real_dns_lookup, inject_script,
403               use_diff_on_unknown_requests=False,
404               use_closest_match=False, scramble_images=False):
405    """Initialize ReplayHttpArchiveFetch.
406
407    Args:
408      http_archive: an instance of a HttpArchive
409      real_dns_lookup: a function that resolves a host to an IP.
410      inject_script: script string to inject in all pages
411      use_diff_on_unknown_requests: If True, log unknown requests
412        with a diff to requests that look similar.
413      use_closest_match: If True, on replay mode, serve the closest match
414        in the archive instead of giving a 404.
415    """
416    self.http_archive = http_archive
417    self.inject_script = inject_script
418    self.use_diff_on_unknown_requests = use_diff_on_unknown_requests
419    self.use_closest_match = use_closest_match
420    self.scramble_images = scramble_images
421    self.real_http_fetch = RealHttpFetch(real_dns_lookup)
422
423  def __call__(self, request):
424    """Fetch the request and return the response.
425
426    Args:
427      request: an instance of an ArchivedHttpRequest.
428    Returns:
429      Instance of ArchivedHttpResponse (if found) or None
430    """
431    if request.host.startswith('127.0.0.1:'):
432      return self.real_http_fetch(request)
433
434    response = self.http_archive.get(request)
435
436    if self.use_closest_match and not response:
437      closest_request = self.http_archive.find_closest_request(
438          request, use_path=True)
439      if closest_request:
440        response = self.http_archive.get(closest_request)
441        if response:
442          logging.info('Request not found: %s\nUsing closest match: %s',
443                       request, closest_request)
444
445    if not response:
446      reason = str(request)
447      if self.use_diff_on_unknown_requests:
448        diff = self.http_archive.diff(request)
449        if diff:
450          reason += (
451              "\nNearest request diff "
452              "('-' for archived request, '+' for current request):\n%s" % diff)
453      logging.warning('Could not replay: %s', reason)
454    else:
455      if self.inject_script:
456        response = _InjectScripts(response, self.inject_script)
457      if self.scramble_images:
458        response = _ScrambleImages(response)
459    return response
460
461
462class ControllableHttpArchiveFetch(object):
463  """Controllable fetch function that can swap between record and replay."""
464
465  def __init__(self, http_archive, real_dns_lookup,
466               inject_script, use_diff_on_unknown_requests,
467               use_record_mode, use_closest_match, scramble_images):
468    """Initialize HttpArchiveFetch.
469
470    Args:
471      http_archive: an instance of a HttpArchive
472      real_dns_lookup: a function that resolves a host to an IP.
473      inject_script: script string to inject in all pages.
474      use_diff_on_unknown_requests: If True, log unknown requests
475        with a diff to requests that look similar.
476      use_record_mode: If True, start in server in record mode.
477      use_closest_match: If True, on replay mode, serve the closest match
478        in the archive instead of giving a 404.
479    """
480    self.http_archive = http_archive
481    self.record_fetch = RecordHttpArchiveFetch(
482        http_archive, real_dns_lookup, inject_script)
483    self.replay_fetch = ReplayHttpArchiveFetch(
484        http_archive, real_dns_lookup, inject_script,
485        use_diff_on_unknown_requests, use_closest_match, scramble_images)
486    if use_record_mode:
487      self.SetRecordMode()
488    else:
489      self.SetReplayMode()
490
491  def SetRecordMode(self):
492    self.fetch = self.record_fetch
493    self.is_record_mode = True
494
495  def SetReplayMode(self):
496    self.fetch = self.replay_fetch
497    self.is_record_mode = False
498
499  def __call__(self, *args, **kwargs):
500    """Forward calls to Replay/Record fetch functions depending on mode."""
501    return self.fetch(*args, **kwargs)
502