1#!/usr/bin/env python 2"""Base class for api services.""" 3 4import base64 5import contextlib 6import datetime 7import logging 8import pprint 9 10 11from protorpc import message_types 12from protorpc import messages 13import six 14from six.moves import http_client 15from six.moves import urllib 16 17 18from apitools.base.py import credentials_lib 19from apitools.base.py import encoding 20from apitools.base.py import exceptions 21from apitools.base.py import http_wrapper 22from apitools.base.py import util 23 24__all__ = [ 25 'ApiMethodInfo', 26 'ApiUploadInfo', 27 'BaseApiClient', 28 'BaseApiService', 29 'NormalizeApiEndpoint', 30] 31 32# TODO(craigcitro): Remove this once we quiet the spurious logging in 33# oauth2client (or drop oauth2client). 34logging.getLogger('oauth2client.util').setLevel(logging.ERROR) 35 36_MAX_URL_LENGTH = 2048 37 38 39class ApiUploadInfo(messages.Message): 40 41 """Media upload information for a method. 42 43 Fields: 44 accept: (repeated) MIME Media Ranges for acceptable media uploads 45 to this method. 46 max_size: (integer) Maximum size of a media upload, such as 3MB 47 or 1TB (converted to an integer). 48 resumable_path: Path to use for resumable uploads. 49 resumable_multipart: (boolean) Whether or not the resumable endpoint 50 supports multipart uploads. 51 simple_path: Path to use for simple uploads. 52 simple_multipart: (boolean) Whether or not the simple endpoint 53 supports multipart uploads. 54 """ 55 accept = messages.StringField(1, repeated=True) 56 max_size = messages.IntegerField(2) 57 resumable_path = messages.StringField(3) 58 resumable_multipart = messages.BooleanField(4) 59 simple_path = messages.StringField(5) 60 simple_multipart = messages.BooleanField(6) 61 62 63class ApiMethodInfo(messages.Message): 64 65 """Configuration info for an API method. 66 67 All fields are strings unless noted otherwise. 68 69 Fields: 70 relative_path: Relative path for this method. 71 method_id: ID for this method. 72 http_method: HTTP verb to use for this method. 73 path_params: (repeated) path parameters for this method. 74 query_params: (repeated) query parameters for this method. 75 ordered_params: (repeated) ordered list of parameters for 76 this method. 77 description: description of this method. 78 request_type_name: name of the request type. 79 response_type_name: name of the response type. 80 request_field: if not null, the field to pass as the body 81 of this POST request. may also be the REQUEST_IS_BODY 82 value below to indicate the whole message is the body. 83 upload_config: (ApiUploadInfo) Information about the upload 84 configuration supported by this method. 85 supports_download: (boolean) If True, this method supports 86 downloading the request via the `alt=media` query 87 parameter. 88 """ 89 90 relative_path = messages.StringField(1) 91 method_id = messages.StringField(2) 92 http_method = messages.StringField(3) 93 path_params = messages.StringField(4, repeated=True) 94 query_params = messages.StringField(5, repeated=True) 95 ordered_params = messages.StringField(6, repeated=True) 96 description = messages.StringField(7) 97 request_type_name = messages.StringField(8) 98 response_type_name = messages.StringField(9) 99 request_field = messages.StringField(10, default='') 100 upload_config = messages.MessageField(ApiUploadInfo, 11) 101 supports_download = messages.BooleanField(12, default=False) 102REQUEST_IS_BODY = '<request>' 103 104 105def _LoadClass(name, messages_module): 106 if name.startswith('message_types.'): 107 _, _, classname = name.partition('.') 108 return getattr(message_types, classname) 109 elif '.' not in name: 110 return getattr(messages_module, name) 111 else: 112 raise exceptions.GeneratedClientError('Unknown class %s' % name) 113 114 115def _RequireClassAttrs(obj, attrs): 116 for attr in attrs: 117 attr_name = attr.upper() 118 if not hasattr(obj, '%s' % attr_name) or not getattr(obj, attr_name): 119 msg = 'No %s specified for object of class %s.' % ( 120 attr_name, type(obj).__name__) 121 raise exceptions.GeneratedClientError(msg) 122 123 124def NormalizeApiEndpoint(api_endpoint): 125 if not api_endpoint.endswith('/'): 126 api_endpoint += '/' 127 return api_endpoint 128 129 130def _urljoin(base, url): # pylint: disable=invalid-name 131 """Custom urljoin replacement supporting : before / in url.""" 132 # In general, it's unsafe to simply join base and url. However, for 133 # the case of discovery documents, we know: 134 # * base will never contain params, query, or fragment 135 # * url will never contain a scheme or net_loc. 136 # In general, this means we can safely join on /; we just need to 137 # ensure we end up with precisely one / joining base and url. The 138 # exception here is the case of media uploads, where url will be an 139 # absolute url. 140 if url.startswith('http://') or url.startswith('https://'): 141 return urllib.parse.urljoin(base, url) 142 new_base = base if base.endswith('/') else base + '/' 143 new_url = url[1:] if url.startswith('/') else url 144 return new_base + new_url 145 146 147class _UrlBuilder(object): 148 149 """Convenient container for url data.""" 150 151 def __init__(self, base_url, relative_path=None, query_params=None): 152 components = urllib.parse.urlsplit(_urljoin( 153 base_url, relative_path or '')) 154 if components.fragment: 155 raise exceptions.ConfigurationValueError( 156 'Unexpected url fragment: %s' % components.fragment) 157 self.query_params = urllib.parse.parse_qs(components.query or '') 158 if query_params is not None: 159 self.query_params.update(query_params) 160 self.__scheme = components.scheme 161 self.__netloc = components.netloc 162 self.relative_path = components.path or '' 163 164 @classmethod 165 def FromUrl(cls, url): 166 urlparts = urllib.parse.urlsplit(url) 167 query_params = urllib.parse.parse_qs(urlparts.query) 168 base_url = urllib.parse.urlunsplit(( 169 urlparts.scheme, urlparts.netloc, '', None, None)) 170 relative_path = urlparts.path or '' 171 return cls( 172 base_url, relative_path=relative_path, query_params=query_params) 173 174 @property 175 def base_url(self): 176 return urllib.parse.urlunsplit( 177 (self.__scheme, self.__netloc, '', '', '')) 178 179 @base_url.setter 180 def base_url(self, value): 181 components = urllib.parse.urlsplit(value) 182 if components.path or components.query or components.fragment: 183 raise exceptions.ConfigurationValueError( 184 'Invalid base url: %s' % value) 185 self.__scheme = components.scheme 186 self.__netloc = components.netloc 187 188 @property 189 def query(self): 190 # TODO(craigcitro): In the case that some of the query params are 191 # non-ASCII, we may silently fail to encode correctly. We should 192 # figure out who is responsible for owning the object -> str 193 # conversion. 194 return urllib.parse.urlencode(self.query_params, doseq=True) 195 196 @property 197 def url(self): 198 if '{' in self.relative_path or '}' in self.relative_path: 199 raise exceptions.ConfigurationValueError( 200 'Cannot create url with relative path %s' % self.relative_path) 201 return urllib.parse.urlunsplit(( 202 self.__scheme, self.__netloc, self.relative_path, self.query, '')) 203 204 205def _SkipGetCredentials(): 206 """Hook for skipping credentials. For internal use.""" 207 return False 208 209 210class BaseApiClient(object): 211 212 """Base class for client libraries.""" 213 MESSAGES_MODULE = None 214 215 _API_KEY = '' 216 _CLIENT_ID = '' 217 _CLIENT_SECRET = '' 218 _PACKAGE = '' 219 _SCOPES = [] 220 _USER_AGENT = '' 221 222 def __init__(self, url, credentials=None, get_credentials=True, http=None, 223 model=None, log_request=False, log_response=False, 224 num_retries=5, max_retry_wait=60, credentials_args=None, 225 default_global_params=None, additional_http_headers=None): 226 _RequireClassAttrs(self, ('_package', '_scopes', 'messages_module')) 227 if default_global_params is not None: 228 util.Typecheck(default_global_params, self.params_type) 229 self.__default_global_params = default_global_params 230 self.log_request = log_request 231 self.log_response = log_response 232 self.__num_retries = 5 233 self.__max_retry_wait = 60 234 # We let the @property machinery below do our validation. 235 self.num_retries = num_retries 236 self.max_retry_wait = max_retry_wait 237 self._credentials = credentials 238 get_credentials = get_credentials and not _SkipGetCredentials() 239 if get_credentials and not credentials: 240 credentials_args = credentials_args or {} 241 self._SetCredentials(**credentials_args) 242 self._url = NormalizeApiEndpoint(url) 243 self._http = http or http_wrapper.GetHttp() 244 # Note that "no credentials" is totally possible. 245 if self._credentials is not None: 246 self._http = self._credentials.authorize(self._http) 247 # TODO(craigcitro): Remove this field when we switch to proto2. 248 self.__include_fields = None 249 250 self.additional_http_headers = additional_http_headers or {} 251 252 # TODO(craigcitro): Finish deprecating these fields. 253 _ = model 254 255 self.__response_type_model = 'proto' 256 257 def _SetCredentials(self, **kwds): 258 """Fetch credentials, and set them for this client. 259 260 Note that we can't simply return credentials, since creating them 261 may involve side-effecting self. 262 263 Args: 264 **kwds: Additional keyword arguments are passed on to GetCredentials. 265 266 Returns: 267 None. Sets self._credentials. 268 """ 269 args = { 270 'api_key': self._API_KEY, 271 'client': self, 272 'client_id': self._CLIENT_ID, 273 'client_secret': self._CLIENT_SECRET, 274 'package_name': self._PACKAGE, 275 'scopes': self._SCOPES, 276 'user_agent': self._USER_AGENT, 277 } 278 args.update(kwds) 279 # TODO(craigcitro): It's a bit dangerous to pass this 280 # still-half-initialized self into this method, but we might need 281 # to set attributes on it associated with our credentials. 282 # Consider another way around this (maybe a callback?) and whether 283 # or not it's worth it. 284 self._credentials = credentials_lib.GetCredentials(**args) 285 286 @classmethod 287 def ClientInfo(cls): 288 return { 289 'client_id': cls._CLIENT_ID, 290 'client_secret': cls._CLIENT_SECRET, 291 'scope': ' '.join(sorted(util.NormalizeScopes(cls._SCOPES))), 292 'user_agent': cls._USER_AGENT, 293 } 294 295 @property 296 def base_model_class(self): 297 return None 298 299 @property 300 def http(self): 301 return self._http 302 303 @property 304 def url(self): 305 return self._url 306 307 @classmethod 308 def GetScopes(cls): 309 return cls._SCOPES 310 311 @property 312 def params_type(self): 313 return _LoadClass('StandardQueryParameters', self.MESSAGES_MODULE) 314 315 @property 316 def user_agent(self): 317 return self._USER_AGENT 318 319 @property 320 def _default_global_params(self): 321 if self.__default_global_params is None: 322 self.__default_global_params = self.params_type() 323 return self.__default_global_params 324 325 def AddGlobalParam(self, name, value): 326 params = self._default_global_params 327 setattr(params, name, value) 328 329 @property 330 def global_params(self): 331 return encoding.CopyProtoMessage(self._default_global_params) 332 333 @contextlib.contextmanager 334 def IncludeFields(self, include_fields): 335 self.__include_fields = include_fields 336 yield 337 self.__include_fields = None 338 339 @property 340 def response_type_model(self): 341 return self.__response_type_model 342 343 @contextlib.contextmanager 344 def JsonResponseModel(self): 345 """In this context, return raw JSON instead of proto.""" 346 old_model = self.response_type_model 347 self.__response_type_model = 'json' 348 yield 349 self.__response_type_model = old_model 350 351 @property 352 def num_retries(self): 353 return self.__num_retries 354 355 @num_retries.setter 356 def num_retries(self, value): 357 util.Typecheck(value, six.integer_types) 358 if value < 0: 359 raise exceptions.InvalidDataError( 360 'Cannot have negative value for num_retries') 361 self.__num_retries = value 362 363 @property 364 def max_retry_wait(self): 365 return self.__max_retry_wait 366 367 @max_retry_wait.setter 368 def max_retry_wait(self, value): 369 util.Typecheck(value, six.integer_types) 370 if value <= 0: 371 raise exceptions.InvalidDataError( 372 'max_retry_wait must be a postiive integer') 373 self.__max_retry_wait = value 374 375 @contextlib.contextmanager 376 def WithRetries(self, num_retries): 377 old_num_retries = self.num_retries 378 self.num_retries = num_retries 379 yield 380 self.num_retries = old_num_retries 381 382 def ProcessRequest(self, method_config, request): 383 """Hook for pre-processing of requests.""" 384 if self.log_request: 385 logging.info( 386 'Calling method %s with %s: %s', method_config.method_id, 387 method_config.request_type_name, request) 388 return request 389 390 def ProcessHttpRequest(self, http_request): 391 """Hook for pre-processing of http requests.""" 392 http_request.headers.update(self.additional_http_headers) 393 if self.log_request: 394 logging.info('Making http %s to %s', 395 http_request.http_method, http_request.url) 396 logging.info('Headers: %s', pprint.pformat(http_request.headers)) 397 if http_request.body: 398 # TODO(craigcitro): Make this safe to print in the case of 399 # non-printable body characters. 400 logging.info('Body:\n%s', 401 http_request.loggable_body or http_request.body) 402 else: 403 logging.info('Body: (none)') 404 return http_request 405 406 def ProcessResponse(self, method_config, response): 407 if self.log_response: 408 logging.info('Response of type %s: %s', 409 method_config.response_type_name, response) 410 return response 411 412 # TODO(craigcitro): Decide where these two functions should live. 413 def SerializeMessage(self, message): 414 return encoding.MessageToJson( 415 message, include_fields=self.__include_fields) 416 417 def DeserializeMessage(self, response_type, data): 418 """Deserialize the given data as method_config.response_type.""" 419 try: 420 message = encoding.JsonToMessage(response_type, data) 421 except (exceptions.InvalidDataFromServerError, 422 messages.ValidationError) as e: 423 raise exceptions.InvalidDataFromServerError( 424 'Error decoding response "%s" as type %s: %s' % ( 425 data, response_type.__name__, e)) 426 return message 427 428 def FinalizeTransferUrl(self, url): 429 """Modify the url for a given transfer, based on auth and version.""" 430 url_builder = _UrlBuilder.FromUrl(url) 431 if self.global_params.key: 432 url_builder.query_params['key'] = self.global_params.key 433 return url_builder.url 434 435 436class BaseApiService(object): 437 438 """Base class for generated API services.""" 439 440 def __init__(self, client): 441 self.__client = client 442 self._method_configs = {} 443 self._upload_configs = {} 444 445 @property 446 def _client(self): 447 return self.__client 448 449 @property 450 def client(self): 451 return self.__client 452 453 def GetMethodConfig(self, method): 454 return self._method_configs[method] 455 456 def GetUploadConfig(self, method): 457 return self._upload_configs.get(method) 458 459 def GetRequestType(self, method): 460 method_config = self.GetMethodConfig(method) 461 return getattr(self.client.MESSAGES_MODULE, 462 method_config.request_type_name) 463 464 def GetResponseType(self, method): 465 method_config = self.GetMethodConfig(method) 466 return getattr(self.client.MESSAGES_MODULE, 467 method_config.response_type_name) 468 469 def __CombineGlobalParams(self, global_params, default_params): 470 """Combine the given params with the defaults.""" 471 util.Typecheck(global_params, (type(None), self.__client.params_type)) 472 result = self.__client.params_type() 473 global_params = global_params or self.__client.params_type() 474 for field in result.all_fields(): 475 value = global_params.get_assigned_value(field.name) 476 if value is None: 477 value = default_params.get_assigned_value(field.name) 478 if value not in (None, [], ()): 479 setattr(result, field.name, value) 480 return result 481 482 def __EncodePrettyPrint(self, query_info): 483 # The prettyPrint flag needs custom encoding: it should be encoded 484 # as 0 if False, and ignored otherwise (True is the default). 485 if not query_info.pop('prettyPrint', True): 486 query_info['prettyPrint'] = 0 487 # The One Platform equivalent of prettyPrint is pp, which also needs 488 # custom encoding. 489 if not query_info.pop('pp', True): 490 query_info['pp'] = 0 491 return query_info 492 493 def __FinalUrlValue(self, value, field): 494 """Encode value for the URL, using field to skip encoding for bytes.""" 495 if isinstance(field, messages.BytesField) and value is not None: 496 return base64.urlsafe_b64encode(value) 497 elif isinstance(value, six.text_type): 498 return value.encode('utf8') 499 elif isinstance(value, six.binary_type): 500 return value.decode('utf8') 501 elif isinstance(value, datetime.datetime): 502 return value.isoformat() 503 return value 504 505 def __ConstructQueryParams(self, query_params, request, global_params): 506 """Construct a dictionary of query parameters for this request.""" 507 # First, handle the global params. 508 global_params = self.__CombineGlobalParams( 509 global_params, self.__client.global_params) 510 global_param_names = util.MapParamNames( 511 [x.name for x in self.__client.params_type.all_fields()], 512 self.__client.params_type) 513 global_params_type = type(global_params) 514 query_info = dict( 515 (param, 516 self.__FinalUrlValue(getattr(global_params, param), 517 getattr(global_params_type, param))) 518 for param in global_param_names) 519 # Next, add the query params. 520 query_param_names = util.MapParamNames(query_params, type(request)) 521 request_type = type(request) 522 query_info.update( 523 (param, 524 self.__FinalUrlValue(getattr(request, param, None), 525 getattr(request_type, param))) 526 for param in query_param_names) 527 query_info = dict((k, v) for k, v in query_info.items() 528 if v is not None) 529 query_info = self.__EncodePrettyPrint(query_info) 530 query_info = util.MapRequestParams(query_info, type(request)) 531 return query_info 532 533 def __ConstructRelativePath(self, method_config, request, 534 relative_path=None): 535 """Determine the relative path for request.""" 536 python_param_names = util.MapParamNames( 537 method_config.path_params, type(request)) 538 params = dict([(param, getattr(request, param, None)) 539 for param in python_param_names]) 540 params = util.MapRequestParams(params, type(request)) 541 return util.ExpandRelativePath(method_config, params, 542 relative_path=relative_path) 543 544 def __FinalizeRequest(self, http_request, url_builder): 545 """Make any final general adjustments to the request.""" 546 if (http_request.http_method == 'GET' and 547 len(http_request.url) > _MAX_URL_LENGTH): 548 http_request.http_method = 'POST' 549 http_request.headers['x-http-method-override'] = 'GET' 550 http_request.headers[ 551 'content-type'] = 'application/x-www-form-urlencoded' 552 http_request.body = url_builder.query 553 url_builder.query_params = {} 554 http_request.url = url_builder.url 555 556 def __ProcessHttpResponse(self, method_config, http_response): 557 """Process the given http response.""" 558 if http_response.status_code not in (http_client.OK, 559 http_client.NO_CONTENT): 560 raise exceptions.HttpError.FromResponse(http_response) 561 if http_response.status_code == http_client.NO_CONTENT: 562 # TODO(craigcitro): Find out why _replace doesn't seem to work 563 # here. 564 http_response = http_wrapper.Response( 565 info=http_response.info, content='{}', 566 request_url=http_response.request_url) 567 if self.__client.response_type_model == 'json': 568 return http_response.content 569 else: 570 response_type = _LoadClass(method_config.response_type_name, 571 self.__client.MESSAGES_MODULE) 572 return self.__client.DeserializeMessage( 573 response_type, http_response.content) 574 575 def __SetBaseHeaders(self, http_request, client): 576 """Fill in the basic headers on http_request.""" 577 # TODO(craigcitro): Make the default a little better here, and 578 # include the apitools version. 579 user_agent = client.user_agent or 'apitools-client/1.0' 580 http_request.headers['user-agent'] = user_agent 581 http_request.headers['accept'] = 'application/json' 582 http_request.headers['accept-encoding'] = 'gzip, deflate' 583 584 def __SetBody(self, http_request, method_config, request, upload): 585 """Fill in the body on http_request.""" 586 if not method_config.request_field: 587 return 588 589 request_type = _LoadClass( 590 method_config.request_type_name, self.__client.MESSAGES_MODULE) 591 if method_config.request_field == REQUEST_IS_BODY: 592 body_value = request 593 body_type = request_type 594 else: 595 body_value = getattr(request, method_config.request_field) 596 body_field = request_type.field_by_name( 597 method_config.request_field) 598 util.Typecheck(body_field, messages.MessageField) 599 body_type = body_field.type 600 601 # If there was no body provided, we use an empty message of the 602 # appropriate type. 603 body_value = body_value or body_type() 604 if upload and not body_value: 605 # We're going to fill in the body later. 606 return 607 util.Typecheck(body_value, body_type) 608 http_request.headers['content-type'] = 'application/json' 609 http_request.body = self.__client.SerializeMessage(body_value) 610 611 def PrepareHttpRequest(self, method_config, request, global_params=None, 612 upload=None, upload_config=None, download=None): 613 """Prepares an HTTP request to be sent.""" 614 request_type = _LoadClass( 615 method_config.request_type_name, self.__client.MESSAGES_MODULE) 616 util.Typecheck(request, request_type) 617 request = self.__client.ProcessRequest(method_config, request) 618 619 http_request = http_wrapper.Request( 620 http_method=method_config.http_method) 621 self.__SetBaseHeaders(http_request, self.__client) 622 self.__SetBody(http_request, method_config, request, upload) 623 624 url_builder = _UrlBuilder( 625 self.__client.url, relative_path=method_config.relative_path) 626 url_builder.query_params = self.__ConstructQueryParams( 627 method_config.query_params, request, global_params) 628 629 # It's important that upload and download go before we fill in the 630 # relative path, so that they can replace it. 631 if upload is not None: 632 upload.ConfigureRequest(upload_config, http_request, url_builder) 633 if download is not None: 634 download.ConfigureRequest(http_request, url_builder) 635 636 url_builder.relative_path = self.__ConstructRelativePath( 637 method_config, request, relative_path=url_builder.relative_path) 638 self.__FinalizeRequest(http_request, url_builder) 639 640 return self.__client.ProcessHttpRequest(http_request) 641 642 def _RunMethod(self, method_config, request, global_params=None, 643 upload=None, upload_config=None, download=None): 644 """Call this method with request.""" 645 if upload is not None and download is not None: 646 # TODO(craigcitro): This just involves refactoring the logic 647 # below into callbacks that we can pass around; in particular, 648 # the order should be that the upload gets the initial request, 649 # and then passes its reply to a download if one exists, and 650 # then that goes to ProcessResponse and is returned. 651 raise exceptions.NotYetImplementedError( 652 'Cannot yet use both upload and download at once') 653 654 http_request = self.PrepareHttpRequest( 655 method_config, request, global_params, upload, upload_config, 656 download) 657 658 # TODO(craigcitro): Make num_retries customizable on Transfer 659 # objects, and pass in self.__client.num_retries when initializing 660 # an upload or download. 661 if download is not None: 662 download.InitializeDownload(http_request, client=self.client) 663 return 664 665 http_response = None 666 if upload is not None: 667 http_response = upload.InitializeUpload( 668 http_request, client=self.client) 669 if http_response is None: 670 http = self.__client.http 671 if upload and upload.bytes_http: 672 http = upload.bytes_http 673 http_response = http_wrapper.MakeRequest( 674 http, http_request, retries=self.__client.num_retries, 675 max_retry_wait=self.__client.max_retry_wait) 676 677 return self.ProcessHttpResponse(method_config, http_response) 678 679 def ProcessHttpResponse(self, method_config, http_response): 680 """Convert an HTTP response to the expected message type.""" 681 return self.__client.ProcessResponse( 682 method_config, 683 self.__ProcessHttpResponse(method_config, http_response)) 684