1#!/usr/bin/env python
2"""Service registry for apitools."""
3
4import collections
5import logging
6import re
7import textwrap
8
9from apitools.base.py import base_api
10from apitools.gen import util
11
12# We're a code generator. I don't care.
13# pylint:disable=too-many-statements
14
15_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+')
16
17
18class ServiceRegistry(object):
19
20    """Registry for service types."""
21
22    def __init__(self, client_info, message_registry, command_registry,
23                 base_url, base_path, names,
24                 root_package_dir, base_files_package,
25                 unelidable_request_methods):
26        self.__client_info = client_info
27        self.__package = client_info.package
28        self.__names = names
29        self.__service_method_info_map = collections.OrderedDict()
30        self.__message_registry = message_registry
31        self.__command_registry = command_registry
32        self.__base_url = base_url
33        self.__base_path = base_path
34        self.__root_package_dir = root_package_dir
35        self.__base_files_package = base_files_package
36        self.__unelidable_request_methods = unelidable_request_methods
37        self.__all_scopes = set(self.__client_info.scopes)
38
39    def Validate(self):
40        self.__message_registry.Validate()
41
42    @property
43    def scopes(self):
44        return sorted(list(self.__all_scopes))
45
46    def __GetServiceClassName(self, service_name):
47        return self.__names.ClassName(
48            '%sService' % self.__names.ClassName(service_name))
49
50    def __PrintDocstring(self, printer, method_info, method_name, name):
51        """Print a docstring for a service method."""
52        if method_info.description:
53            description = util.CleanDescription(method_info.description)
54            first_line, newline, remaining = method_info.description.partition(
55                '\n')
56            if not first_line.endswith('.'):
57                first_line = '%s.' % first_line
58            description = '%s%s%s' % (first_line, newline, remaining)
59        else:
60            description = '%s method for the %s service.' % (method_name, name)
61        with printer.CommentContext():
62            printer('"""%s' % description)
63        printer()
64        printer('Args:')
65        printer('  request: (%s) input message', method_info.request_type_name)
66        printer('  global_params: (StandardQueryParameters, default: None) '
67                'global arguments')
68        if method_info.upload_config:
69            printer('  upload: (Upload, default: None) If present, upload')
70            printer('      this stream with the request.')
71        if method_info.supports_download:
72            printer(
73                '  download: (Download, default: None) If present, download')
74            printer('      data from the request via this stream.')
75        printer('Returns:')
76        printer('  (%s) The response message.', method_info.response_type_name)
77        printer('"""')
78
79    def __WriteSingleService(
80            self, printer, name, method_info_map, client_class_name):
81        printer()
82        class_name = self.__GetServiceClassName(name)
83        printer('class %s(base_api.BaseApiService):', class_name)
84        with printer.Indent():
85            printer('"""Service class for the %s resource."""', name)
86            printer()
87            printer('_NAME = %s', repr(name))
88
89            # Print the configs for the methods first.
90            printer()
91            printer('def __init__(self, client):')
92            with printer.Indent():
93                printer('super(%s.%s, self).__init__(client)',
94                        client_class_name, class_name)
95                printer('self._method_configs = {')
96                with printer.Indent(indent='    '):
97                    for method_name, method_info in method_info_map.items():
98                        printer("'%s': base_api.ApiMethodInfo(", method_name)
99                        with printer.Indent(indent='    '):
100                            attrs = sorted(
101                                x.name for x in method_info.all_fields())
102                            for attr in attrs:
103                                if attr in ('upload_config', 'description'):
104                                    continue
105                                printer(
106                                    '%s=%r,', attr, getattr(method_info, attr))
107                        printer('),')
108                    printer('}')
109                printer()
110                printer('self._upload_configs = {')
111                with printer.Indent(indent='    '):
112                    for method_name, method_info in method_info_map.items():
113                        upload_config = method_info.upload_config
114                        if upload_config is not None:
115                            printer(
116                                "'%s': base_api.ApiUploadInfo(", method_name)
117                            with printer.Indent(indent='    '):
118                                attrs = sorted(
119                                    x.name for x in upload_config.all_fields())
120                                for attr in attrs:
121                                    printer('%s=%r,',
122                                            attr, getattr(upload_config, attr))
123                            printer('),')
124                    printer('}')
125
126            # Now write each method in turn.
127            for method_name, method_info in method_info_map.items():
128                printer()
129                params = ['self', 'request', 'global_params=None']
130                if method_info.upload_config:
131                    params.append('upload=None')
132                if method_info.supports_download:
133                    params.append('download=None')
134                printer('def %s(%s):', method_name, ', '.join(params))
135                with printer.Indent():
136                    self.__PrintDocstring(
137                        printer, method_info, method_name, name)
138                    printer("config = self.GetMethodConfig('%s')", method_name)
139                    upload_config = method_info.upload_config
140                    if upload_config is not None:
141                        printer("upload_config = self.GetUploadConfig('%s')",
142                                method_name)
143                    arg_lines = [
144                        'config, request, global_params=global_params']
145                    if method_info.upload_config:
146                        arg_lines.append(
147                            'upload=upload, upload_config=upload_config')
148                    if method_info.supports_download:
149                        arg_lines.append('download=download')
150                    printer('return self._RunMethod(')
151                    with printer.Indent(indent='    '):
152                        for line in arg_lines[:-1]:
153                            printer('%s,', line)
154                        printer('%s)', arg_lines[-1])
155
156    def __WriteProtoServiceDeclaration(self, printer, name, method_info_map):
157        """Write a single service declaration to a proto file."""
158        printer()
159        printer('service %s {', self.__GetServiceClassName(name))
160        with printer.Indent():
161            for method_name, method_info in method_info_map.items():
162                for line in textwrap.wrap(method_info.description,
163                                          printer.CalculateWidth() - 3):
164                    printer('// %s', line)
165                printer('rpc %s (%s) returns (%s);',
166                        method_name,
167                        method_info.request_type_name,
168                        method_info.response_type_name)
169        printer('}')
170
171    def WriteProtoFile(self, printer):
172        """Write the services in this registry to out as proto."""
173        self.Validate()
174        client_info = self.__client_info
175        printer('// Generated services for %s version %s.',
176                client_info.package, client_info.version)
177        printer()
178        printer('syntax = "proto2";')
179        printer('package %s;', self.__package)
180        printer('import "%s";', client_info.messages_proto_file_name)
181        printer()
182        for name, method_info_map in self.__service_method_info_map.items():
183            self.__WriteProtoServiceDeclaration(printer, name, method_info_map)
184
185    def WriteFile(self, printer):
186        """Write the services in this registry to out."""
187        self.Validate()
188        client_info = self.__client_info
189        printer('"""Generated client library for %s version %s."""',
190                client_info.package, client_info.version)
191        printer('# NOTE: This file is autogenerated and should not be edited '
192                'by hand.')
193        printer('from %s import base_api', self.__base_files_package)
194        import_prefix = ''
195        printer('%simport %s as messages', import_prefix,
196                client_info.messages_rule_name)
197        printer()
198        printer()
199        printer('class %s(base_api.BaseApiClient):',
200                client_info.client_class_name)
201        with printer.Indent():
202            printer(
203                '"""Generated client library for service %s version %s."""',
204                client_info.package, client_info.version)
205            printer()
206            printer('MESSAGES_MODULE = messages')
207            printer()
208            client_info_items = client_info._asdict(
209            ).items()  # pylint:disable=protected-access
210            for attr, val in client_info_items:
211                if attr == 'scopes' and not val:
212                    val = ['https://www.googleapis.com/auth/userinfo.email']
213                printer('_%s = %r' % (attr.upper(), val))
214            printer()
215            printer("def __init__(self, url='', credentials=None,")
216            with printer.Indent(indent='             '):
217                printer('get_credentials=True, http=None, model=None,')
218                printer('log_request=False, log_response=False,')
219                printer('credentials_args=None, default_global_params=None,')
220                printer('additional_http_headers=None):')
221            with printer.Indent():
222                printer('"""Create a new %s handle."""', client_info.package)
223                printer('url = url or %r', self.__base_url)
224                printer(
225                    'super(%s, self).__init__(', client_info.client_class_name)
226                printer('    url, credentials=credentials,')
227                printer('    get_credentials=get_credentials, http=http, '
228                        'model=model,')
229                printer('    log_request=log_request, '
230                        'log_response=log_response,')
231                printer('    credentials_args=credentials_args,')
232                printer('    default_global_params=default_global_params,')
233                printer('    additional_http_headers=additional_http_headers)')
234                for name in self.__service_method_info_map.keys():
235                    printer('self.%s = self.%s(self)',
236                            name, self.__GetServiceClassName(name))
237            for name, method_info in self.__service_method_info_map.items():
238                self.__WriteSingleService(
239                    printer, name, method_info, client_info.client_class_name)
240
241    def __RegisterService(self, service_name, method_info_map):
242        if service_name in self.__service_method_info_map:
243            raise ValueError(
244                'Attempt to re-register descriptor %s' % service_name)
245        self.__service_method_info_map[service_name] = method_info_map
246
247    def __CreateRequestType(self, method_description, body_type=None):
248        """Create a request type for this method."""
249        schema = {}
250        schema['id'] = self.__names.ClassName('%sRequest' % (
251            self.__names.ClassName(method_description['id'], separator='.'),))
252        schema['type'] = 'object'
253        schema['properties'] = collections.OrderedDict()
254        if 'parameterOrder' not in method_description:
255            ordered_parameters = list(method_description.get('parameters', []))
256        else:
257            ordered_parameters = method_description['parameterOrder'][:]
258            for k in method_description['parameters']:
259                if k not in ordered_parameters:
260                    ordered_parameters.append(k)
261        for parameter_name in ordered_parameters:
262            field_name = self.__names.CleanName(parameter_name)
263            field = dict(method_description['parameters'][parameter_name])
264            if 'type' not in field:
265                raise ValueError('No type found in parameter %s' % field)
266            schema['properties'][field_name] = field
267        if body_type is not None:
268            body_field_name = self.__GetRequestField(
269                method_description, body_type)
270            if body_field_name in schema['properties']:
271                raise ValueError('Failed to normalize request resource name')
272            if 'description' not in body_type:
273                body_type['description'] = (
274                    'A %s resource to be passed as the request body.' % (
275                        self.__GetRequestType(body_type),))
276            schema['properties'][body_field_name] = body_type
277        self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
278        return schema['id']
279
280    def __CreateVoidResponseType(self, method_description):
281        """Create an empty response type."""
282        schema = {}
283        method_name = self.__names.ClassName(
284            method_description['id'], separator='.')
285        schema['id'] = self.__names.ClassName('%sResponse' % method_name)
286        schema['type'] = 'object'
287        schema['description'] = 'An empty %s response.' % method_name
288        self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
289        return schema['id']
290
291    def __NeedRequestType(self, method_description, request_type):
292        """Determine if this method needs a new request type created."""
293        if not request_type:
294            return True
295        method_id = method_description.get('id', '')
296        if method_id in self.__unelidable_request_methods:
297            return True
298        message = self.__message_registry.LookupDescriptorOrDie(request_type)
299        if message is None:
300            return True
301        field_names = [x.name for x in message.fields]
302        parameters = method_description.get('parameters', {})
303        for param_name, param_info in parameters.items():
304            if (param_info.get('location') != 'path' or
305                    self.__names.CleanName(param_name) not in field_names):
306                break
307        else:
308            return False
309        return True
310
311    def __MaxSizeToInt(self, max_size):
312        """Convert max_size to an int."""
313        size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size)
314        if size_groups is None:
315            raise ValueError('Could not parse maxSize')
316        size, unit = size_groups.group('size', 'unit')
317        shift = 0
318        if unit is not None:
319            unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40}
320            shift = unit_dict.get(unit.upper())
321            if shift is None:
322                raise ValueError('Unknown unit %s' % unit)
323        return int(size) * (1 << shift)
324
325    def __ComputeUploadConfig(self, media_upload_config, method_id):
326        """Fill out the upload config for this method."""
327        config = base_api.ApiUploadInfo()
328        if 'maxSize' in media_upload_config:
329            config.max_size = self.__MaxSizeToInt(
330                media_upload_config['maxSize'])
331        if 'accept' not in media_upload_config:
332            logging.warn(
333                'No accept types found for upload configuration in '
334                'method %s, using */*', method_id)
335        config.accept.extend([
336            str(a) for a in media_upload_config.get('accept', '*/*')])
337
338        for accept_pattern in config.accept:
339            if not _MIME_PATTERN_RE.match(accept_pattern):
340                logging.warn('Unexpected MIME type: %s', accept_pattern)
341        protocols = media_upload_config.get('protocols', {})
342        for protocol in ('simple', 'resumable'):
343            media = protocols.get(protocol, {})
344            for attr in ('multipart', 'path'):
345                if attr in media:
346                    setattr(config, '%s_%s' % (protocol, attr), media[attr])
347        return config
348
349    def __ComputeMethodInfo(self, method_description, request, response,
350                            request_field):
351        """Compute the base_api.ApiMethodInfo for this method."""
352        relative_path = self.__names.NormalizeRelativePath(
353            ''.join((self.__base_path, method_description['path'])))
354        method_id = method_description['id']
355        ordered_params = []
356        for param_name in method_description.get('parameterOrder', []):
357            param_info = method_description['parameters'][param_name]
358            if param_info.get('required', False):
359                ordered_params.append(param_name)
360        method_info = base_api.ApiMethodInfo(
361            relative_path=relative_path,
362            method_id=method_id,
363            http_method=method_description['httpMethod'],
364            description=util.CleanDescription(
365                method_description.get('description', '')),
366            query_params=[],
367            path_params=[],
368            ordered_params=ordered_params,
369            request_type_name=self.__names.ClassName(request),
370            response_type_name=self.__names.ClassName(response),
371            request_field=request_field,
372        )
373        if method_description.get('supportsMediaUpload', False):
374            method_info.upload_config = self.__ComputeUploadConfig(
375                method_description.get('mediaUpload'), method_id)
376        method_info.supports_download = method_description.get(
377            'supportsMediaDownload', False)
378        self.__all_scopes.update(method_description.get('scopes', ()))
379        for param, desc in method_description.get('parameters', {}).items():
380            param = self.__names.CleanName(param)
381            location = desc['location']
382            if location == 'query':
383                method_info.query_params.append(param)
384            elif location == 'path':
385                method_info.path_params.append(param)
386            else:
387                raise ValueError(
388                    'Unknown parameter location %s for parameter %s' % (
389                        location, param))
390        method_info.path_params.sort()
391        method_info.query_params.sort()
392        return method_info
393
394    def __BodyFieldName(self, body_type):
395        if body_type is None:
396            return ''
397        return self.__names.FieldName(body_type['$ref'])
398
399    def __GetRequestType(self, body_type):
400        return self.__names.ClassName(body_type.get('$ref'))
401
402    def __GetRequestField(self, method_description, body_type):
403        """Determine the request field for this method."""
404        body_field_name = self.__BodyFieldName(body_type)
405        if body_field_name in method_description.get('parameters', {}):
406            body_field_name = self.__names.FieldName(
407                '%s_resource' % body_field_name)
408        # It's exceedingly unlikely that we'd get two name collisions, which
409        # means it's bound to happen at some point.
410        while body_field_name in method_description.get('parameters', {}):
411            body_field_name = self.__names.FieldName(
412                '%s_body' % body_field_name)
413        return body_field_name
414
415    def AddServiceFromResource(self, service_name, methods):
416        """Add a new service named service_name with the given methods."""
417        method_descriptions = methods.get('methods', {})
418        method_info_map = collections.OrderedDict()
419        items = sorted(method_descriptions.items())
420        for method_name, method_description in items:
421            method_name = self.__names.MethodName(method_name)
422
423            # NOTE: According to the discovery document, if the request or
424            # response is present, it will simply contain a `$ref`.
425            body_type = method_description.get('request')
426            if body_type is None:
427                request_type = None
428            else:
429                request_type = self.__GetRequestType(body_type)
430            if self.__NeedRequestType(method_description, request_type):
431                request = self.__CreateRequestType(
432                    method_description, body_type=body_type)
433                request_field = self.__GetRequestField(
434                    method_description, body_type)
435            else:
436                request = request_type
437                request_field = base_api.REQUEST_IS_BODY
438
439            if 'response' in method_description:
440                response = method_description['response']['$ref']
441            else:
442                response = self.__CreateVoidResponseType(method_description)
443
444            method_info_map[method_name] = self.__ComputeMethodInfo(
445                method_description, request, response, request_field)
446            self.__command_registry.AddCommandForMethod(
447                service_name, method_name, method_info_map[method_name],
448                request, response)
449
450        nested_services = methods.get('resources', {})
451        services = sorted(nested_services.items())
452        for subservice_name, submethods in services:
453            new_service_name = '%s_%s' % (service_name, subservice_name)
454            self.AddServiceFromResource(new_service_name, submethods)
455
456        self.__RegisterService(service_name, method_info_map)
457