1# -*- coding: utf-8 -*-
2# Copyright 2014 Google Inc. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Media helper functions and classes for Google Cloud Storage JSON API."""
16
17from __future__ import absolute_import
18
19import copy
20import cStringIO
21import httplib
22import logging
23import socket
24import types
25import urlparse
26
27from apitools.base.py import exceptions as apitools_exceptions
28import httplib2
29from httplib2 import parse_uri
30
31from gslib.cloud_api import BadRequestException
32from gslib.progress_callback import ProgressCallbackWithBackoff
33from gslib.util import SSL_TIMEOUT
34from gslib.util import TRANSFER_BUFFER_SIZE
35
36
37class BytesTransferredContainer(object):
38  """Container class for passing number of bytes transferred to lower layers.
39
40  For resumed transfers or connection rebuilds in the middle of a transfer, we
41  need to rebuild the connection class with how much we've transferred so far.
42  For uploads, we don't know the total number of bytes uploaded until we've
43  queried the server, but we need to create the connection class to pass to
44  httplib2 before we can query the server. This container object allows us to
45  pass a reference into Upload/DownloadCallbackConnection.
46  """
47
48  def __init__(self):
49    self.__bytes_transferred = 0
50
51  @property
52  def bytes_transferred(self):
53    return self.__bytes_transferred
54
55  @bytes_transferred.setter
56  def bytes_transferred(self, value):
57    self.__bytes_transferred = value
58
59
60class UploadCallbackConnectionClassFactory(object):
61  """Creates a class that can override an httplib2 connection.
62
63  This is used to provide progress callbacks and disable dumping the upload
64  payload during debug statements. It can later be used to provide on-the-fly
65  hash digestion during upload.
66  """
67
68  def __init__(self, bytes_uploaded_container,
69               buffer_size=TRANSFER_BUFFER_SIZE,
70               total_size=0, progress_callback=None):
71    self.bytes_uploaded_container = bytes_uploaded_container
72    self.buffer_size = buffer_size
73    self.total_size = total_size
74    self.progress_callback = progress_callback
75
76  def GetConnectionClass(self):
77    """Returns a connection class that overrides send."""
78    outer_bytes_uploaded_container = self.bytes_uploaded_container
79    outer_buffer_size = self.buffer_size
80    outer_total_size = self.total_size
81    outer_progress_callback = self.progress_callback
82
83    class UploadCallbackConnection(httplib2.HTTPSConnectionWithTimeout):
84      """Connection class override for uploads."""
85      bytes_uploaded_container = outer_bytes_uploaded_container
86      # After we instantiate this class, apitools will check with the server
87      # to find out how many bytes remain for a resumable upload.  This allows
88      # us to update our progress once based on that number.
89      processed_initial_bytes = False
90      GCS_JSON_BUFFER_SIZE = outer_buffer_size
91      callback_processor = None
92      size = outer_total_size
93
94      def __init__(self, *args, **kwargs):
95        kwargs['timeout'] = SSL_TIMEOUT
96        httplib2.HTTPSConnectionWithTimeout.__init__(self, *args, **kwargs)
97
98      def send(self, data):
99        """Overrides HTTPConnection.send."""
100        if not self.processed_initial_bytes:
101          self.processed_initial_bytes = True
102          if outer_progress_callback:
103            self.callback_processor = ProgressCallbackWithBackoff(
104                outer_total_size, outer_progress_callback)
105            self.callback_processor.Progress(
106                self.bytes_uploaded_container.bytes_transferred)
107        # httplib.HTTPConnection.send accepts either a string or a file-like
108        # object (anything that implements read()).
109        if isinstance(data, basestring):
110          full_buffer = cStringIO.StringIO(data)
111        else:
112          full_buffer = data
113        partial_buffer = full_buffer.read(self.GCS_JSON_BUFFER_SIZE)
114        while partial_buffer:
115          httplib2.HTTPSConnectionWithTimeout.send(self, partial_buffer)
116          send_length = len(partial_buffer)
117          if self.callback_processor:
118            # This is the only place where gsutil has control over making a
119            # callback, but here we can't differentiate the metadata bytes
120            # (such as headers and OAuth2 refreshes) sent during an upload
121            # from the actual upload bytes, so we will actually report
122            # slightly more bytes than desired to the callback handler.
123            #
124            # One considered/rejected alternative is to move the callbacks
125            # into the HashingFileUploadWrapper which only processes reads on
126            # the bytes. This has the disadvantages of being removed from
127            # where we actually send the bytes and unnecessarily
128            # multi-purposing that class.
129            self.callback_processor.Progress(send_length)
130          partial_buffer = full_buffer.read(self.GCS_JSON_BUFFER_SIZE)
131
132    return UploadCallbackConnection
133
134
135def WrapUploadHttpRequest(upload_http):
136  """Wraps upload_http so we only use our custom connection_type on PUTs.
137
138  POSTs are used to refresh oauth tokens, and we don't want to process the
139  data sent in those requests.
140
141  Args:
142    upload_http: httplib2.Http instance to wrap
143  """
144  request_orig = upload_http.request
145  def NewRequest(uri, method='GET', body=None, headers=None,
146                 redirections=httplib2.DEFAULT_MAX_REDIRECTS,
147                 connection_type=None):
148    if method == 'PUT' or method == 'POST':
149      override_connection_type = connection_type
150    else:
151      override_connection_type = None
152    return request_orig(uri, method=method, body=body,
153                        headers=headers, redirections=redirections,
154                        connection_type=override_connection_type)
155  # Replace the request method with our own closure.
156  upload_http.request = NewRequest
157
158
159class DownloadCallbackConnectionClassFactory(object):
160  """Creates a class that can override an httplib2 connection.
161
162  This is used to provide progress callbacks, disable dumping the download
163  payload during debug statements, and provide on-the-fly hash digestion during
164  download. On-the-fly digestion is particularly important because httplib2
165  will decompress gzipped content on-the-fly, thus this class provides our
166  only opportunity to calculate the correct hash for an object that has a
167  gzip hash in the cloud.
168  """
169
170  def __init__(self, bytes_downloaded_container,
171               buffer_size=TRANSFER_BUFFER_SIZE, total_size=0,
172               progress_callback=None, digesters=None):
173    self.buffer_size = buffer_size
174    self.total_size = total_size
175    self.progress_callback = progress_callback
176    self.digesters = digesters
177    self.bytes_downloaded_container = bytes_downloaded_container
178
179  def GetConnectionClass(self):
180    """Returns a connection class that overrides getresponse."""
181
182    class DownloadCallbackConnection(httplib2.HTTPSConnectionWithTimeout):
183      """Connection class override for downloads."""
184      outer_total_size = self.total_size
185      outer_digesters = self.digesters
186      outer_progress_callback = self.progress_callback
187      outer_bytes_downloaded_container = self.bytes_downloaded_container
188      processed_initial_bytes = False
189      callback_processor = None
190
191      def __init__(self, *args, **kwargs):
192        kwargs['timeout'] = SSL_TIMEOUT
193        httplib2.HTTPSConnectionWithTimeout.__init__(self, *args, **kwargs)
194
195      def getresponse(self, buffering=False):
196        """Wraps an HTTPResponse to perform callbacks and hashing.
197
198        In this function, self is a DownloadCallbackConnection.
199
200        Args:
201          buffering: Unused. This function uses a local buffer.
202
203        Returns:
204          HTTPResponse object with wrapped read function.
205        """
206        orig_response = httplib.HTTPConnection.getresponse(self)
207        if orig_response.status not in (httplib.OK, httplib.PARTIAL_CONTENT):
208          return orig_response
209        orig_read_func = orig_response.read
210
211        def read(amt=None):  # pylint: disable=invalid-name
212          """Overrides HTTPConnection.getresponse.read.
213
214          This function only supports reads of TRANSFER_BUFFER_SIZE or smaller.
215
216          Args:
217            amt: Integer n where 0 < n <= TRANSFER_BUFFER_SIZE. This is a
218                 keyword argument to match the read function it overrides,
219                 but it is required.
220
221          Returns:
222            Data read from HTTPConnection.
223          """
224          if not amt or amt > TRANSFER_BUFFER_SIZE:
225            raise BadRequestException(
226                'Invalid HTTP read size %s during download, expected %s.' %
227                (amt, TRANSFER_BUFFER_SIZE))
228          else:
229            amt = amt or TRANSFER_BUFFER_SIZE
230
231          if not self.processed_initial_bytes:
232            self.processed_initial_bytes = True
233            if self.outer_progress_callback:
234              self.callback_processor = ProgressCallbackWithBackoff(
235                  self.outer_total_size, self.outer_progress_callback)
236              self.callback_processor.Progress(
237                  self.outer_bytes_downloaded_container.bytes_transferred)
238
239          data = orig_read_func(amt)
240          read_length = len(data)
241          if self.callback_processor:
242            self.callback_processor.Progress(read_length)
243          if self.outer_digesters:
244            for alg in self.outer_digesters:
245              self.outer_digesters[alg].update(data)
246          return data
247        orig_response.read = read
248
249        return orig_response
250    return DownloadCallbackConnection
251
252
253def WrapDownloadHttpRequest(download_http):
254  """Overrides download request functions for an httplib2.Http object.
255
256  Args:
257    download_http: httplib2.Http.object to wrap / override.
258
259  Returns:
260    Wrapped / overridden httplib2.Http object.
261  """
262
263  # httplib2 has a bug https://code.google.com/p/httplib2/issues/detail?id=305
264  # where custom connection_type is not respected after redirects.  This
265  # function is copied from httplib2 and overrides the request function so that
266  # the connection_type is properly passed through.
267  # pylint: disable=protected-access,g-inconsistent-quotes,unused-variable
268  # pylint: disable=g-equals-none,g-doc-return-or-yield
269  # pylint: disable=g-short-docstring-punctuation,g-doc-args
270  # pylint: disable=too-many-statements
271  def OverrideRequest(self, conn, host, absolute_uri, request_uri, method,
272                      body, headers, redirections, cachekey):
273    """Do the actual request using the connection object.
274
275    Also follow one level of redirects if necessary.
276    """
277
278    auths = ([(auth.depth(request_uri), auth) for auth in self.authorizations
279              if auth.inscope(host, request_uri)])
280    auth = auths and sorted(auths)[0][1] or None
281    if auth:
282      auth.request(method, request_uri, headers, body)
283
284    (response, content) = self._conn_request(conn, request_uri, method, body,
285                                             headers)
286
287    if auth:
288      if auth.response(response, body):
289        auth.request(method, request_uri, headers, body)
290        (response, content) = self._conn_request(conn, request_uri, method,
291                                                 body, headers)
292        response._stale_digest = 1
293
294    if response.status == 401:
295      for authorization in self._auth_from_challenge(
296          host, request_uri, headers, response, content):
297        authorization.request(method, request_uri, headers, body)
298        (response, content) = self._conn_request(conn, request_uri, method,
299                                                 body, headers)
300        if response.status != 401:
301          self.authorizations.append(authorization)
302          authorization.response(response, body)
303          break
304
305    if (self.follow_all_redirects or (method in ["GET", "HEAD"])
306        or response.status == 303):
307      if self.follow_redirects and response.status in [300, 301, 302,
308                                                       303, 307]:
309        # Pick out the location header and basically start from the beginning
310        # remembering first to strip the ETag header and decrement our 'depth'
311        if redirections:
312          if not response.has_key('location') and response.status != 300:
313            raise httplib2.RedirectMissingLocation(
314                "Redirected but the response is missing a Location: header.",
315                response, content)
316          # Fix-up relative redirects (which violate an RFC 2616 MUST)
317          if response.has_key('location'):
318            location = response['location']
319            (scheme, authority, path, query, fragment) = parse_uri(location)
320            if authority == None:
321              response['location'] = urlparse.urljoin(absolute_uri, location)
322          if response.status == 301 and method in ["GET", "HEAD"]:
323            response['-x-permanent-redirect-url'] = response['location']
324            if not response.has_key('content-location'):
325              response['content-location'] = absolute_uri
326            httplib2._updateCache(headers, response, content, self.cache,
327                                  cachekey)
328          if headers.has_key('if-none-match'):
329            del headers['if-none-match']
330          if headers.has_key('if-modified-since'):
331            del headers['if-modified-since']
332          if ('authorization' in headers and
333              not self.forward_authorization_headers):
334            del headers['authorization']
335          if response.has_key('location'):
336            location = response['location']
337            old_response = copy.deepcopy(response)
338            if not old_response.has_key('content-location'):
339              old_response['content-location'] = absolute_uri
340            redirect_method = method
341            if response.status in [302, 303]:
342              redirect_method = "GET"
343              body = None
344            (response, content) = self.request(
345                location, redirect_method, body=body, headers=headers,
346                redirections=redirections-1,
347                connection_type=conn.__class__)
348            response.previous = old_response
349        else:
350          raise httplib2.RedirectLimit(
351              "Redirected more times than redirection_limit allows.",
352              response, content)
353      elif response.status in [200, 203] and method in ["GET", "HEAD"]:
354        # Don't cache 206's since we aren't going to handle byte range
355        # requests
356        if not response.has_key('content-location'):
357          response['content-location'] = absolute_uri
358        httplib2._updateCache(headers, response, content, self.cache,
359                              cachekey)
360
361    return (response, content)
362
363  # Wrap download_http so we do not use our custom connection_type
364  # on POSTS, which are used to refresh oauth tokens. We don't want to
365  # process the data received in those requests.
366  request_orig = download_http.request
367  def NewRequest(uri, method='GET', body=None, headers=None,
368                 redirections=httplib2.DEFAULT_MAX_REDIRECTS,
369                 connection_type=None):
370    if method == 'POST':
371      return request_orig(uri, method=method, body=body,
372                          headers=headers, redirections=redirections,
373                          connection_type=None)
374    else:
375      return request_orig(uri, method=method, body=body,
376                          headers=headers, redirections=redirections,
377                          connection_type=connection_type)
378
379  # Replace the request methods with our own closures.
380  download_http._request = types.MethodType(OverrideRequest, download_http)
381  download_http.request = NewRequest
382
383  return download_http
384
385
386class HttpWithNoRetries(httplib2.Http):
387  """httplib2.Http variant that does not retry.
388
389  httplib2 automatically retries requests according to httplib2.RETRIES, but
390  in certain cases httplib2 ignores the RETRIES value and forces a retry.
391  Because httplib2 does not handle the case where the underlying request body
392  is a stream, a retry may cause a non-idempotent write as the stream is
393  partially consumed and not reset before the retry occurs.
394
395  Here we override _conn_request to disable retries unequivocally, so that
396  uploads may be retried at higher layers that properly handle stream request
397  bodies.
398  """
399
400  def _conn_request(self, conn, request_uri, method, body, headers):  # pylint: disable=too-many-statements
401
402    try:
403      if hasattr(conn, 'sock') and conn.sock is None:
404        conn.connect()
405      conn.request(method, request_uri, body, headers)
406    except socket.timeout:
407      raise
408    except socket.gaierror:
409      conn.close()
410      raise httplib2.ServerNotFoundError(
411          'Unable to find the server at %s' % conn.host)
412    except httplib2.ssl_SSLError:
413      conn.close()
414      raise
415    except socket.error, e:
416      err = 0
417      if hasattr(e, 'args'):
418        err = getattr(e, 'args')[0]
419      else:
420        err = e.errno
421      if err == httplib2.errno.ECONNREFUSED:  # Connection refused
422        raise
423    except httplib.HTTPException:
424      conn.close()
425      raise
426    try:
427      response = conn.getresponse()
428    except (socket.error, httplib.HTTPException):
429      conn.close()
430      raise
431    else:
432      content = ''
433      if method == 'HEAD':
434        conn.close()
435      else:
436        content = response.read()
437      response = httplib2.Response(response)
438      if method != 'HEAD':
439        # pylint: disable=protected-access
440        content = httplib2._decompressContent(response, content)
441    return (response, content)
442
443
444class HttpWithDownloadStream(httplib2.Http):
445  """httplib2.Http variant that only pushes bytes through a stream.
446
447  httplib2 handles media by storing entire chunks of responses in memory, which
448  is undesirable particularly when multiple instances are used during
449  multi-threaded/multi-process copy. This class copies and then overrides some
450  httplib2 functions to use a streaming copy approach that uses small memory
451  buffers.
452
453  Also disables httplib2 retries (for reasons stated in the HttpWithNoRetries
454  class doc).
455  """
456
457  def __init__(self, *args, **kwds):
458    self._stream = None
459    self._logger = logging.getLogger()
460    super(HttpWithDownloadStream, self).__init__(*args, **kwds)
461
462  @property
463  def stream(self):
464    return self._stream
465
466  @stream.setter
467  def stream(self, value):
468    self._stream = value
469
470  def _conn_request(self, conn, request_uri, method, body, headers):  # pylint: disable=too-many-statements
471    try:
472      if hasattr(conn, 'sock') and conn.sock is None:
473        conn.connect()
474      conn.request(method, request_uri, body, headers)
475    except socket.timeout:
476      raise
477    except socket.gaierror:
478      conn.close()
479      raise httplib2.ServerNotFoundError(
480          'Unable to find the server at %s' % conn.host)
481    except httplib2.ssl_SSLError:
482      conn.close()
483      raise
484    except socket.error, e:
485      err = 0
486      if hasattr(e, 'args'):
487        err = getattr(e, 'args')[0]
488      else:
489        err = e.errno
490      if err == httplib2.errno.ECONNREFUSED:  # Connection refused
491        raise
492    except httplib.HTTPException:
493      # Just because the server closed the connection doesn't apparently mean
494      # that the server didn't send a response.
495      conn.close()
496      raise
497    try:
498      response = conn.getresponse()
499    except (socket.error, httplib.HTTPException):
500      conn.close()
501      raise
502    else:
503      content = ''
504      if method == 'HEAD':
505        conn.close()
506        response = httplib2.Response(response)
507      else:
508        if response.status in (httplib.OK, httplib.PARTIAL_CONTENT):
509          content_length = None
510          if hasattr(response, 'msg'):
511            content_length = response.getheader('content-length')
512          http_stream = response
513          bytes_read = 0
514          while True:
515            new_data = http_stream.read(TRANSFER_BUFFER_SIZE)
516            if new_data:
517              if self.stream is None:
518                raise apitools_exceptions.InvalidUserInputError(
519                    'Cannot exercise HttpWithDownloadStream with no stream')
520              self.stream.write(new_data)
521              bytes_read += len(new_data)
522            else:
523              break
524
525          if (content_length is not None and
526              long(bytes_read) != long(content_length)):
527            # The input stream terminated before we were able to read the
528            # entire contents, possibly due to a network condition. Set
529            # content-length to indicate how many bytes we actually read.
530            self._logger.log(
531                logging.DEBUG, 'Only got %s bytes out of content-length %s '
532                'for request URI %s. Resetting content-length to match '
533                'bytes read.', bytes_read, content_length, request_uri)
534            response.msg['content-length'] = str(bytes_read)
535          response = httplib2.Response(response)
536        else:
537          # We fall back to the current httplib2 behavior if we're
538          # not processing bytes (eg it's a redirect).
539          content = response.read()
540          response = httplib2.Response(response)
541          # pylint: disable=protected-access
542          content = httplib2._decompressContent(response, content)
543    return (response, content)
544