1# Copyright 2016 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging
16
17import httplib2
18import six
19from six.moves import http_client
20
21from oauth2client._helpers import _to_bytes
22
23
24_LOGGER = logging.getLogger(__name__)
25# Properties present in file-like streams / buffers.
26_STREAM_PROPERTIES = ('read', 'seek', 'tell')
27
28# Google Data client libraries may need to set this to [401, 403].
29REFRESH_STATUS_CODES = (http_client.UNAUTHORIZED,)
30
31
32class MemoryCache(object):
33    """httplib2 Cache implementation which only caches locally."""
34
35    def __init__(self):
36        self.cache = {}
37
38    def get(self, key):
39        return self.cache.get(key)
40
41    def set(self, key, value):
42        self.cache[key] = value
43
44    def delete(self, key):
45        self.cache.pop(key, None)
46
47
48def get_cached_http():
49    """Return an HTTP object which caches results returned.
50
51    This is intended to be used in methods like
52    oauth2client.client.verify_id_token(), which calls to the same URI
53    to retrieve certs.
54
55    Returns:
56        httplib2.Http, an HTTP object with a MemoryCache
57    """
58    return _CACHED_HTTP
59
60
61def get_http_object():
62    """Return a new HTTP object.
63
64    Returns:
65        httplib2.Http, an HTTP object.
66    """
67    return httplib2.Http()
68
69
70def _initialize_headers(headers):
71    """Creates a copy of the headers.
72
73    Args:
74        headers: dict, request headers to copy.
75
76    Returns:
77        dict, the copied headers or a new dictionary if the headers
78        were None.
79    """
80    return {} if headers is None else dict(headers)
81
82
83def _apply_user_agent(headers, user_agent):
84    """Adds a user-agent to the headers.
85
86    Args:
87        headers: dict, request headers to add / modify user
88                 agent within.
89        user_agent: str, the user agent to add.
90
91    Returns:
92        dict, the original headers passed in, but modified if the
93        user agent is not None.
94    """
95    if user_agent is not None:
96        if 'user-agent' in headers:
97            headers['user-agent'] = (user_agent + ' ' + headers['user-agent'])
98        else:
99            headers['user-agent'] = user_agent
100
101    return headers
102
103
104def clean_headers(headers):
105    """Forces header keys and values to be strings, i.e not unicode.
106
107    The httplib module just concats the header keys and values in a way that
108    may make the message header a unicode string, which, if it then tries to
109    contatenate to a binary request body may result in a unicode decode error.
110
111    Args:
112        headers: dict, A dictionary of headers.
113
114    Returns:
115        The same dictionary but with all the keys converted to strings.
116    """
117    clean = {}
118    try:
119        for k, v in six.iteritems(headers):
120            if not isinstance(k, six.binary_type):
121                k = str(k)
122            if not isinstance(v, six.binary_type):
123                v = str(v)
124            clean[_to_bytes(k)] = _to_bytes(v)
125    except UnicodeEncodeError:
126        from oauth2client.client import NonAsciiHeaderError
127        raise NonAsciiHeaderError(k, ': ', v)
128    return clean
129
130
131def wrap_http_for_auth(credentials, http):
132    """Prepares an HTTP object's request method for auth.
133
134    Wraps HTTP requests with logic to catch auth failures (typically
135    identified via a 401 status code). In the event of failure, tries
136    to refresh the token used and then retry the original request.
137
138    Args:
139        credentials: Credentials, the credentials used to identify
140                     the authenticated user.
141        http: httplib2.Http, an http object to be used to make
142              auth requests.
143    """
144    orig_request_method = http.request
145
146    # The closure that will replace 'httplib2.Http.request'.
147    def new_request(uri, method='GET', body=None, headers=None,
148                    redirections=httplib2.DEFAULT_MAX_REDIRECTS,
149                    connection_type=None):
150        if not credentials.access_token:
151            _LOGGER.info('Attempting refresh to obtain '
152                         'initial access_token')
153            credentials._refresh(orig_request_method)
154
155        # Clone and modify the request headers to add the appropriate
156        # Authorization header.
157        headers = _initialize_headers(headers)
158        credentials.apply(headers)
159        _apply_user_agent(headers, credentials.user_agent)
160
161        body_stream_position = None
162        # Check if the body is a file-like stream.
163        if all(getattr(body, stream_prop, None) for stream_prop in
164               _STREAM_PROPERTIES):
165            body_stream_position = body.tell()
166
167        resp, content = orig_request_method(uri, method, body,
168                                            clean_headers(headers),
169                                            redirections, connection_type)
170
171        # A stored token may expire between the time it is retrieved and
172        # the time the request is made, so we may need to try twice.
173        max_refresh_attempts = 2
174        for refresh_attempt in range(max_refresh_attempts):
175            if resp.status not in REFRESH_STATUS_CODES:
176                break
177            _LOGGER.info('Refreshing due to a %s (attempt %s/%s)',
178                         resp.status, refresh_attempt + 1,
179                         max_refresh_attempts)
180            credentials._refresh(orig_request_method)
181            credentials.apply(headers)
182            if body_stream_position is not None:
183                body.seek(body_stream_position)
184
185            resp, content = orig_request_method(uri, method, body,
186                                                clean_headers(headers),
187                                                redirections, connection_type)
188
189        return resp, content
190
191    # Replace the request method with our own closure.
192    http.request = new_request
193
194    # Set credentials as a property of the request method.
195    setattr(http.request, 'credentials', credentials)
196
197
198def wrap_http_for_jwt_access(credentials, http):
199    """Prepares an HTTP object's request method for JWT access.
200
201    Wraps HTTP requests with logic to catch auth failures (typically
202    identified via a 401 status code). In the event of failure, tries
203    to refresh the token used and then retry the original request.
204
205    Args:
206        credentials: _JWTAccessCredentials, the credentials used to identify
207                     a service account that uses JWT access tokens.
208        http: httplib2.Http, an http object to be used to make
209              auth requests.
210    """
211    orig_request_method = http.request
212    wrap_http_for_auth(credentials, http)
213    # The new value of ``http.request`` set by ``wrap_http_for_auth``.
214    authenticated_request_method = http.request
215
216    # The closure that will replace 'httplib2.Http.request'.
217    def new_request(uri, method='GET', body=None, headers=None,
218                    redirections=httplib2.DEFAULT_MAX_REDIRECTS,
219                    connection_type=None):
220        if 'aud' in credentials._kwargs:
221            # Preemptively refresh token, this is not done for OAuth2
222            if (credentials.access_token is None or
223                    credentials.access_token_expired):
224                credentials.refresh(None)
225            return authenticated_request_method(uri, method, body,
226                                                headers, redirections,
227                                                connection_type)
228        else:
229            # If we don't have an 'aud' (audience) claim,
230            # create a 1-time token with the uri root as the audience
231            headers = _initialize_headers(headers)
232            _apply_user_agent(headers, credentials.user_agent)
233            uri_root = uri.split('?', 1)[0]
234            token, unused_expiry = credentials._create_token({'aud': uri_root})
235
236            headers['Authorization'] = 'Bearer ' + token
237            return orig_request_method(uri, method, body,
238                                       clean_headers(headers),
239                                       redirections, connection_type)
240
241    # Replace the request method with our own closure.
242    http.request = new_request
243
244
245_CACHED_HTTP = httplib2.Http(MemoryCache())
246