1#!/usr/bin/env python 2"""Service registry for apitools.""" 3 4import collections 5import logging 6import re 7import textwrap 8 9from apitools.base.py import base_api 10from apitools.gen import util 11 12# We're a code generator. I don't care. 13# pylint:disable=too-many-statements 14 15_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') 16 17 18class ServiceRegistry(object): 19 20 """Registry for service types.""" 21 22 def __init__(self, client_info, message_registry, command_registry, 23 base_url, base_path, names, 24 root_package_dir, base_files_package, 25 unelidable_request_methods): 26 self.__client_info = client_info 27 self.__package = client_info.package 28 self.__names = names 29 self.__service_method_info_map = collections.OrderedDict() 30 self.__message_registry = message_registry 31 self.__command_registry = command_registry 32 self.__base_url = base_url 33 self.__base_path = base_path 34 self.__root_package_dir = root_package_dir 35 self.__base_files_package = base_files_package 36 self.__unelidable_request_methods = unelidable_request_methods 37 self.__all_scopes = set(self.__client_info.scopes) 38 39 def Validate(self): 40 self.__message_registry.Validate() 41 42 @property 43 def scopes(self): 44 return sorted(list(self.__all_scopes)) 45 46 def __GetServiceClassName(self, service_name): 47 return self.__names.ClassName( 48 '%sService' % self.__names.ClassName(service_name)) 49 50 def __PrintDocstring(self, printer, method_info, method_name, name): 51 """Print a docstring for a service method.""" 52 if method_info.description: 53 description = util.CleanDescription(method_info.description) 54 first_line, newline, remaining = method_info.description.partition( 55 '\n') 56 if not first_line.endswith('.'): 57 first_line = '%s.' % first_line 58 description = '%s%s%s' % (first_line, newline, remaining) 59 else: 60 description = '%s method for the %s service.' % (method_name, name) 61 with printer.CommentContext(): 62 printer('"""%s' % description) 63 printer() 64 printer('Args:') 65 printer(' request: (%s) input message', method_info.request_type_name) 66 printer(' global_params: (StandardQueryParameters, default: None) ' 67 'global arguments') 68 if method_info.upload_config: 69 printer(' upload: (Upload, default: None) If present, upload') 70 printer(' this stream with the request.') 71 if method_info.supports_download: 72 printer( 73 ' download: (Download, default: None) If present, download') 74 printer(' data from the request via this stream.') 75 printer('Returns:') 76 printer(' (%s) The response message.', method_info.response_type_name) 77 printer('"""') 78 79 def __WriteSingleService( 80 self, printer, name, method_info_map, client_class_name): 81 printer() 82 class_name = self.__GetServiceClassName(name) 83 printer('class %s(base_api.BaseApiService):', class_name) 84 with printer.Indent(): 85 printer('"""Service class for the %s resource."""', name) 86 printer() 87 printer('_NAME = %s', repr(name)) 88 89 # Print the configs for the methods first. 90 printer() 91 printer('def __init__(self, client):') 92 with printer.Indent(): 93 printer('super(%s.%s, self).__init__(client)', 94 client_class_name, class_name) 95 printer('self._method_configs = {') 96 with printer.Indent(indent=' '): 97 for method_name, method_info in method_info_map.items(): 98 printer("'%s': base_api.ApiMethodInfo(", method_name) 99 with printer.Indent(indent=' '): 100 attrs = sorted( 101 x.name for x in method_info.all_fields()) 102 for attr in attrs: 103 if attr in ('upload_config', 'description'): 104 continue 105 printer( 106 '%s=%r,', attr, getattr(method_info, attr)) 107 printer('),') 108 printer('}') 109 printer() 110 printer('self._upload_configs = {') 111 with printer.Indent(indent=' '): 112 for method_name, method_info in method_info_map.items(): 113 upload_config = method_info.upload_config 114 if upload_config is not None: 115 printer( 116 "'%s': base_api.ApiUploadInfo(", method_name) 117 with printer.Indent(indent=' '): 118 attrs = sorted( 119 x.name for x in upload_config.all_fields()) 120 for attr in attrs: 121 printer('%s=%r,', 122 attr, getattr(upload_config, attr)) 123 printer('),') 124 printer('}') 125 126 # Now write each method in turn. 127 for method_name, method_info in method_info_map.items(): 128 printer() 129 params = ['self', 'request', 'global_params=None'] 130 if method_info.upload_config: 131 params.append('upload=None') 132 if method_info.supports_download: 133 params.append('download=None') 134 printer('def %s(%s):', method_name, ', '.join(params)) 135 with printer.Indent(): 136 self.__PrintDocstring( 137 printer, method_info, method_name, name) 138 printer("config = self.GetMethodConfig('%s')", method_name) 139 upload_config = method_info.upload_config 140 if upload_config is not None: 141 printer("upload_config = self.GetUploadConfig('%s')", 142 method_name) 143 arg_lines = [ 144 'config, request, global_params=global_params'] 145 if method_info.upload_config: 146 arg_lines.append( 147 'upload=upload, upload_config=upload_config') 148 if method_info.supports_download: 149 arg_lines.append('download=download') 150 printer('return self._RunMethod(') 151 with printer.Indent(indent=' '): 152 for line in arg_lines[:-1]: 153 printer('%s,', line) 154 printer('%s)', arg_lines[-1]) 155 156 def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): 157 """Write a single service declaration to a proto file.""" 158 printer() 159 printer('service %s {', self.__GetServiceClassName(name)) 160 with printer.Indent(): 161 for method_name, method_info in method_info_map.items(): 162 for line in textwrap.wrap(method_info.description, 163 printer.CalculateWidth() - 3): 164 printer('// %s', line) 165 printer('rpc %s (%s) returns (%s);', 166 method_name, 167 method_info.request_type_name, 168 method_info.response_type_name) 169 printer('}') 170 171 def WriteProtoFile(self, printer): 172 """Write the services in this registry to out as proto.""" 173 self.Validate() 174 client_info = self.__client_info 175 printer('// Generated services for %s version %s.', 176 client_info.package, client_info.version) 177 printer() 178 printer('syntax = "proto2";') 179 printer('package %s;', self.__package) 180 printer('import "%s";', client_info.messages_proto_file_name) 181 printer() 182 for name, method_info_map in self.__service_method_info_map.items(): 183 self.__WriteProtoServiceDeclaration(printer, name, method_info_map) 184 185 def WriteFile(self, printer): 186 """Write the services in this registry to out.""" 187 self.Validate() 188 client_info = self.__client_info 189 printer('"""Generated client library for %s version %s."""', 190 client_info.package, client_info.version) 191 printer('# NOTE: This file is autogenerated and should not be edited ' 192 'by hand.') 193 printer('from %s import base_api', self.__base_files_package) 194 import_prefix = '' 195 printer('%simport %s as messages', import_prefix, 196 client_info.messages_rule_name) 197 printer() 198 printer() 199 printer('class %s(base_api.BaseApiClient):', 200 client_info.client_class_name) 201 with printer.Indent(): 202 printer( 203 '"""Generated client library for service %s version %s."""', 204 client_info.package, client_info.version) 205 printer() 206 printer('MESSAGES_MODULE = messages') 207 printer() 208 client_info_items = client_info._asdict( 209 ).items() # pylint:disable=protected-access 210 for attr, val in client_info_items: 211 if attr == 'scopes' and not val: 212 val = ['https://www.googleapis.com/auth/userinfo.email'] 213 printer('_%s = %r' % (attr.upper(), val)) 214 printer() 215 printer("def __init__(self, url='', credentials=None,") 216 with printer.Indent(indent=' '): 217 printer('get_credentials=True, http=None, model=None,') 218 printer('log_request=False, log_response=False,') 219 printer('credentials_args=None, default_global_params=None,') 220 printer('additional_http_headers=None):') 221 with printer.Indent(): 222 printer('"""Create a new %s handle."""', client_info.package) 223 printer('url = url or %r', self.__base_url) 224 printer( 225 'super(%s, self).__init__(', client_info.client_class_name) 226 printer(' url, credentials=credentials,') 227 printer(' get_credentials=get_credentials, http=http, ' 228 'model=model,') 229 printer(' log_request=log_request, ' 230 'log_response=log_response,') 231 printer(' credentials_args=credentials_args,') 232 printer(' default_global_params=default_global_params,') 233 printer(' additional_http_headers=additional_http_headers)') 234 for name in self.__service_method_info_map.keys(): 235 printer('self.%s = self.%s(self)', 236 name, self.__GetServiceClassName(name)) 237 for name, method_info in self.__service_method_info_map.items(): 238 self.__WriteSingleService( 239 printer, name, method_info, client_info.client_class_name) 240 241 def __RegisterService(self, service_name, method_info_map): 242 if service_name in self.__service_method_info_map: 243 raise ValueError( 244 'Attempt to re-register descriptor %s' % service_name) 245 self.__service_method_info_map[service_name] = method_info_map 246 247 def __CreateRequestType(self, method_description, body_type=None): 248 """Create a request type for this method.""" 249 schema = {} 250 schema['id'] = self.__names.ClassName('%sRequest' % ( 251 self.__names.ClassName(method_description['id'], separator='.'),)) 252 schema['type'] = 'object' 253 schema['properties'] = collections.OrderedDict() 254 if 'parameterOrder' not in method_description: 255 ordered_parameters = list(method_description.get('parameters', [])) 256 else: 257 ordered_parameters = method_description['parameterOrder'][:] 258 for k in method_description['parameters']: 259 if k not in ordered_parameters: 260 ordered_parameters.append(k) 261 for parameter_name in ordered_parameters: 262 field_name = self.__names.CleanName(parameter_name) 263 field = dict(method_description['parameters'][parameter_name]) 264 if 'type' not in field: 265 raise ValueError('No type found in parameter %s' % field) 266 schema['properties'][field_name] = field 267 if body_type is not None: 268 body_field_name = self.__GetRequestField( 269 method_description, body_type) 270 if body_field_name in schema['properties']: 271 raise ValueError('Failed to normalize request resource name') 272 if 'description' not in body_type: 273 body_type['description'] = ( 274 'A %s resource to be passed as the request body.' % ( 275 self.__GetRequestType(body_type),)) 276 schema['properties'][body_field_name] = body_type 277 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 278 return schema['id'] 279 280 def __CreateVoidResponseType(self, method_description): 281 """Create an empty response type.""" 282 schema = {} 283 method_name = self.__names.ClassName( 284 method_description['id'], separator='.') 285 schema['id'] = self.__names.ClassName('%sResponse' % method_name) 286 schema['type'] = 'object' 287 schema['description'] = 'An empty %s response.' % method_name 288 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 289 return schema['id'] 290 291 def __NeedRequestType(self, method_description, request_type): 292 """Determine if this method needs a new request type created.""" 293 if not request_type: 294 return True 295 method_id = method_description.get('id', '') 296 if method_id in self.__unelidable_request_methods: 297 return True 298 message = self.__message_registry.LookupDescriptorOrDie(request_type) 299 if message is None: 300 return True 301 field_names = [x.name for x in message.fields] 302 parameters = method_description.get('parameters', {}) 303 for param_name, param_info in parameters.items(): 304 if (param_info.get('location') != 'path' or 305 self.__names.CleanName(param_name) not in field_names): 306 break 307 else: 308 return False 309 return True 310 311 def __MaxSizeToInt(self, max_size): 312 """Convert max_size to an int.""" 313 size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) 314 if size_groups is None: 315 raise ValueError('Could not parse maxSize') 316 size, unit = size_groups.group('size', 'unit') 317 shift = 0 318 if unit is not None: 319 unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} 320 shift = unit_dict.get(unit.upper()) 321 if shift is None: 322 raise ValueError('Unknown unit %s' % unit) 323 return int(size) * (1 << shift) 324 325 def __ComputeUploadConfig(self, media_upload_config, method_id): 326 """Fill out the upload config for this method.""" 327 config = base_api.ApiUploadInfo() 328 if 'maxSize' in media_upload_config: 329 config.max_size = self.__MaxSizeToInt( 330 media_upload_config['maxSize']) 331 if 'accept' not in media_upload_config: 332 logging.warn( 333 'No accept types found for upload configuration in ' 334 'method %s, using */*', method_id) 335 config.accept.extend([ 336 str(a) for a in media_upload_config.get('accept', '*/*')]) 337 338 for accept_pattern in config.accept: 339 if not _MIME_PATTERN_RE.match(accept_pattern): 340 logging.warn('Unexpected MIME type: %s', accept_pattern) 341 protocols = media_upload_config.get('protocols', {}) 342 for protocol in ('simple', 'resumable'): 343 media = protocols.get(protocol, {}) 344 for attr in ('multipart', 'path'): 345 if attr in media: 346 setattr(config, '%s_%s' % (protocol, attr), media[attr]) 347 return config 348 349 def __ComputeMethodInfo(self, method_description, request, response, 350 request_field): 351 """Compute the base_api.ApiMethodInfo for this method.""" 352 relative_path = self.__names.NormalizeRelativePath( 353 ''.join((self.__base_path, method_description['path']))) 354 method_id = method_description['id'] 355 ordered_params = [] 356 for param_name in method_description.get('parameterOrder', []): 357 param_info = method_description['parameters'][param_name] 358 if param_info.get('required', False): 359 ordered_params.append(param_name) 360 method_info = base_api.ApiMethodInfo( 361 relative_path=relative_path, 362 method_id=method_id, 363 http_method=method_description['httpMethod'], 364 description=util.CleanDescription( 365 method_description.get('description', '')), 366 query_params=[], 367 path_params=[], 368 ordered_params=ordered_params, 369 request_type_name=self.__names.ClassName(request), 370 response_type_name=self.__names.ClassName(response), 371 request_field=request_field, 372 ) 373 if method_description.get('supportsMediaUpload', False): 374 method_info.upload_config = self.__ComputeUploadConfig( 375 method_description.get('mediaUpload'), method_id) 376 method_info.supports_download = method_description.get( 377 'supportsMediaDownload', False) 378 self.__all_scopes.update(method_description.get('scopes', ())) 379 for param, desc in method_description.get('parameters', {}).items(): 380 param = self.__names.CleanName(param) 381 location = desc['location'] 382 if location == 'query': 383 method_info.query_params.append(param) 384 elif location == 'path': 385 method_info.path_params.append(param) 386 else: 387 raise ValueError( 388 'Unknown parameter location %s for parameter %s' % ( 389 location, param)) 390 method_info.path_params.sort() 391 method_info.query_params.sort() 392 return method_info 393 394 def __BodyFieldName(self, body_type): 395 if body_type is None: 396 return '' 397 return self.__names.FieldName(body_type['$ref']) 398 399 def __GetRequestType(self, body_type): 400 return self.__names.ClassName(body_type.get('$ref')) 401 402 def __GetRequestField(self, method_description, body_type): 403 """Determine the request field for this method.""" 404 body_field_name = self.__BodyFieldName(body_type) 405 if body_field_name in method_description.get('parameters', {}): 406 body_field_name = self.__names.FieldName( 407 '%s_resource' % body_field_name) 408 # It's exceedingly unlikely that we'd get two name collisions, which 409 # means it's bound to happen at some point. 410 while body_field_name in method_description.get('parameters', {}): 411 body_field_name = self.__names.FieldName( 412 '%s_body' % body_field_name) 413 return body_field_name 414 415 def AddServiceFromResource(self, service_name, methods): 416 """Add a new service named service_name with the given methods.""" 417 method_descriptions = methods.get('methods', {}) 418 method_info_map = collections.OrderedDict() 419 items = sorted(method_descriptions.items()) 420 for method_name, method_description in items: 421 method_name = self.__names.MethodName(method_name) 422 423 # NOTE: According to the discovery document, if the request or 424 # response is present, it will simply contain a `$ref`. 425 body_type = method_description.get('request') 426 if body_type is None: 427 request_type = None 428 else: 429 request_type = self.__GetRequestType(body_type) 430 if self.__NeedRequestType(method_description, request_type): 431 request = self.__CreateRequestType( 432 method_description, body_type=body_type) 433 request_field = self.__GetRequestField( 434 method_description, body_type) 435 else: 436 request = request_type 437 request_field = base_api.REQUEST_IS_BODY 438 439 if 'response' in method_description: 440 response = method_description['response']['$ref'] 441 else: 442 response = self.__CreateVoidResponseType(method_description) 443 444 method_info_map[method_name] = self.__ComputeMethodInfo( 445 method_description, request, response, request_field) 446 self.__command_registry.AddCommandForMethod( 447 service_name, method_name, method_info_map[method_name], 448 request, response) 449 450 nested_services = methods.get('resources', {}) 451 services = sorted(nested_services.items()) 452 for subservice_name, submethods in services: 453 new_service_name = '%s_%s' % (service_name, subservice_name) 454 self.AddServiceFromResource(new_service_name, submethods) 455 456 self.__RegisterService(service_name, method_info_map) 457