1#!/usr/bin/env python
2"""Base class for api services."""
3
4import base64
5import contextlib
6import datetime
7import logging
8import pprint
9
10
11from protorpc import message_types
12from protorpc import messages
13import six
14from six.moves import http_client
15from six.moves import urllib
16
17
18from apitools.base.py import credentials_lib
19from apitools.base.py import encoding
20from apitools.base.py import exceptions
21from apitools.base.py import http_wrapper
22from apitools.base.py import util
23
24__all__ = [
25    'ApiMethodInfo',
26    'ApiUploadInfo',
27    'BaseApiClient',
28    'BaseApiService',
29    'NormalizeApiEndpoint',
30]
31
32# TODO(craigcitro): Remove this once we quiet the spurious logging in
33# oauth2client (or drop oauth2client).
34logging.getLogger('oauth2client.util').setLevel(logging.ERROR)
35
36_MAX_URL_LENGTH = 2048
37
38
39class ApiUploadInfo(messages.Message):
40
41    """Media upload information for a method.
42
43    Fields:
44      accept: (repeated) MIME Media Ranges for acceptable media uploads
45          to this method.
46      max_size: (integer) Maximum size of a media upload, such as 3MB
47          or 1TB (converted to an integer).
48      resumable_path: Path to use for resumable uploads.
49      resumable_multipart: (boolean) Whether or not the resumable endpoint
50          supports multipart uploads.
51      simple_path: Path to use for simple uploads.
52      simple_multipart: (boolean) Whether or not the simple endpoint
53          supports multipart uploads.
54    """
55    accept = messages.StringField(1, repeated=True)
56    max_size = messages.IntegerField(2)
57    resumable_path = messages.StringField(3)
58    resumable_multipart = messages.BooleanField(4)
59    simple_path = messages.StringField(5)
60    simple_multipart = messages.BooleanField(6)
61
62
63class ApiMethodInfo(messages.Message):
64
65    """Configuration info for an API method.
66
67    All fields are strings unless noted otherwise.
68
69    Fields:
70      relative_path: Relative path for this method.
71      method_id: ID for this method.
72      http_method: HTTP verb to use for this method.
73      path_params: (repeated) path parameters for this method.
74      query_params: (repeated) query parameters for this method.
75      ordered_params: (repeated) ordered list of parameters for
76          this method.
77      description: description of this method.
78      request_type_name: name of the request type.
79      response_type_name: name of the response type.
80      request_field: if not null, the field to pass as the body
81          of this POST request. may also be the REQUEST_IS_BODY
82          value below to indicate the whole message is the body.
83      upload_config: (ApiUploadInfo) Information about the upload
84          configuration supported by this method.
85      supports_download: (boolean) If True, this method supports
86          downloading the request via the `alt=media` query
87          parameter.
88    """
89
90    relative_path = messages.StringField(1)
91    method_id = messages.StringField(2)
92    http_method = messages.StringField(3)
93    path_params = messages.StringField(4, repeated=True)
94    query_params = messages.StringField(5, repeated=True)
95    ordered_params = messages.StringField(6, repeated=True)
96    description = messages.StringField(7)
97    request_type_name = messages.StringField(8)
98    response_type_name = messages.StringField(9)
99    request_field = messages.StringField(10, default='')
100    upload_config = messages.MessageField(ApiUploadInfo, 11)
101    supports_download = messages.BooleanField(12, default=False)
102REQUEST_IS_BODY = '<request>'
103
104
105def _LoadClass(name, messages_module):
106    if name.startswith('message_types.'):
107        _, _, classname = name.partition('.')
108        return getattr(message_types, classname)
109    elif '.' not in name:
110        return getattr(messages_module, name)
111    else:
112        raise exceptions.GeneratedClientError('Unknown class %s' % name)
113
114
115def _RequireClassAttrs(obj, attrs):
116    for attr in attrs:
117        attr_name = attr.upper()
118        if not hasattr(obj, '%s' % attr_name) or not getattr(obj, attr_name):
119            msg = 'No %s specified for object of class %s.' % (
120                attr_name, type(obj).__name__)
121            raise exceptions.GeneratedClientError(msg)
122
123
124def NormalizeApiEndpoint(api_endpoint):
125    if not api_endpoint.endswith('/'):
126        api_endpoint += '/'
127    return api_endpoint
128
129
130def _urljoin(base, url):  # pylint: disable=invalid-name
131    """Custom urljoin replacement supporting : before / in url."""
132    # In general, it's unsafe to simply join base and url. However, for
133    # the case of discovery documents, we know:
134    #  * base will never contain params, query, or fragment
135    #  * url will never contain a scheme or net_loc.
136    # In general, this means we can safely join on /; we just need to
137    # ensure we end up with precisely one / joining base and url. The
138    # exception here is the case of media uploads, where url will be an
139    # absolute url.
140    if url.startswith('http://') or url.startswith('https://'):
141        return urllib.parse.urljoin(base, url)
142    new_base = base if base.endswith('/') else base + '/'
143    new_url = url[1:] if url.startswith('/') else url
144    return new_base + new_url
145
146
147class _UrlBuilder(object):
148
149    """Convenient container for url data."""
150
151    def __init__(self, base_url, relative_path=None, query_params=None):
152        components = urllib.parse.urlsplit(_urljoin(
153            base_url, relative_path or ''))
154        if components.fragment:
155            raise exceptions.ConfigurationValueError(
156                'Unexpected url fragment: %s' % components.fragment)
157        self.query_params = urllib.parse.parse_qs(components.query or '')
158        if query_params is not None:
159            self.query_params.update(query_params)
160        self.__scheme = components.scheme
161        self.__netloc = components.netloc
162        self.relative_path = components.path or ''
163
164    @classmethod
165    def FromUrl(cls, url):
166        urlparts = urllib.parse.urlsplit(url)
167        query_params = urllib.parse.parse_qs(urlparts.query)
168        base_url = urllib.parse.urlunsplit((
169            urlparts.scheme, urlparts.netloc, '', None, None))
170        relative_path = urlparts.path or ''
171        return cls(
172            base_url, relative_path=relative_path, query_params=query_params)
173
174    @property
175    def base_url(self):
176        return urllib.parse.urlunsplit(
177            (self.__scheme, self.__netloc, '', '', ''))
178
179    @base_url.setter
180    def base_url(self, value):
181        components = urllib.parse.urlsplit(value)
182        if components.path or components.query or components.fragment:
183            raise exceptions.ConfigurationValueError(
184                'Invalid base url: %s' % value)
185        self.__scheme = components.scheme
186        self.__netloc = components.netloc
187
188    @property
189    def query(self):
190        # TODO(craigcitro): In the case that some of the query params are
191        # non-ASCII, we may silently fail to encode correctly. We should
192        # figure out who is responsible for owning the object -> str
193        # conversion.
194        return urllib.parse.urlencode(self.query_params, doseq=True)
195
196    @property
197    def url(self):
198        if '{' in self.relative_path or '}' in self.relative_path:
199            raise exceptions.ConfigurationValueError(
200                'Cannot create url with relative path %s' % self.relative_path)
201        return urllib.parse.urlunsplit((
202            self.__scheme, self.__netloc, self.relative_path, self.query, ''))
203
204
205def _SkipGetCredentials():
206    """Hook for skipping credentials. For internal use."""
207    return False
208
209
210class BaseApiClient(object):
211
212    """Base class for client libraries."""
213    MESSAGES_MODULE = None
214
215    _API_KEY = ''
216    _CLIENT_ID = ''
217    _CLIENT_SECRET = ''
218    _PACKAGE = ''
219    _SCOPES = []
220    _USER_AGENT = ''
221
222    def __init__(self, url, credentials=None, get_credentials=True, http=None,
223                 model=None, log_request=False, log_response=False,
224                 num_retries=5, max_retry_wait=60, credentials_args=None,
225                 default_global_params=None, additional_http_headers=None):
226        _RequireClassAttrs(self, ('_package', '_scopes', 'messages_module'))
227        if default_global_params is not None:
228            util.Typecheck(default_global_params, self.params_type)
229        self.__default_global_params = default_global_params
230        self.log_request = log_request
231        self.log_response = log_response
232        self.__num_retries = 5
233        self.__max_retry_wait = 60
234        # We let the @property machinery below do our validation.
235        self.num_retries = num_retries
236        self.max_retry_wait = max_retry_wait
237        self._credentials = credentials
238        get_credentials = get_credentials and not _SkipGetCredentials()
239        if get_credentials and not credentials:
240            credentials_args = credentials_args or {}
241            self._SetCredentials(**credentials_args)
242        self._url = NormalizeApiEndpoint(url)
243        self._http = http or http_wrapper.GetHttp()
244        # Note that "no credentials" is totally possible.
245        if self._credentials is not None:
246            self._http = self._credentials.authorize(self._http)
247        # TODO(craigcitro): Remove this field when we switch to proto2.
248        self.__include_fields = None
249
250        self.additional_http_headers = additional_http_headers or {}
251
252        # TODO(craigcitro): Finish deprecating these fields.
253        _ = model
254
255        self.__response_type_model = 'proto'
256
257    def _SetCredentials(self, **kwds):
258        """Fetch credentials, and set them for this client.
259
260        Note that we can't simply return credentials, since creating them
261        may involve side-effecting self.
262
263        Args:
264          **kwds: Additional keyword arguments are passed on to GetCredentials.
265
266        Returns:
267          None. Sets self._credentials.
268        """
269        args = {
270            'api_key': self._API_KEY,
271            'client': self,
272            'client_id': self._CLIENT_ID,
273            'client_secret': self._CLIENT_SECRET,
274            'package_name': self._PACKAGE,
275            'scopes': self._SCOPES,
276            'user_agent': self._USER_AGENT,
277        }
278        args.update(kwds)
279        # TODO(craigcitro): It's a bit dangerous to pass this
280        # still-half-initialized self into this method, but we might need
281        # to set attributes on it associated with our credentials.
282        # Consider another way around this (maybe a callback?) and whether
283        # or not it's worth it.
284        self._credentials = credentials_lib.GetCredentials(**args)
285
286    @classmethod
287    def ClientInfo(cls):
288        return {
289            'client_id': cls._CLIENT_ID,
290            'client_secret': cls._CLIENT_SECRET,
291            'scope': ' '.join(sorted(util.NormalizeScopes(cls._SCOPES))),
292            'user_agent': cls._USER_AGENT,
293        }
294
295    @property
296    def base_model_class(self):
297        return None
298
299    @property
300    def http(self):
301        return self._http
302
303    @property
304    def url(self):
305        return self._url
306
307    @classmethod
308    def GetScopes(cls):
309        return cls._SCOPES
310
311    @property
312    def params_type(self):
313        return _LoadClass('StandardQueryParameters', self.MESSAGES_MODULE)
314
315    @property
316    def user_agent(self):
317        return self._USER_AGENT
318
319    @property
320    def _default_global_params(self):
321        if self.__default_global_params is None:
322            self.__default_global_params = self.params_type()
323        return self.__default_global_params
324
325    def AddGlobalParam(self, name, value):
326        params = self._default_global_params
327        setattr(params, name, value)
328
329    @property
330    def global_params(self):
331        return encoding.CopyProtoMessage(self._default_global_params)
332
333    @contextlib.contextmanager
334    def IncludeFields(self, include_fields):
335        self.__include_fields = include_fields
336        yield
337        self.__include_fields = None
338
339    @property
340    def response_type_model(self):
341        return self.__response_type_model
342
343    @contextlib.contextmanager
344    def JsonResponseModel(self):
345        """In this context, return raw JSON instead of proto."""
346        old_model = self.response_type_model
347        self.__response_type_model = 'json'
348        yield
349        self.__response_type_model = old_model
350
351    @property
352    def num_retries(self):
353        return self.__num_retries
354
355    @num_retries.setter
356    def num_retries(self, value):
357        util.Typecheck(value, six.integer_types)
358        if value < 0:
359            raise exceptions.InvalidDataError(
360                'Cannot have negative value for num_retries')
361        self.__num_retries = value
362
363    @property
364    def max_retry_wait(self):
365        return self.__max_retry_wait
366
367    @max_retry_wait.setter
368    def max_retry_wait(self, value):
369        util.Typecheck(value, six.integer_types)
370        if value <= 0:
371            raise exceptions.InvalidDataError(
372                'max_retry_wait must be a postiive integer')
373        self.__max_retry_wait = value
374
375    @contextlib.contextmanager
376    def WithRetries(self, num_retries):
377        old_num_retries = self.num_retries
378        self.num_retries = num_retries
379        yield
380        self.num_retries = old_num_retries
381
382    def ProcessRequest(self, method_config, request):
383        """Hook for pre-processing of requests."""
384        if self.log_request:
385            logging.info(
386                'Calling method %s with %s: %s', method_config.method_id,
387                method_config.request_type_name, request)
388        return request
389
390    def ProcessHttpRequest(self, http_request):
391        """Hook for pre-processing of http requests."""
392        http_request.headers.update(self.additional_http_headers)
393        if self.log_request:
394            logging.info('Making http %s to %s',
395                         http_request.http_method, http_request.url)
396            logging.info('Headers: %s', pprint.pformat(http_request.headers))
397            if http_request.body:
398                # TODO(craigcitro): Make this safe to print in the case of
399                # non-printable body characters.
400                logging.info('Body:\n%s',
401                             http_request.loggable_body or http_request.body)
402            else:
403                logging.info('Body: (none)')
404        return http_request
405
406    def ProcessResponse(self, method_config, response):
407        if self.log_response:
408            logging.info('Response of type %s: %s',
409                         method_config.response_type_name, response)
410        return response
411
412    # TODO(craigcitro): Decide where these two functions should live.
413    def SerializeMessage(self, message):
414        return encoding.MessageToJson(
415            message, include_fields=self.__include_fields)
416
417    def DeserializeMessage(self, response_type, data):
418        """Deserialize the given data as method_config.response_type."""
419        try:
420            message = encoding.JsonToMessage(response_type, data)
421        except (exceptions.InvalidDataFromServerError,
422                messages.ValidationError) as e:
423            raise exceptions.InvalidDataFromServerError(
424                'Error decoding response "%s" as type %s: %s' % (
425                    data, response_type.__name__, e))
426        return message
427
428    def FinalizeTransferUrl(self, url):
429        """Modify the url for a given transfer, based on auth and version."""
430        url_builder = _UrlBuilder.FromUrl(url)
431        if self.global_params.key:
432            url_builder.query_params['key'] = self.global_params.key
433        return url_builder.url
434
435
436class BaseApiService(object):
437
438    """Base class for generated API services."""
439
440    def __init__(self, client):
441        self.__client = client
442        self._method_configs = {}
443        self._upload_configs = {}
444
445    @property
446    def _client(self):
447        return self.__client
448
449    @property
450    def client(self):
451        return self.__client
452
453    def GetMethodConfig(self, method):
454        return self._method_configs[method]
455
456    def GetUploadConfig(self, method):
457        return self._upload_configs.get(method)
458
459    def GetRequestType(self, method):
460        method_config = self.GetMethodConfig(method)
461        return getattr(self.client.MESSAGES_MODULE,
462                       method_config.request_type_name)
463
464    def GetResponseType(self, method):
465        method_config = self.GetMethodConfig(method)
466        return getattr(self.client.MESSAGES_MODULE,
467                       method_config.response_type_name)
468
469    def __CombineGlobalParams(self, global_params, default_params):
470        """Combine the given params with the defaults."""
471        util.Typecheck(global_params, (type(None), self.__client.params_type))
472        result = self.__client.params_type()
473        global_params = global_params or self.__client.params_type()
474        for field in result.all_fields():
475            value = global_params.get_assigned_value(field.name)
476            if value is None:
477                value = default_params.get_assigned_value(field.name)
478            if value not in (None, [], ()):
479                setattr(result, field.name, value)
480        return result
481
482    def __EncodePrettyPrint(self, query_info):
483        # The prettyPrint flag needs custom encoding: it should be encoded
484        # as 0 if False, and ignored otherwise (True is the default).
485        if not query_info.pop('prettyPrint', True):
486            query_info['prettyPrint'] = 0
487        # The One Platform equivalent of prettyPrint is pp, which also needs
488        # custom encoding.
489        if not query_info.pop('pp', True):
490            query_info['pp'] = 0
491        return query_info
492
493    def __FinalUrlValue(self, value, field):
494        """Encode value for the URL, using field to skip encoding for bytes."""
495        if isinstance(field, messages.BytesField) and value is not None:
496            return base64.urlsafe_b64encode(value)
497        elif isinstance(value, six.text_type):
498            return value.encode('utf8')
499        elif isinstance(value, six.binary_type):
500            return value.decode('utf8')
501        elif isinstance(value, datetime.datetime):
502            return value.isoformat()
503        return value
504
505    def __ConstructQueryParams(self, query_params, request, global_params):
506        """Construct a dictionary of query parameters for this request."""
507        # First, handle the global params.
508        global_params = self.__CombineGlobalParams(
509            global_params, self.__client.global_params)
510        global_param_names = util.MapParamNames(
511            [x.name for x in self.__client.params_type.all_fields()],
512            self.__client.params_type)
513        global_params_type = type(global_params)
514        query_info = dict(
515            (param,
516             self.__FinalUrlValue(getattr(global_params, param),
517                                  getattr(global_params_type, param)))
518            for param in global_param_names)
519        # Next, add the query params.
520        query_param_names = util.MapParamNames(query_params, type(request))
521        request_type = type(request)
522        query_info.update(
523            (param,
524             self.__FinalUrlValue(getattr(request, param, None),
525                                  getattr(request_type, param)))
526            for param in query_param_names)
527        query_info = dict((k, v) for k, v in query_info.items()
528                          if v is not None)
529        query_info = self.__EncodePrettyPrint(query_info)
530        query_info = util.MapRequestParams(query_info, type(request))
531        return query_info
532
533    def __ConstructRelativePath(self, method_config, request,
534                                relative_path=None):
535        """Determine the relative path for request."""
536        python_param_names = util.MapParamNames(
537            method_config.path_params, type(request))
538        params = dict([(param, getattr(request, param, None))
539                       for param in python_param_names])
540        params = util.MapRequestParams(params, type(request))
541        return util.ExpandRelativePath(method_config, params,
542                                       relative_path=relative_path)
543
544    def __FinalizeRequest(self, http_request, url_builder):
545        """Make any final general adjustments to the request."""
546        if (http_request.http_method == 'GET' and
547                len(http_request.url) > _MAX_URL_LENGTH):
548            http_request.http_method = 'POST'
549            http_request.headers['x-http-method-override'] = 'GET'
550            http_request.headers[
551                'content-type'] = 'application/x-www-form-urlencoded'
552            http_request.body = url_builder.query
553            url_builder.query_params = {}
554        http_request.url = url_builder.url
555
556    def __ProcessHttpResponse(self, method_config, http_response):
557        """Process the given http response."""
558        if http_response.status_code not in (http_client.OK,
559                                             http_client.NO_CONTENT):
560            raise exceptions.HttpError.FromResponse(http_response)
561        if http_response.status_code == http_client.NO_CONTENT:
562            # TODO(craigcitro): Find out why _replace doesn't seem to work
563            # here.
564            http_response = http_wrapper.Response(
565                info=http_response.info, content='{}',
566                request_url=http_response.request_url)
567        if self.__client.response_type_model == 'json':
568            return http_response.content
569        else:
570            response_type = _LoadClass(method_config.response_type_name,
571                                       self.__client.MESSAGES_MODULE)
572            return self.__client.DeserializeMessage(
573                response_type, http_response.content)
574
575    def __SetBaseHeaders(self, http_request, client):
576        """Fill in the basic headers on http_request."""
577        # TODO(craigcitro): Make the default a little better here, and
578        # include the apitools version.
579        user_agent = client.user_agent or 'apitools-client/1.0'
580        http_request.headers['user-agent'] = user_agent
581        http_request.headers['accept'] = 'application/json'
582        http_request.headers['accept-encoding'] = 'gzip, deflate'
583
584    def __SetBody(self, http_request, method_config, request, upload):
585        """Fill in the body on http_request."""
586        if not method_config.request_field:
587            return
588
589        request_type = _LoadClass(
590            method_config.request_type_name, self.__client.MESSAGES_MODULE)
591        if method_config.request_field == REQUEST_IS_BODY:
592            body_value = request
593            body_type = request_type
594        else:
595            body_value = getattr(request, method_config.request_field)
596            body_field = request_type.field_by_name(
597                method_config.request_field)
598            util.Typecheck(body_field, messages.MessageField)
599            body_type = body_field.type
600
601        # If there was no body provided, we use an empty message of the
602        # appropriate type.
603        body_value = body_value or body_type()
604        if upload and not body_value:
605            # We're going to fill in the body later.
606            return
607        util.Typecheck(body_value, body_type)
608        http_request.headers['content-type'] = 'application/json'
609        http_request.body = self.__client.SerializeMessage(body_value)
610
611    def PrepareHttpRequest(self, method_config, request, global_params=None,
612                           upload=None, upload_config=None, download=None):
613        """Prepares an HTTP request to be sent."""
614        request_type = _LoadClass(
615            method_config.request_type_name, self.__client.MESSAGES_MODULE)
616        util.Typecheck(request, request_type)
617        request = self.__client.ProcessRequest(method_config, request)
618
619        http_request = http_wrapper.Request(
620            http_method=method_config.http_method)
621        self.__SetBaseHeaders(http_request, self.__client)
622        self.__SetBody(http_request, method_config, request, upload)
623
624        url_builder = _UrlBuilder(
625            self.__client.url, relative_path=method_config.relative_path)
626        url_builder.query_params = self.__ConstructQueryParams(
627            method_config.query_params, request, global_params)
628
629        # It's important that upload and download go before we fill in the
630        # relative path, so that they can replace it.
631        if upload is not None:
632            upload.ConfigureRequest(upload_config, http_request, url_builder)
633        if download is not None:
634            download.ConfigureRequest(http_request, url_builder)
635
636        url_builder.relative_path = self.__ConstructRelativePath(
637            method_config, request, relative_path=url_builder.relative_path)
638        self.__FinalizeRequest(http_request, url_builder)
639
640        return self.__client.ProcessHttpRequest(http_request)
641
642    def _RunMethod(self, method_config, request, global_params=None,
643                   upload=None, upload_config=None, download=None):
644        """Call this method with request."""
645        if upload is not None and download is not None:
646            # TODO(craigcitro): This just involves refactoring the logic
647            # below into callbacks that we can pass around; in particular,
648            # the order should be that the upload gets the initial request,
649            # and then passes its reply to a download if one exists, and
650            # then that goes to ProcessResponse and is returned.
651            raise exceptions.NotYetImplementedError(
652                'Cannot yet use both upload and download at once')
653
654        http_request = self.PrepareHttpRequest(
655            method_config, request, global_params, upload, upload_config,
656            download)
657
658        # TODO(craigcitro): Make num_retries customizable on Transfer
659        # objects, and pass in self.__client.num_retries when initializing
660        # an upload or download.
661        if download is not None:
662            download.InitializeDownload(http_request, client=self.client)
663            return
664
665        http_response = None
666        if upload is not None:
667            http_response = upload.InitializeUpload(
668                http_request, client=self.client)
669        if http_response is None:
670            http = self.__client.http
671            if upload and upload.bytes_http:
672                http = upload.bytes_http
673            http_response = http_wrapper.MakeRequest(
674                http, http_request, retries=self.__client.num_retries,
675                max_retry_wait=self.__client.max_retry_wait)
676
677        return self.ProcessHttpResponse(method_config, http_response)
678
679    def ProcessHttpResponse(self, method_config, http_response):
680        """Convert an HTTP response to the expected message type."""
681        return self.__client.ProcessResponse(
682            method_config,
683            self.__ProcessHttpResponse(method_config, http_response))
684