1#!/usr/bin/env python
2"""Assorted utilities shared between parts of apitools."""
3from __future__ import print_function
4
5import collections
6import contextlib
7import json
8import keyword
9import logging
10import os
11import re
12
13import six
14import six.moves.urllib.error as urllib_error
15import six.moves.urllib.request as urllib_request
16
17
18class Error(Exception):
19
20    """Base error for apitools generation."""
21
22
23class CommunicationError(Error):
24
25    """Error in network communication."""
26
27
28def _SortLengthFirstKey(a):
29    return -len(a), a
30
31
32class Names(object):
33
34    """Utility class for cleaning and normalizing names in a fixed style."""
35    DEFAULT_NAME_CONVENTION = 'LOWER_CAMEL'
36    NAME_CONVENTIONS = ['LOWER_CAMEL', 'LOWER_WITH_UNDER', 'NONE']
37
38    def __init__(self, strip_prefixes,
39                 name_convention=None,
40                 capitalize_enums=False):
41        self.__strip_prefixes = sorted(strip_prefixes, key=_SortLengthFirstKey)
42        self.__name_convention = (
43            name_convention or self.DEFAULT_NAME_CONVENTION)
44        self.__capitalize_enums = capitalize_enums
45
46    @staticmethod
47    def __FromCamel(name, separator='_'):
48        name = re.sub(r'([a-z0-9])([A-Z])', r'\1%s\2' % separator, name)
49        return name.lower()
50
51    @staticmethod
52    def __ToCamel(name, separator='_'):
53        # TODO(craigcitro): Consider what to do about leading or trailing
54        # underscores (such as `_refValue` in discovery).
55        return ''.join(s[0:1].upper() + s[1:] for s in name.split(separator))
56
57    @staticmethod
58    def __ToLowerCamel(name, separator='_'):
59        name = Names.__ToCamel(name, separator=separator)
60        return name[0].lower() + name[1:]
61
62    def __StripName(self, name):
63        """Strip strip_prefix entries from name."""
64        if not name:
65            return name
66        for prefix in self.__strip_prefixes:
67            if name.startswith(prefix):
68                return name[len(prefix):]
69        return name
70
71    @staticmethod
72    def CleanName(name):
73        """Perform generic name cleaning."""
74        name = re.sub('[^_A-Za-z0-9]', '_', name)
75        if name[0].isdigit():
76            name = '_%s' % name
77        while keyword.iskeyword(name):
78            name = '%s_' % name
79        # If we end up with __ as a prefix, we'll run afoul of python
80        # field renaming, so we manually correct for it.
81        if name.startswith('__'):
82            name = 'f%s' % name
83        return name
84
85    @staticmethod
86    def NormalizeRelativePath(path):
87        """Normalize camelCase entries in path."""
88        path_components = path.split('/')
89        normalized_components = []
90        for component in path_components:
91            if re.match(r'{[A-Za-z0-9_]+}$', component):
92                normalized_components.append(
93                    '{%s}' % Names.CleanName(component[1:-1]))
94            else:
95                normalized_components.append(component)
96        return '/'.join(normalized_components)
97
98    def NormalizeEnumName(self, enum_name):
99        if self.__capitalize_enums:
100            enum_name = enum_name.upper()
101        return self.CleanName(enum_name)
102
103    def ClassName(self, name, separator='_'):
104        """Generate a valid class name from name."""
105        # TODO(craigcitro): Get rid of this case here and in MethodName.
106        if name is None:
107            return name
108        # TODO(craigcitro): This is a hack to handle the case of specific
109        # protorpc class names; clean this up.
110        if name.startswith('protorpc.') or name.startswith('message_types.'):
111            return name
112        name = self.__StripName(name)
113        name = self.__ToCamel(name, separator=separator)
114        return self.CleanName(name)
115
116    def MethodName(self, name, separator='_'):
117        """Generate a valid method name from name."""
118        if name is None:
119            return None
120        name = Names.__ToCamel(name, separator=separator)
121        return Names.CleanName(name)
122
123    def FieldName(self, name):
124        """Generate a valid field name from name."""
125        # TODO(craigcitro): We shouldn't need to strip this name, but some
126        # of the service names here are excessive. Fix the API and then
127        # remove this.
128        name = self.__StripName(name)
129        if self.__name_convention == 'LOWER_CAMEL':
130            name = Names.__ToLowerCamel(name)
131        elif self.__name_convention == 'LOWER_WITH_UNDER':
132            name = Names.__FromCamel(name)
133        return Names.CleanName(name)
134
135
136@contextlib.contextmanager
137def Chdir(dirname, create=True):
138    if not os.path.exists(dirname):
139        if not create:
140            raise OSError('Cannot find directory %s' % dirname)
141        else:
142            os.mkdir(dirname)
143    previous_directory = os.getcwd()
144    os.chdir(dirname)
145    yield
146    os.chdir(previous_directory)
147
148
149def NormalizeVersion(version):
150    # Currently, '.' is the only character that might cause us trouble.
151    return version.replace('.', '_')
152
153
154class ClientInfo(collections.namedtuple('ClientInfo', (
155        'package', 'scopes', 'version', 'client_id', 'client_secret',
156        'user_agent', 'client_class_name', 'url_version', 'api_key'))):
157
158    """Container for client-related info and names."""
159
160    @classmethod
161    def Create(cls, discovery_doc,
162               scope_ls, client_id, client_secret, user_agent, names, api_key):
163        """Create a new ClientInfo object from a discovery document."""
164        scopes = set(
165            discovery_doc.get('auth', {}).get('oauth2', {}).get('scopes', {}))
166        scopes.update(scope_ls)
167        client_info = {
168            'package': discovery_doc['name'],
169            'version': NormalizeVersion(discovery_doc['version']),
170            'url_version': discovery_doc['version'],
171            'scopes': sorted(list(scopes)),
172            'client_id': client_id,
173            'client_secret': client_secret,
174            'user_agent': user_agent,
175            'api_key': api_key,
176        }
177        client_class_name = '%s%s' % (
178            names.ClassName(client_info['package']),
179            names.ClassName(client_info['version']))
180        client_info['client_class_name'] = client_class_name
181        return cls(**client_info)
182
183    @property
184    def default_directory(self):
185        return self.package
186
187    @property
188    def cli_rule_name(self):
189        return '%s_%s' % (self.package, self.version)
190
191    @property
192    def cli_file_name(self):
193        return '%s.py' % self.cli_rule_name
194
195    @property
196    def client_rule_name(self):
197        return '%s_%s_client' % (self.package, self.version)
198
199    @property
200    def client_file_name(self):
201        return '%s.py' % self.client_rule_name
202
203    @property
204    def messages_rule_name(self):
205        return '%s_%s_messages' % (self.package, self.version)
206
207    @property
208    def services_rule_name(self):
209        return '%s_%s_services' % (self.package, self.version)
210
211    @property
212    def messages_file_name(self):
213        return '%s.py' % self.messages_rule_name
214
215    @property
216    def messages_proto_file_name(self):
217        return '%s.proto' % self.messages_rule_name
218
219    @property
220    def services_proto_file_name(self):
221        return '%s.proto' % self.services_rule_name
222
223
224def GetPackage(path):
225    path_components = path.split(os.path.sep)
226    return '.'.join(path_components)
227
228
229def CleanDescription(description):
230    """Return a version of description safe for printing in a docstring."""
231    if not isinstance(description, six.string_types):
232        return description
233    return description.replace('"""', '" " "')
234
235
236class SimplePrettyPrinter(object):
237
238    """Simple pretty-printer that supports an indent contextmanager."""
239
240    def __init__(self, out):
241        self.__out = out
242        self.__indent = ''
243        self.__skip = False
244        self.__comment_context = False
245
246    @property
247    def indent(self):
248        return self.__indent
249
250    def CalculateWidth(self, max_width=78):
251        return max_width - len(self.indent)
252
253    @contextlib.contextmanager
254    def Indent(self, indent='  '):
255        previous_indent = self.__indent
256        self.__indent = '%s%s' % (previous_indent, indent)
257        yield
258        self.__indent = previous_indent
259
260    @contextlib.contextmanager
261    def CommentContext(self):
262        """Print without any argument formatting."""
263        old_context = self.__comment_context
264        self.__comment_context = True
265        yield
266        self.__comment_context = old_context
267
268    def __call__(self, *args):
269        if self.__comment_context and args[1:]:
270            raise Error('Cannot do string interpolation in comment context')
271        if args and args[0]:
272            if not self.__comment_context:
273                line = (args[0] % args[1:]).rstrip()
274            else:
275                line = args[0].rstrip()
276            line = line.encode('ascii', 'backslashreplace')
277            print('%s%s' % (self.__indent, line), file=self.__out)
278        else:
279            print('', file=self.__out)
280
281
282def NormalizeDiscoveryUrl(discovery_url):
283    """Expands a few abbreviations into full discovery urls."""
284    if discovery_url.startswith('http'):
285        return discovery_url
286    elif '.' not in discovery_url:
287        raise ValueError('Unrecognized value "%s" for discovery url')
288    api_name, _, api_version = discovery_url.partition('.')
289    return 'https://www.googleapis.com/discovery/v1/apis/%s/%s/rest' % (
290        api_name, api_version)
291
292
293def FetchDiscoveryDoc(discovery_url, retries=5):
294    """Fetch the discovery document at the given url."""
295    discovery_url = NormalizeDiscoveryUrl(discovery_url)
296    discovery_doc = None
297    last_exception = None
298    for _ in range(retries):
299        try:
300            discovery_doc = json.loads(
301                urllib_request.urlopen(discovery_url).read())
302            break
303        except (urllib_error.HTTPError,
304                urllib_error.URLError) as last_exception:
305            logging.warning(
306                'Attempting to fetch discovery doc again after "%s"',
307                last_exception)
308    if discovery_doc is None:
309        raise CommunicationError(
310            'Could not find discovery doc at url "%s": %s' % (
311                discovery_url, last_exception))
312    return discovery_doc
313