1#!/usr/bin/env python 2# 3# Copyright 2015 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Service registry for apitools.""" 18 19import collections 20import logging 21import re 22import textwrap 23 24from apitools.base.py import base_api 25from apitools.gen import util 26 27# We're a code generator. I don't care. 28# pylint:disable=too-many-statements 29 30_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') 31 32 33class ServiceRegistry(object): 34 35 """Registry for service types.""" 36 37 def __init__(self, client_info, message_registry, command_registry, 38 names, root_package, base_files_package, 39 unelidable_request_methods): 40 self.__client_info = client_info 41 self.__package = client_info.package 42 self.__names = names 43 self.__service_method_info_map = collections.OrderedDict() 44 self.__message_registry = message_registry 45 self.__command_registry = command_registry 46 self.__root_package = root_package 47 self.__base_files_package = base_files_package 48 self.__unelidable_request_methods = unelidable_request_methods 49 self.__all_scopes = set(self.__client_info.scopes) 50 51 def Validate(self): 52 self.__message_registry.Validate() 53 54 @property 55 def scopes(self): 56 return sorted(list(self.__all_scopes)) 57 58 def __GetServiceClassName(self, service_name): 59 return self.__names.ClassName( 60 '%sService' % self.__names.ClassName(service_name)) 61 62 def __PrintDocstring(self, printer, method_info, method_name, name): 63 """Print a docstring for a service method.""" 64 if method_info.description: 65 description = util.CleanDescription(method_info.description) 66 first_line, newline, remaining = method_info.description.partition( 67 '\n') 68 if not first_line.endswith('.'): 69 first_line = '%s.' % first_line 70 description = '%s%s%s' % (first_line, newline, remaining) 71 else: 72 description = '%s method for the %s service.' % (method_name, name) 73 with printer.CommentContext(): 74 printer('"""%s' % description) 75 printer() 76 printer('Args:') 77 printer(' request: (%s) input message', method_info.request_type_name) 78 printer(' global_params: (StandardQueryParameters, default: None) ' 79 'global arguments') 80 if method_info.upload_config: 81 printer(' upload: (Upload, default: None) If present, upload') 82 printer(' this stream with the request.') 83 if method_info.supports_download: 84 printer( 85 ' download: (Download, default: None) If present, download') 86 printer(' data from the request via this stream.') 87 printer('Returns:') 88 printer(' (%s) The response message.', method_info.response_type_name) 89 printer('"""') 90 91 def __WriteSingleService( 92 self, printer, name, method_info_map, client_class_name): 93 printer() 94 class_name = self.__GetServiceClassName(name) 95 printer('class %s(base_api.BaseApiService):', class_name) 96 with printer.Indent(): 97 printer('"""Service class for the %s resource."""', name) 98 printer() 99 printer('_NAME = %s', repr(name)) 100 101 # Print the configs for the methods first. 102 printer() 103 printer('def __init__(self, client):') 104 with printer.Indent(): 105 printer('super(%s.%s, self).__init__(client)', 106 client_class_name, class_name) 107 printer('self._upload_configs = {') 108 with printer.Indent(indent=' '): 109 for method_name, method_info in method_info_map.items(): 110 upload_config = method_info.upload_config 111 if upload_config is not None: 112 printer( 113 "'%s': base_api.ApiUploadInfo(", method_name) 114 with printer.Indent(indent=' '): 115 attrs = sorted( 116 x.name for x in upload_config.all_fields()) 117 for attr in attrs: 118 printer('%s=%r,', 119 attr, getattr(upload_config, attr)) 120 printer('),') 121 printer('}') 122 123 # Now write each method in turn. 124 for method_name, method_info in method_info_map.items(): 125 printer() 126 params = ['self', 'request', 'global_params=None'] 127 if method_info.upload_config: 128 params.append('upload=None') 129 if method_info.supports_download: 130 params.append('download=None') 131 printer('def %s(%s):', method_name, ', '.join(params)) 132 with printer.Indent(): 133 self.__PrintDocstring( 134 printer, method_info, method_name, name) 135 printer("config = self.GetMethodConfig('%s')", method_name) 136 upload_config = method_info.upload_config 137 if upload_config is not None: 138 printer("upload_config = self.GetUploadConfig('%s')", 139 method_name) 140 arg_lines = [ 141 'config, request, global_params=global_params'] 142 if method_info.upload_config: 143 arg_lines.append( 144 'upload=upload, upload_config=upload_config') 145 if method_info.supports_download: 146 arg_lines.append('download=download') 147 printer('return self._RunMethod(') 148 with printer.Indent(indent=' '): 149 for line in arg_lines[:-1]: 150 printer('%s,', line) 151 printer('%s)', arg_lines[-1]) 152 printer() 153 printer('{0}.method_config = lambda: base_api.ApiMethodInfo(' 154 .format(method_name)) 155 with printer.Indent(indent=' '): 156 method_info = method_info_map[method_name] 157 attrs = sorted( 158 x.name for x in method_info.all_fields()) 159 for attr in attrs: 160 if attr in ('upload_config', 'description'): 161 continue 162 value = getattr(method_info, attr) 163 if value is not None: 164 printer('%s=%r,', attr, value) 165 printer(')') 166 167 def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): 168 """Write a single service declaration to a proto file.""" 169 printer() 170 printer('service %s {', self.__GetServiceClassName(name)) 171 with printer.Indent(): 172 for method_name, method_info in method_info_map.items(): 173 for line in textwrap.wrap(method_info.description, 174 printer.CalculateWidth() - 3): 175 printer('// %s', line) 176 printer('rpc %s (%s) returns (%s);', 177 method_name, 178 method_info.request_type_name, 179 method_info.response_type_name) 180 printer('}') 181 182 def WriteProtoFile(self, printer): 183 """Write the services in this registry to out as proto.""" 184 self.Validate() 185 client_info = self.__client_info 186 printer('// Generated services for %s version %s.', 187 client_info.package, client_info.version) 188 printer() 189 printer('syntax = "proto2";') 190 printer('package %s;', self.__package) 191 printer('import "%s";', client_info.messages_proto_file_name) 192 printer() 193 for name, method_info_map in self.__service_method_info_map.items(): 194 self.__WriteProtoServiceDeclaration(printer, name, method_info_map) 195 196 def WriteFile(self, printer): 197 """Write the services in this registry to out.""" 198 self.Validate() 199 client_info = self.__client_info 200 printer('"""Generated client library for %s version %s."""', 201 client_info.package, client_info.version) 202 printer('# NOTE: This file is autogenerated and should not be edited ' 203 'by hand.') 204 printer('from %s import base_api', self.__base_files_package) 205 if self.__root_package: 206 import_prefix = 'from {0} '.format(self.__root_package) 207 else: 208 import_prefix = '' 209 printer('%simport %s as messages', import_prefix, 210 client_info.messages_rule_name) 211 printer() 212 printer() 213 printer('class %s(base_api.BaseApiClient):', 214 client_info.client_class_name) 215 with printer.Indent(): 216 printer( 217 '"""Generated client library for service %s version %s."""', 218 client_info.package, client_info.version) 219 printer() 220 printer('MESSAGES_MODULE = messages') 221 printer('BASE_URL = {0!r}'.format(client_info.base_url)) 222 printer() 223 printer('_PACKAGE = {0!r}'.format(client_info.package)) 224 printer('_SCOPES = {0!r}'.format( 225 client_info.scopes or 226 ['https://www.googleapis.com/auth/userinfo.email'])) 227 printer('_VERSION = {0!r}'.format(client_info.version)) 228 printer('_CLIENT_ID = {0!r}'.format(client_info.client_id)) 229 printer('_CLIENT_SECRET = {0!r}'.format(client_info.client_secret)) 230 printer('_USER_AGENT = {0!r}'.format(client_info.user_agent)) 231 printer('_CLIENT_CLASS_NAME = {0!r}'.format( 232 client_info.client_class_name)) 233 printer('_URL_VERSION = {0!r}'.format(client_info.url_version)) 234 printer('_API_KEY = {0!r}'.format(client_info.api_key)) 235 printer() 236 printer("def __init__(self, url='', credentials=None,") 237 with printer.Indent(indent=' '): 238 printer('get_credentials=True, http=None, model=None,') 239 printer('log_request=False, log_response=False,') 240 printer('credentials_args=None, default_global_params=None,') 241 printer('additional_http_headers=None):') 242 with printer.Indent(): 243 printer('"""Create a new %s handle."""', client_info.package) 244 printer('url = url or self.BASE_URL') 245 printer( 246 'super(%s, self).__init__(', client_info.client_class_name) 247 printer(' url, credentials=credentials,') 248 printer(' get_credentials=get_credentials, http=http, ' 249 'model=model,') 250 printer(' log_request=log_request, ' 251 'log_response=log_response,') 252 printer(' credentials_args=credentials_args,') 253 printer(' default_global_params=default_global_params,') 254 printer(' additional_http_headers=additional_http_headers)') 255 for name in self.__service_method_info_map.keys(): 256 printer('self.%s = self.%s(self)', 257 name, self.__GetServiceClassName(name)) 258 for name, method_info in self.__service_method_info_map.items(): 259 self.__WriteSingleService( 260 printer, name, method_info, client_info.client_class_name) 261 262 def __RegisterService(self, service_name, method_info_map): 263 if service_name in self.__service_method_info_map: 264 raise ValueError( 265 'Attempt to re-register descriptor %s' % service_name) 266 self.__service_method_info_map[service_name] = method_info_map 267 268 def __CreateRequestType(self, method_description, body_type=None): 269 """Create a request type for this method.""" 270 schema = {} 271 schema['id'] = self.__names.ClassName('%sRequest' % ( 272 self.__names.ClassName(method_description['id'], separator='.'),)) 273 schema['type'] = 'object' 274 schema['properties'] = collections.OrderedDict() 275 if 'parameterOrder' not in method_description: 276 ordered_parameters = list(method_description.get('parameters', [])) 277 else: 278 ordered_parameters = method_description['parameterOrder'][:] 279 for k in method_description['parameters']: 280 if k not in ordered_parameters: 281 ordered_parameters.append(k) 282 for parameter_name in ordered_parameters: 283 field_name = self.__names.CleanName(parameter_name) 284 field = dict(method_description['parameters'][parameter_name]) 285 if 'type' not in field: 286 raise ValueError('No type found in parameter %s' % field) 287 schema['properties'][field_name] = field 288 if body_type is not None: 289 body_field_name = self.__GetRequestField( 290 method_description, body_type) 291 if body_field_name in schema['properties']: 292 raise ValueError('Failed to normalize request resource name') 293 if 'description' not in body_type: 294 body_type['description'] = ( 295 'A %s resource to be passed as the request body.' % ( 296 self.__GetRequestType(body_type),)) 297 schema['properties'][body_field_name] = body_type 298 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 299 return schema['id'] 300 301 def __CreateVoidResponseType(self, method_description): 302 """Create an empty response type.""" 303 schema = {} 304 method_name = self.__names.ClassName( 305 method_description['id'], separator='.') 306 schema['id'] = self.__names.ClassName('%sResponse' % method_name) 307 schema['type'] = 'object' 308 schema['description'] = 'An empty %s response.' % method_name 309 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 310 return schema['id'] 311 312 def __NeedRequestType(self, method_description, request_type): 313 """Determine if this method needs a new request type created.""" 314 if not request_type: 315 return True 316 method_id = method_description.get('id', '') 317 if method_id in self.__unelidable_request_methods: 318 return True 319 message = self.__message_registry.LookupDescriptorOrDie(request_type) 320 if message is None: 321 return True 322 field_names = [x.name for x in message.fields] 323 parameters = method_description.get('parameters', {}) 324 for param_name, param_info in parameters.items(): 325 if (param_info.get('location') != 'path' or 326 self.__names.CleanName(param_name) not in field_names): 327 break 328 else: 329 return False 330 return True 331 332 def __MaxSizeToInt(self, max_size): 333 """Convert max_size to an int.""" 334 size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) 335 if size_groups is None: 336 raise ValueError('Could not parse maxSize') 337 size, unit = size_groups.group('size', 'unit') 338 shift = 0 339 if unit is not None: 340 unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} 341 shift = unit_dict.get(unit.upper()) 342 if shift is None: 343 raise ValueError('Unknown unit %s' % unit) 344 return int(size) * (1 << shift) 345 346 def __ComputeUploadConfig(self, media_upload_config, method_id): 347 """Fill out the upload config for this method.""" 348 config = base_api.ApiUploadInfo() 349 if 'maxSize' in media_upload_config: 350 config.max_size = self.__MaxSizeToInt( 351 media_upload_config['maxSize']) 352 if 'accept' not in media_upload_config: 353 logging.warn( 354 'No accept types found for upload configuration in ' 355 'method %s, using */*', method_id) 356 config.accept.extend([ 357 str(a) for a in media_upload_config.get('accept', '*/*')]) 358 359 for accept_pattern in config.accept: 360 if not _MIME_PATTERN_RE.match(accept_pattern): 361 logging.warn('Unexpected MIME type: %s', accept_pattern) 362 protocols = media_upload_config.get('protocols', {}) 363 for protocol in ('simple', 'resumable'): 364 media = protocols.get(protocol, {}) 365 for attr in ('multipart', 'path'): 366 if attr in media: 367 setattr(config, '%s_%s' % (protocol, attr), media[attr]) 368 return config 369 370 def __ComputeMethodInfo(self, method_description, request, response, 371 request_field): 372 """Compute the base_api.ApiMethodInfo for this method.""" 373 relative_path = self.__names.NormalizeRelativePath( 374 ''.join((self.__client_info.base_path, 375 method_description['path']))) 376 method_id = method_description['id'] 377 ordered_params = [] 378 for param_name in method_description.get('parameterOrder', []): 379 param_info = method_description['parameters'][param_name] 380 if param_info.get('required', False): 381 ordered_params.append(param_name) 382 method_info = base_api.ApiMethodInfo( 383 relative_path=relative_path, 384 method_id=method_id, 385 http_method=method_description['httpMethod'], 386 description=util.CleanDescription( 387 method_description.get('description', '')), 388 query_params=[], 389 path_params=[], 390 ordered_params=ordered_params, 391 request_type_name=self.__names.ClassName(request), 392 response_type_name=self.__names.ClassName(response), 393 request_field=request_field, 394 ) 395 flat_path = method_description.get('flatPath', None) 396 if flat_path is not None: 397 flat_path = self.__names.NormalizeRelativePath( 398 self.__client_info.base_path + flat_path) 399 if flat_path != relative_path: 400 method_info.flat_path = flat_path 401 if method_description.get('supportsMediaUpload', False): 402 method_info.upload_config = self.__ComputeUploadConfig( 403 method_description.get('mediaUpload'), method_id) 404 method_info.supports_download = method_description.get( 405 'supportsMediaDownload', False) 406 self.__all_scopes.update(method_description.get('scopes', ())) 407 for param, desc in method_description.get('parameters', {}).items(): 408 param = self.__names.CleanName(param) 409 location = desc['location'] 410 if location == 'query': 411 method_info.query_params.append(param) 412 elif location == 'path': 413 method_info.path_params.append(param) 414 else: 415 raise ValueError( 416 'Unknown parameter location %s for parameter %s' % ( 417 location, param)) 418 method_info.path_params.sort() 419 method_info.query_params.sort() 420 return method_info 421 422 def __BodyFieldName(self, body_type): 423 if body_type is None: 424 return '' 425 return self.__names.FieldName(body_type['$ref']) 426 427 def __GetRequestType(self, body_type): 428 return self.__names.ClassName(body_type.get('$ref')) 429 430 def __GetRequestField(self, method_description, body_type): 431 """Determine the request field for this method.""" 432 body_field_name = self.__BodyFieldName(body_type) 433 if body_field_name in method_description.get('parameters', {}): 434 body_field_name = self.__names.FieldName( 435 '%s_resource' % body_field_name) 436 # It's exceedingly unlikely that we'd get two name collisions, which 437 # means it's bound to happen at some point. 438 while body_field_name in method_description.get('parameters', {}): 439 body_field_name = self.__names.FieldName( 440 '%s_body' % body_field_name) 441 return body_field_name 442 443 def AddServiceFromResource(self, service_name, methods): 444 """Add a new service named service_name with the given methods.""" 445 method_descriptions = methods.get('methods', {}) 446 method_info_map = collections.OrderedDict() 447 items = sorted(method_descriptions.items()) 448 for method_name, method_description in items: 449 method_name = self.__names.MethodName(method_name) 450 451 # NOTE: According to the discovery document, if the request or 452 # response is present, it will simply contain a `$ref`. 453 body_type = method_description.get('request') 454 if body_type is None: 455 request_type = None 456 else: 457 request_type = self.__GetRequestType(body_type) 458 if self.__NeedRequestType(method_description, request_type): 459 request = self.__CreateRequestType( 460 method_description, body_type=body_type) 461 request_field = self.__GetRequestField( 462 method_description, body_type) 463 else: 464 request = request_type 465 request_field = base_api.REQUEST_IS_BODY 466 467 if 'response' in method_description: 468 response = method_description['response']['$ref'] 469 else: 470 response = self.__CreateVoidResponseType(method_description) 471 472 method_info_map[method_name] = self.__ComputeMethodInfo( 473 method_description, request, response, request_field) 474 self.__command_registry.AddCommandForMethod( 475 service_name, method_name, method_info_map[method_name], 476 request, response) 477 478 nested_services = methods.get('resources', {}) 479 services = sorted(nested_services.items()) 480 for subservice_name, submethods in services: 481 new_service_name = '%s_%s' % (service_name, subservice_name) 482 self.AddServiceFromResource(new_service_name, submethods) 483 484 self.__RegisterService(service_name, method_info_map) 485