1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Assorted utilities shared between parts of apitools."""
18from __future__ import print_function
19from __future__ import unicode_literals
20
21import collections
22import contextlib
23import gzip
24import json
25import keyword
26import logging
27import os
28import re
29import tempfile
30
31import six
32from six.moves import urllib_parse
33import six.moves.urllib.error as urllib_error
34import six.moves.urllib.request as urllib_request
35
36
37class Error(Exception):
38
39    """Base error for apitools generation."""
40
41
42class CommunicationError(Error):
43
44    """Error in network communication."""
45
46
47def _SortLengthFirstKey(a):
48    return -len(a), a
49
50
51class Names(object):
52
53    """Utility class for cleaning and normalizing names in a fixed style."""
54    DEFAULT_NAME_CONVENTION = 'LOWER_CAMEL'
55    NAME_CONVENTIONS = ['LOWER_CAMEL', 'LOWER_WITH_UNDER', 'NONE']
56
57    def __init__(self, strip_prefixes,
58                 name_convention=None,
59                 capitalize_enums=False):
60        self.__strip_prefixes = sorted(strip_prefixes, key=_SortLengthFirstKey)
61        self.__name_convention = (
62            name_convention or self.DEFAULT_NAME_CONVENTION)
63        self.__capitalize_enums = capitalize_enums
64
65    @staticmethod
66    def __FromCamel(name, separator='_'):
67        name = re.sub(r'([a-z0-9])([A-Z])', r'\1%s\2' % separator, name)
68        return name.lower()
69
70    @staticmethod
71    def __ToCamel(name, separator='_'):
72        # TODO(craigcitro): Consider what to do about leading or trailing
73        # underscores (such as `_refValue` in discovery).
74        return ''.join(s[0:1].upper() + s[1:] for s in name.split(separator))
75
76    @staticmethod
77    def __ToLowerCamel(name, separator='_'):
78        name = Names.__ToCamel(name, separator=separator)
79        return name[0].lower() + name[1:]
80
81    def __StripName(self, name):
82        """Strip strip_prefix entries from name."""
83        if not name:
84            return name
85        for prefix in self.__strip_prefixes:
86            if name.startswith(prefix):
87                return name[len(prefix):]
88        return name
89
90    @staticmethod
91    def CleanName(name):
92        """Perform generic name cleaning."""
93        name = re.sub('[^_A-Za-z0-9]', '_', name)
94        if name[0].isdigit():
95            name = '_%s' % name
96        while keyword.iskeyword(name) or name == 'exec':
97            name = '%s_' % name
98        # If we end up with __ as a prefix, we'll run afoul of python
99        # field renaming, so we manually correct for it.
100        if name.startswith('__'):
101            name = 'f%s' % name
102        return name
103
104    @staticmethod
105    def NormalizeRelativePath(path):
106        """Normalize camelCase entries in path."""
107        path_components = path.split('/')
108        normalized_components = []
109        for component in path_components:
110            if re.match(r'{[A-Za-z0-9_]+}$', component):
111                normalized_components.append(
112                    '{%s}' % Names.CleanName(component[1:-1]))
113            else:
114                normalized_components.append(component)
115        return '/'.join(normalized_components)
116
117    def NormalizeEnumName(self, enum_name):
118        if self.__capitalize_enums:
119            enum_name = enum_name.upper()
120        return self.CleanName(enum_name)
121
122    def ClassName(self, name, separator='_'):
123        """Generate a valid class name from name."""
124        # TODO(craigcitro): Get rid of this case here and in MethodName.
125        if name is None:
126            return name
127        # TODO(craigcitro): This is a hack to handle the case of specific
128        # protorpc class names; clean this up.
129        if name.startswith(('protorpc.', 'message_types.',
130                            'apitools.base.protorpclite.',
131                            'apitools.base.protorpclite.message_types.')):
132            return name
133        name = self.__StripName(name)
134        name = self.__ToCamel(name, separator=separator)
135        return self.CleanName(name)
136
137    def MethodName(self, name, separator='_'):
138        """Generate a valid method name from name."""
139        if name is None:
140            return None
141        name = Names.__ToCamel(name, separator=separator)
142        return Names.CleanName(name)
143
144    def FieldName(self, name):
145        """Generate a valid field name from name."""
146        # TODO(craigcitro): We shouldn't need to strip this name, but some
147        # of the service names here are excessive. Fix the API and then
148        # remove this.
149        name = self.__StripName(name)
150        if self.__name_convention == 'LOWER_CAMEL':
151            name = Names.__ToLowerCamel(name)
152        elif self.__name_convention == 'LOWER_WITH_UNDER':
153            name = Names.__FromCamel(name)
154        return Names.CleanName(name)
155
156
157@contextlib.contextmanager
158def Chdir(dirname, create=True):
159    if not os.path.exists(dirname):
160        if not create:
161            raise OSError('Cannot find directory %s' % dirname)
162        else:
163            os.mkdir(dirname)
164    previous_directory = os.getcwd()
165    try:
166        os.chdir(dirname)
167        yield
168    finally:
169        os.chdir(previous_directory)
170
171
172def NormalizeVersion(version):
173    # Currently, '.' is the only character that might cause us trouble.
174    return version.replace('.', '_')
175
176
177def _ComputePaths(package, version, root_url, service_path):
178    """Compute the base url and base path.
179
180    Attributes:
181      package: name field of the discovery, i.e. 'storage' for storage service.
182      version: version of the service, i.e. 'v1'.
183      root_url: root url of the service, i.e. 'https://www.googleapis.com/'.
184      service_path: path of the service under the rool url, i.e. 'storage/v1/'.
185
186    Returns:
187      base url: string, base url of the service,
188        'https://www.googleapis.com/storage/v1/' for the storage service.
189      base path: string, common prefix of service endpoints after the base url.
190    """
191    full_path = urllib_parse.urljoin(root_url, service_path)
192    api_path_component = '/'.join((package, version, ''))
193    if api_path_component not in full_path:
194        return full_path, ''
195    prefix, _, suffix = full_path.rpartition(api_path_component)
196    return prefix + api_path_component, suffix
197
198
199class ClientInfo(collections.namedtuple('ClientInfo', (
200        'package', 'scopes', 'version', 'client_id', 'client_secret',
201        'user_agent', 'client_class_name', 'url_version', 'api_key',
202        'base_url', 'base_path', 'mtls_base_url'))):
203
204    """Container for client-related info and names."""
205
206    @classmethod
207    def Create(cls, discovery_doc,
208               scope_ls, client_id, client_secret, user_agent, names, api_key):
209        """Create a new ClientInfo object from a discovery document."""
210        scopes = set(
211            discovery_doc.get('auth', {}).get('oauth2', {}).get('scopes', {}))
212        scopes.update(scope_ls)
213        package = discovery_doc['name']
214        url_version = discovery_doc['version']
215        base_url, base_path = _ComputePaths(package, url_version,
216                                            discovery_doc['rootUrl'],
217                                            discovery_doc['servicePath'])
218
219        mtls_root_url = discovery_doc.get('mtlsRootUrl', '')
220        mtls_base_url = ''
221        if mtls_root_url:
222            mtls_base_url, _ = _ComputePaths(package, url_version,
223                                             mtls_root_url,
224                                             discovery_doc['servicePath'])
225
226        client_info = {
227            'package': package,
228            'version': NormalizeVersion(discovery_doc['version']),
229            'url_version': url_version,
230            'scopes': sorted(list(scopes)),
231            'client_id': client_id,
232            'client_secret': client_secret,
233            'user_agent': user_agent,
234            'api_key': api_key,
235            'base_url': base_url,
236            'base_path': base_path,
237            'mtls_base_url': mtls_base_url,
238        }
239        client_class_name = '%s%s' % (
240            names.ClassName(client_info['package']),
241            names.ClassName(client_info['version']))
242        client_info['client_class_name'] = client_class_name
243        return cls(**client_info)
244
245    @property
246    def default_directory(self):
247        return self.package
248
249    @property
250    def client_rule_name(self):
251        return '%s_%s_client' % (self.package, self.version)
252
253    @property
254    def client_file_name(self):
255        return '%s.py' % self.client_rule_name
256
257    @property
258    def messages_rule_name(self):
259        return '%s_%s_messages' % (self.package, self.version)
260
261    @property
262    def services_rule_name(self):
263        return '%s_%s_services' % (self.package, self.version)
264
265    @property
266    def messages_file_name(self):
267        return '%s.py' % self.messages_rule_name
268
269    @property
270    def messages_proto_file_name(self):
271        return '%s.proto' % self.messages_rule_name
272
273    @property
274    def services_proto_file_name(self):
275        return '%s.proto' % self.services_rule_name
276
277
278def ReplaceHomoglyphs(s):
279    """Returns s with unicode homoglyphs replaced by ascii equivalents."""
280    homoglyphs = {
281        '\xa0': ' ',  #   ?
282        '\u00e3': '',  # TODO(gsfowler) drop after .proto spurious char elided
283        '\u00a0': ' ',  #   ?
284        '\u00a9': '(C)',  # COPYRIGHT SIGN (would you believe "asciiglyph"?)
285        '\u00ae': '(R)',  # REGISTERED SIGN (would you believe "asciiglyph"?)
286        '\u2014': '-',  # EM DASH
287        '\u2018': "'",  # LEFT SINGLE QUOTATION MARK
288        '\u2019': "'",  # RIGHT SINGLE QUOTATION MARK
289        '\u201c': '"',  # LEFT DOUBLE QUOTATION MARK
290        '\u201d': '"',  # RIGHT DOUBLE QUOTATION MARK
291        '\u2026': '...',  # HORIZONTAL ELLIPSIS
292        '\u2e3a': '-',  # TWO-EM DASH
293    }
294
295    def _ReplaceOne(c):
296        """Returns the homoglyph or escaped replacement for c."""
297        equiv = homoglyphs.get(c)
298        if equiv is not None:
299            return equiv
300        try:
301            c.encode('ascii')
302            return c
303        except UnicodeError:
304            pass
305        try:
306            return c.encode('unicode-escape').decode('ascii')
307        except UnicodeError:
308            return '?'
309
310    return ''.join([_ReplaceOne(c) for c in s])
311
312
313def CleanDescription(description):
314    """Return a version of description safe for printing in a docstring."""
315    if not isinstance(description, six.string_types):
316        return description
317    if six.PY3:
318        # https://docs.python.org/3/reference/lexical_analysis.html#index-18
319        description = description.replace('\\N', '\\\\N')
320        description = description.replace('\\u', '\\\\u')
321        description = description.replace('\\U', '\\\\U')
322    description = ReplaceHomoglyphs(description)
323    return description.replace('"""', '" " "')
324
325
326class SimplePrettyPrinter(object):
327
328    """Simple pretty-printer that supports an indent contextmanager."""
329
330    def __init__(self, out):
331        self.__out = out
332        self.__indent = ''
333        self.__skip = False
334        self.__comment_context = False
335
336    @property
337    def indent(self):
338        return self.__indent
339
340    def CalculateWidth(self, max_width=78):
341        return max_width - len(self.indent)
342
343    @contextlib.contextmanager
344    def Indent(self, indent='  '):
345        previous_indent = self.__indent
346        self.__indent = '%s%s' % (previous_indent, indent)
347        yield
348        self.__indent = previous_indent
349
350    @contextlib.contextmanager
351    def CommentContext(self):
352        """Print without any argument formatting."""
353        old_context = self.__comment_context
354        self.__comment_context = True
355        yield
356        self.__comment_context = old_context
357
358    def __call__(self, *args):
359        if self.__comment_context and args[1:]:
360            raise Error('Cannot do string interpolation in comment context')
361        if args and args[0]:
362            if not self.__comment_context:
363                line = (args[0] % args[1:]).rstrip()
364            else:
365                line = args[0].rstrip()
366            line = ReplaceHomoglyphs(line)
367            try:
368                print('%s%s' % (self.__indent, line), file=self.__out)
369            except UnicodeEncodeError:
370                line = line.encode('ascii', 'backslashreplace').decode('ascii')
371                print('%s%s' % (self.__indent, line), file=self.__out)
372        else:
373            print('', file=self.__out)
374
375
376def _NormalizeDiscoveryUrls(discovery_url):
377    """Expands a few abbreviations into full discovery urls."""
378    if discovery_url.startswith('http'):
379        return [discovery_url]
380    elif '.' not in discovery_url:
381        raise ValueError('Unrecognized value "%s" for discovery url')
382    api_name, _, api_version = discovery_url.partition('.')
383    return [
384        'https://www.googleapis.com/discovery/v1/apis/%s/%s/rest' % (
385            api_name, api_version),
386        'https://%s.googleapis.com/$discovery/rest?version=%s' % (
387            api_name, api_version),
388    ]
389
390
391def _Gunzip(gzipped_content):
392    """Returns gunzipped content from gzipped contents."""
393    f = tempfile.NamedTemporaryFile(suffix='gz', mode='w+b', delete=False)
394    try:
395        f.write(gzipped_content)
396        f.close()  # force file synchronization
397        with gzip.open(f.name, 'rb') as h:
398            decompressed_content = h.read()
399        return decompressed_content
400    finally:
401        os.unlink(f.name)
402
403
404def _GetURLContent(url):
405    """Download and return the content of URL."""
406    response = urllib_request.urlopen(url)
407    encoding = response.info().get('Content-Encoding')
408    if encoding == 'gzip':
409        content = _Gunzip(response.read())
410    else:
411        content = response.read()
412    return content
413
414
415def FetchDiscoveryDoc(discovery_url, retries=5):
416    """Fetch the discovery document at the given url."""
417    discovery_urls = _NormalizeDiscoveryUrls(discovery_url)
418    discovery_doc = None
419    last_exception = None
420    for url in discovery_urls:
421        for _ in range(retries):
422            try:
423                content = _GetURLContent(url)
424                if isinstance(content, bytes):
425                    content = content.decode('utf8')
426                discovery_doc = json.loads(content)
427                if discovery_doc:
428                    return discovery_doc
429            except (urllib_error.HTTPError, urllib_error.URLError) as e:
430                logging.info(
431                    'Attempting to fetch discovery doc again after "%s"', e)
432                last_exception = e
433    if discovery_doc is None:
434        raise CommunicationError(
435            'Could not find discovery doc at any of %s: %s' % (
436                discovery_urls, last_exception))
437