1#!/usr/bin/env python 2"""Assorted utilities shared between parts of apitools.""" 3from __future__ import print_function 4 5import collections 6import contextlib 7import json 8import keyword 9import logging 10import os 11import re 12 13import six 14import six.moves.urllib.error as urllib_error 15import six.moves.urllib.request as urllib_request 16 17 18class Error(Exception): 19 20 """Base error for apitools generation.""" 21 22 23class CommunicationError(Error): 24 25 """Error in network communication.""" 26 27 28def _SortLengthFirstKey(a): 29 return -len(a), a 30 31 32class Names(object): 33 34 """Utility class for cleaning and normalizing names in a fixed style.""" 35 DEFAULT_NAME_CONVENTION = 'LOWER_CAMEL' 36 NAME_CONVENTIONS = ['LOWER_CAMEL', 'LOWER_WITH_UNDER', 'NONE'] 37 38 def __init__(self, strip_prefixes, 39 name_convention=None, 40 capitalize_enums=False): 41 self.__strip_prefixes = sorted(strip_prefixes, key=_SortLengthFirstKey) 42 self.__name_convention = ( 43 name_convention or self.DEFAULT_NAME_CONVENTION) 44 self.__capitalize_enums = capitalize_enums 45 46 @staticmethod 47 def __FromCamel(name, separator='_'): 48 name = re.sub(r'([a-z0-9])([A-Z])', r'\1%s\2' % separator, name) 49 return name.lower() 50 51 @staticmethod 52 def __ToCamel(name, separator='_'): 53 # TODO(craigcitro): Consider what to do about leading or trailing 54 # underscores (such as `_refValue` in discovery). 55 return ''.join(s[0:1].upper() + s[1:] for s in name.split(separator)) 56 57 @staticmethod 58 def __ToLowerCamel(name, separator='_'): 59 name = Names.__ToCamel(name, separator=separator) 60 return name[0].lower() + name[1:] 61 62 def __StripName(self, name): 63 """Strip strip_prefix entries from name.""" 64 if not name: 65 return name 66 for prefix in self.__strip_prefixes: 67 if name.startswith(prefix): 68 return name[len(prefix):] 69 return name 70 71 @staticmethod 72 def CleanName(name): 73 """Perform generic name cleaning.""" 74 name = re.sub('[^_A-Za-z0-9]', '_', name) 75 if name[0].isdigit(): 76 name = '_%s' % name 77 while keyword.iskeyword(name): 78 name = '%s_' % name 79 # If we end up with __ as a prefix, we'll run afoul of python 80 # field renaming, so we manually correct for it. 81 if name.startswith('__'): 82 name = 'f%s' % name 83 return name 84 85 @staticmethod 86 def NormalizeRelativePath(path): 87 """Normalize camelCase entries in path.""" 88 path_components = path.split('/') 89 normalized_components = [] 90 for component in path_components: 91 if re.match(r'{[A-Za-z0-9_]+}$', component): 92 normalized_components.append( 93 '{%s}' % Names.CleanName(component[1:-1])) 94 else: 95 normalized_components.append(component) 96 return '/'.join(normalized_components) 97 98 def NormalizeEnumName(self, enum_name): 99 if self.__capitalize_enums: 100 enum_name = enum_name.upper() 101 return self.CleanName(enum_name) 102 103 def ClassName(self, name, separator='_'): 104 """Generate a valid class name from name.""" 105 # TODO(craigcitro): Get rid of this case here and in MethodName. 106 if name is None: 107 return name 108 # TODO(craigcitro): This is a hack to handle the case of specific 109 # protorpc class names; clean this up. 110 if name.startswith('protorpc.') or name.startswith('message_types.'): 111 return name 112 name = self.__StripName(name) 113 name = self.__ToCamel(name, separator=separator) 114 return self.CleanName(name) 115 116 def MethodName(self, name, separator='_'): 117 """Generate a valid method name from name.""" 118 if name is None: 119 return None 120 name = Names.__ToCamel(name, separator=separator) 121 return Names.CleanName(name) 122 123 def FieldName(self, name): 124 """Generate a valid field name from name.""" 125 # TODO(craigcitro): We shouldn't need to strip this name, but some 126 # of the service names here are excessive. Fix the API and then 127 # remove this. 128 name = self.__StripName(name) 129 if self.__name_convention == 'LOWER_CAMEL': 130 name = Names.__ToLowerCamel(name) 131 elif self.__name_convention == 'LOWER_WITH_UNDER': 132 name = Names.__FromCamel(name) 133 return Names.CleanName(name) 134 135 136@contextlib.contextmanager 137def Chdir(dirname, create=True): 138 if not os.path.exists(dirname): 139 if not create: 140 raise OSError('Cannot find directory %s' % dirname) 141 else: 142 os.mkdir(dirname) 143 previous_directory = os.getcwd() 144 os.chdir(dirname) 145 yield 146 os.chdir(previous_directory) 147 148 149def NormalizeVersion(version): 150 # Currently, '.' is the only character that might cause us trouble. 151 return version.replace('.', '_') 152 153 154class ClientInfo(collections.namedtuple('ClientInfo', ( 155 'package', 'scopes', 'version', 'client_id', 'client_secret', 156 'user_agent', 'client_class_name', 'url_version', 'api_key'))): 157 158 """Container for client-related info and names.""" 159 160 @classmethod 161 def Create(cls, discovery_doc, 162 scope_ls, client_id, client_secret, user_agent, names, api_key): 163 """Create a new ClientInfo object from a discovery document.""" 164 scopes = set( 165 discovery_doc.get('auth', {}).get('oauth2', {}).get('scopes', {})) 166 scopes.update(scope_ls) 167 client_info = { 168 'package': discovery_doc['name'], 169 'version': NormalizeVersion(discovery_doc['version']), 170 'url_version': discovery_doc['version'], 171 'scopes': sorted(list(scopes)), 172 'client_id': client_id, 173 'client_secret': client_secret, 174 'user_agent': user_agent, 175 'api_key': api_key, 176 } 177 client_class_name = '%s%s' % ( 178 names.ClassName(client_info['package']), 179 names.ClassName(client_info['version'])) 180 client_info['client_class_name'] = client_class_name 181 return cls(**client_info) 182 183 @property 184 def default_directory(self): 185 return self.package 186 187 @property 188 def cli_rule_name(self): 189 return '%s_%s' % (self.package, self.version) 190 191 @property 192 def cli_file_name(self): 193 return '%s.py' % self.cli_rule_name 194 195 @property 196 def client_rule_name(self): 197 return '%s_%s_client' % (self.package, self.version) 198 199 @property 200 def client_file_name(self): 201 return '%s.py' % self.client_rule_name 202 203 @property 204 def messages_rule_name(self): 205 return '%s_%s_messages' % (self.package, self.version) 206 207 @property 208 def services_rule_name(self): 209 return '%s_%s_services' % (self.package, self.version) 210 211 @property 212 def messages_file_name(self): 213 return '%s.py' % self.messages_rule_name 214 215 @property 216 def messages_proto_file_name(self): 217 return '%s.proto' % self.messages_rule_name 218 219 @property 220 def services_proto_file_name(self): 221 return '%s.proto' % self.services_rule_name 222 223 224def GetPackage(path): 225 path_components = path.split(os.path.sep) 226 return '.'.join(path_components) 227 228 229def CleanDescription(description): 230 """Return a version of description safe for printing in a docstring.""" 231 if not isinstance(description, six.string_types): 232 return description 233 return description.replace('"""', '" " "') 234 235 236class SimplePrettyPrinter(object): 237 238 """Simple pretty-printer that supports an indent contextmanager.""" 239 240 def __init__(self, out): 241 self.__out = out 242 self.__indent = '' 243 self.__skip = False 244 self.__comment_context = False 245 246 @property 247 def indent(self): 248 return self.__indent 249 250 def CalculateWidth(self, max_width=78): 251 return max_width - len(self.indent) 252 253 @contextlib.contextmanager 254 def Indent(self, indent=' '): 255 previous_indent = self.__indent 256 self.__indent = '%s%s' % (previous_indent, indent) 257 yield 258 self.__indent = previous_indent 259 260 @contextlib.contextmanager 261 def CommentContext(self): 262 """Print without any argument formatting.""" 263 old_context = self.__comment_context 264 self.__comment_context = True 265 yield 266 self.__comment_context = old_context 267 268 def __call__(self, *args): 269 if self.__comment_context and args[1:]: 270 raise Error('Cannot do string interpolation in comment context') 271 if args and args[0]: 272 if not self.__comment_context: 273 line = (args[0] % args[1:]).rstrip() 274 else: 275 line = args[0].rstrip() 276 line = line.encode('ascii', 'backslashreplace') 277 print('%s%s' % (self.__indent, line), file=self.__out) 278 else: 279 print('', file=self.__out) 280 281 282def NormalizeDiscoveryUrl(discovery_url): 283 """Expands a few abbreviations into full discovery urls.""" 284 if discovery_url.startswith('http'): 285 return discovery_url 286 elif '.' not in discovery_url: 287 raise ValueError('Unrecognized value "%s" for discovery url') 288 api_name, _, api_version = discovery_url.partition('.') 289 return 'https://www.googleapis.com/discovery/v1/apis/%s/%s/rest' % ( 290 api_name, api_version) 291 292 293def FetchDiscoveryDoc(discovery_url, retries=5): 294 """Fetch the discovery document at the given url.""" 295 discovery_url = NormalizeDiscoveryUrl(discovery_url) 296 discovery_doc = None 297 last_exception = None 298 for _ in range(retries): 299 try: 300 discovery_doc = json.loads( 301 urllib_request.urlopen(discovery_url).read()) 302 break 303 except (urllib_error.HTTPError, 304 urllib_error.URLError) as last_exception: 305 logging.warning( 306 'Attempting to fetch discovery doc again after "%s"', 307 last_exception) 308 if discovery_doc is None: 309 raise CommunicationError( 310 'Could not find discovery doc at url "%s": %s' % ( 311 discovery_url, last_exception)) 312 return discovery_doc 313