1#!/usr/bin/env python 2"""Message registry for apitools.""" 3 4import collections 5import contextlib 6import json 7 8from protorpc import descriptor 9from protorpc import messages 10import six 11 12from apitools.gen import extended_descriptor 13from apitools.gen import util 14 15TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant')) 16 17 18class MessageRegistry(object): 19 20 """Registry for message types. 21 22 This closely mirrors a messages.FileDescriptor, but adds additional 23 attributes (such as message and field descriptions) and some extra 24 code for validation and cycle detection. 25 """ 26 27 # Type information from these two maps comes from here: 28 # https://developers.google.com/discovery/v1/type-format 29 PRIMITIVE_TYPE_INFO_MAP = { 30 'string': TypeInfo(type_name='string', 31 variant=messages.StringField.DEFAULT_VARIANT), 32 'integer': TypeInfo(type_name='integer', 33 variant=messages.IntegerField.DEFAULT_VARIANT), 34 'boolean': TypeInfo(type_name='boolean', 35 variant=messages.BooleanField.DEFAULT_VARIANT), 36 'number': TypeInfo(type_name='number', 37 variant=messages.FloatField.DEFAULT_VARIANT), 38 'any': TypeInfo(type_name='extra_types.JsonValue', 39 variant=messages.Variant.MESSAGE), 40 } 41 42 PRIMITIVE_FORMAT_MAP = { 43 'int32': TypeInfo(type_name='integer', 44 variant=messages.Variant.INT32), 45 'uint32': TypeInfo(type_name='integer', 46 variant=messages.Variant.UINT32), 47 'int64': TypeInfo(type_name='string', 48 variant=messages.Variant.INT64), 49 'uint64': TypeInfo(type_name='string', 50 variant=messages.Variant.UINT64), 51 'double': TypeInfo(type_name='number', 52 variant=messages.Variant.DOUBLE), 53 'float': TypeInfo(type_name='number', 54 variant=messages.Variant.FLOAT), 55 'byte': TypeInfo(type_name='byte', 56 variant=messages.BytesField.DEFAULT_VARIANT), 57 'date': TypeInfo(type_name='extra_types.DateField', 58 variant=messages.Variant.STRING), 59 'date-time': TypeInfo( 60 type_name='protorpc.message_types.DateTimeMessage', 61 variant=messages.Variant.MESSAGE), 62 } 63 64 def __init__(self, client_info, names, description, 65 root_package_dir, base_files_package): 66 self.__names = names 67 self.__client_info = client_info 68 self.__package = client_info.package 69 self.__description = util.CleanDescription(description) 70 self.__root_package_dir = root_package_dir 71 self.__base_files_package = base_files_package 72 self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor( 73 package=self.__package, description=self.__description) 74 # Add required imports 75 self.__file_descriptor.additional_imports = [ 76 'from protorpc import messages as _messages', 77 ] 78 # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors. 79 self.__message_registry = collections.OrderedDict() 80 # A set of types that we're currently adding (for cycle detection). 81 self.__nascent_types = set() 82 # A set of types for which we've seen a reference but no 83 # definition; if this set is nonempty, validation fails. 84 self.__unknown_types = set() 85 # Used for tracking paths during message creation 86 self.__current_path = [] 87 # Where to register created messages 88 self.__current_env = self.__file_descriptor 89 # TODO(craigcitro): Add a `Finalize` method. 90 91 @property 92 def file_descriptor(self): 93 self.Validate() 94 return self.__file_descriptor 95 96 def WriteProtoFile(self, printer): 97 """Write the messages file to out as proto.""" 98 self.Validate() 99 extended_descriptor.WriteMessagesFile( 100 self.__file_descriptor, self.__package, self.__client_info.version, 101 printer) 102 103 def WriteFile(self, printer): 104 """Write the messages file to out.""" 105 self.Validate() 106 extended_descriptor.WritePythonFile( 107 self.__file_descriptor, self.__package, self.__client_info.version, 108 printer) 109 110 def Validate(self): 111 mysteries = self.__nascent_types or self.__unknown_types 112 if mysteries: 113 raise ValueError('Malformed MessageRegistry: %s' % mysteries) 114 115 def __ComputeFullName(self, name): 116 return '.'.join(map(six.text_type, self.__current_path[:] + [name])) 117 118 def __AddImport(self, new_import): 119 if new_import not in self.__file_descriptor.additional_imports: 120 self.__file_descriptor.additional_imports.append(new_import) 121 122 def __DeclareDescriptor(self, name): 123 self.__nascent_types.add(self.__ComputeFullName(name)) 124 125 def __RegisterDescriptor(self, new_descriptor): 126 """Register the given descriptor in this registry.""" 127 if not isinstance(new_descriptor, ( 128 extended_descriptor.ExtendedMessageDescriptor, 129 extended_descriptor.ExtendedEnumDescriptor)): 130 raise ValueError('Cannot add descriptor of type %s' % ( 131 type(new_descriptor),)) 132 full_name = self.__ComputeFullName(new_descriptor.name) 133 if full_name in self.__message_registry: 134 raise ValueError( 135 'Attempt to re-register descriptor %s' % full_name) 136 if full_name not in self.__nascent_types: 137 raise ValueError('Directly adding types is not supported') 138 new_descriptor.full_name = full_name 139 self.__message_registry[full_name] = new_descriptor 140 if isinstance(new_descriptor, 141 extended_descriptor.ExtendedMessageDescriptor): 142 self.__current_env.message_types.append(new_descriptor) 143 elif isinstance(new_descriptor, 144 extended_descriptor.ExtendedEnumDescriptor): 145 self.__current_env.enum_types.append(new_descriptor) 146 self.__unknown_types.discard(full_name) 147 self.__nascent_types.remove(full_name) 148 149 def LookupDescriptor(self, name): 150 return self.__GetDescriptorByName(name) 151 152 def LookupDescriptorOrDie(self, name): 153 message_descriptor = self.LookupDescriptor(name) 154 if message_descriptor is None: 155 raise ValueError('No message descriptor named "%s"', name) 156 return message_descriptor 157 158 def __GetDescriptor(self, name): 159 return self.__GetDescriptorByName(self.__ComputeFullName(name)) 160 161 def __GetDescriptorByName(self, name): 162 if name in self.__message_registry: 163 return self.__message_registry[name] 164 if name in self.__nascent_types: 165 raise ValueError( 166 'Cannot retrieve type currently being created: %s' % name) 167 return None 168 169 @contextlib.contextmanager 170 def __DescriptorEnv(self, message_descriptor): 171 # TODO(craigcitro): Typecheck? 172 previous_env = self.__current_env 173 self.__current_path.append(message_descriptor.name) 174 self.__current_env = message_descriptor 175 yield 176 self.__current_path.pop() 177 self.__current_env = previous_env 178 179 def AddEnumDescriptor(self, name, description, 180 enum_values, enum_descriptions): 181 """Add a new EnumDescriptor named name with the given enum values.""" 182 message = extended_descriptor.ExtendedEnumDescriptor() 183 message.name = self.__names.ClassName(name) 184 message.description = util.CleanDescription(description) 185 self.__DeclareDescriptor(message.name) 186 for index, (enum_name, enum_description) in enumerate( 187 zip(enum_values, enum_descriptions)): 188 enum_value = extended_descriptor.ExtendedEnumValueDescriptor() 189 enum_value.name = self.__names.NormalizeEnumName(enum_name) 190 if enum_value.name != enum_name: 191 message.enum_mappings.append( 192 extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping( 193 python_name=enum_value.name, json_name=enum_name)) 194 self.__AddImport('from %s import encoding' % 195 self.__base_files_package) 196 enum_value.number = index 197 enum_value.description = util.CleanDescription( 198 enum_description or '<no description>') 199 message.values.append(enum_value) 200 self.__RegisterDescriptor(message) 201 202 def __DeclareMessageAlias(self, schema, alias_for): 203 """Declare schema as an alias for alias_for.""" 204 # TODO(craigcitro): This is a hack. Remove it. 205 message = extended_descriptor.ExtendedMessageDescriptor() 206 message.name = self.__names.ClassName(schema['id']) 207 message.alias_for = alias_for 208 self.__DeclareDescriptor(message.name) 209 self.__AddImport('from %s import extra_types' % 210 self.__base_files_package) 211 self.__RegisterDescriptor(message) 212 213 def __AddAdditionalProperties(self, message, schema, properties): 214 """Add an additionalProperties field to message.""" 215 additional_properties_info = schema['additionalProperties'] 216 entries_type_name = self.__AddAdditionalPropertyType( 217 message.name, additional_properties_info) 218 description = util.CleanDescription( 219 additional_properties_info.get('description')) 220 if description is None: 221 description = 'Additional properties of type %s' % message.name 222 attrs = { 223 'items': { 224 '$ref': entries_type_name, 225 }, 226 'description': description, 227 'type': 'array', 228 } 229 field_name = 'additionalProperties' 230 message.fields.append(self.__FieldDescriptorFromProperties( 231 field_name, len(properties) + 1, attrs)) 232 self.__AddImport('from %s import encoding' % self.__base_files_package) 233 message.decorators.append( 234 'encoding.MapUnrecognizedFields(%r)' % field_name) 235 236 def AddDescriptorFromSchema(self, schema_name, schema): 237 """Add a new MessageDescriptor named schema_name based on schema.""" 238 # TODO(craigcitro): Is schema_name redundant? 239 if self.__GetDescriptor(schema_name): 240 return 241 if schema.get('enum'): 242 self.__DeclareEnum(schema_name, schema) 243 return 244 if schema.get('type') == 'any': 245 self.__DeclareMessageAlias(schema, 'extra_types.JsonValue') 246 return 247 if schema.get('type') != 'object': 248 raise ValueError('Cannot create message descriptors for type %s', 249 schema.get('type')) 250 message = extended_descriptor.ExtendedMessageDescriptor() 251 message.name = self.__names.ClassName(schema['id']) 252 message.description = util.CleanDescription(schema.get( 253 'description', 'A %s object.' % message.name)) 254 self.__DeclareDescriptor(message.name) 255 with self.__DescriptorEnv(message): 256 properties = schema.get('properties', {}) 257 for index, (name, attrs) in enumerate(sorted(properties.items())): 258 field = self.__FieldDescriptorFromProperties( 259 name, index + 1, attrs) 260 message.fields.append(field) 261 if field.name != name: 262 message.field_mappings.append( 263 type(message).JsonFieldMapping( 264 python_name=field.name, json_name=name)) 265 self.__AddImport( 266 'from %s import encoding' % self.__base_files_package) 267 if 'additionalProperties' in schema: 268 self.__AddAdditionalProperties(message, schema, properties) 269 self.__RegisterDescriptor(message) 270 271 def __AddAdditionalPropertyType(self, name, property_schema): 272 """Add a new nested AdditionalProperty message.""" 273 new_type_name = 'AdditionalProperty' 274 property_schema = dict(property_schema) 275 # We drop the description here on purpose, so the resulting 276 # messages are less repetitive. 277 property_schema.pop('description', None) 278 description = 'An additional property for a %s object.' % name 279 schema = { 280 'id': new_type_name, 281 'type': 'object', 282 'description': description, 283 'properties': { 284 'key': { 285 'type': 'string', 286 'description': 'Name of the additional property.', 287 }, 288 'value': property_schema, 289 }, 290 } 291 self.AddDescriptorFromSchema(new_type_name, schema) 292 return new_type_name 293 294 def __AddEntryType(self, entry_type_name, entry_schema, parent_name): 295 """Add a type for a list entry.""" 296 entry_schema.pop('description', None) 297 description = 'Single entry in a %s.' % parent_name 298 schema = { 299 'id': entry_type_name, 300 'type': 'object', 301 'description': description, 302 'properties': { 303 'entry': { 304 'type': 'array', 305 'items': entry_schema, 306 }, 307 }, 308 } 309 self.AddDescriptorFromSchema(entry_type_name, schema) 310 return entry_type_name 311 312 def __FieldDescriptorFromProperties(self, name, index, attrs): 313 """Create a field descriptor for these attrs.""" 314 field = descriptor.FieldDescriptor() 315 field.name = self.__names.CleanName(name) 316 field.number = index 317 field.label = self.__ComputeLabel(attrs) 318 new_type_name_hint = self.__names.ClassName( 319 '%sValue' % self.__names.ClassName(name)) 320 type_info = self.__GetTypeInfo(attrs, new_type_name_hint) 321 field.type_name = type_info.type_name 322 field.variant = type_info.variant 323 if 'default' in attrs: 324 # TODO(craigcitro): Correctly handle non-primitive default values. 325 default = attrs['default'] 326 if not (field.type_name == 'string' or 327 field.variant == messages.Variant.ENUM): 328 default = str(json.loads(default)) 329 if field.variant == messages.Variant.ENUM: 330 default = self.__names.NormalizeEnumName(default) 331 field.default_value = default 332 extended_field = extended_descriptor.ExtendedFieldDescriptor() 333 extended_field.name = field.name 334 extended_field.description = util.CleanDescription( 335 attrs.get('description', 'A %s attribute.' % field.type_name)) 336 extended_field.field_descriptor = field 337 return extended_field 338 339 @staticmethod 340 def __ComputeLabel(attrs): 341 if attrs.get('required', False): 342 return descriptor.FieldDescriptor.Label.REQUIRED 343 elif attrs.get('type') == 'array': 344 return descriptor.FieldDescriptor.Label.REPEATED 345 elif attrs.get('repeated'): 346 return descriptor.FieldDescriptor.Label.REPEATED 347 return descriptor.FieldDescriptor.Label.OPTIONAL 348 349 def __DeclareEnum(self, enum_name, attrs): 350 description = util.CleanDescription(attrs.get('description', '')) 351 enum_values = attrs['enum'] 352 enum_descriptions = attrs.get( 353 'enumDescriptions', [''] * len(enum_values)) 354 self.AddEnumDescriptor(enum_name, description, 355 enum_values, enum_descriptions) 356 self.__AddIfUnknown(enum_name) 357 return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM) 358 359 def __AddIfUnknown(self, type_name): 360 type_name = self.__names.ClassName(type_name) 361 full_type_name = self.__ComputeFullName(type_name) 362 if (full_type_name not in self.__message_registry.keys() and 363 type_name not in self.__message_registry.keys()): 364 self.__unknown_types.add(type_name) 365 366 def __GetTypeInfo(self, attrs, name_hint): 367 """Return a TypeInfo object for attrs, creating one if needed.""" 368 369 type_ref = self.__names.ClassName(attrs.get('$ref')) 370 type_name = attrs.get('type') 371 if not (type_ref or type_name): 372 raise ValueError('No type found for %s' % attrs) 373 374 if type_ref: 375 self.__AddIfUnknown(type_ref) 376 # We don't actually know this is a message -- it might be an 377 # enum. However, we can't check that until we've created all the 378 # types, so we come back and fix this up later. 379 return TypeInfo( 380 type_name=type_ref, variant=messages.Variant.MESSAGE) 381 382 if 'enum' in attrs: 383 enum_name = '%sValuesEnum' % name_hint 384 return self.__DeclareEnum(enum_name, attrs) 385 386 if 'format' in attrs: 387 type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format']) 388 if type_info is None: 389 # If we don't recognize the format, the spec says we fall back 390 # to just using the type name. 391 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: 392 return self.PRIMITIVE_TYPE_INFO_MAP[type_name] 393 raise ValueError('Unknown type/format "%s"/"%s"' % ( 394 attrs['format'], type_name)) 395 if (type_info.type_name.startswith('protorpc.message_types.') or 396 type_info.type_name.startswith('message_types.')): 397 self.__AddImport( 398 'from protorpc import message_types as _message_types') 399 if type_info.type_name.startswith('extra_types.'): 400 self.__AddImport( 401 'from %s import extra_types' % self.__base_files_package) 402 return type_info 403 404 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: 405 type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name] 406 return type_info 407 408 if type_name == 'array': 409 items = attrs.get('items') 410 if not items: 411 raise ValueError('Array type with no item type: %s' % attrs) 412 entry_name_hint = self.__names.ClassName( 413 items.get('title') or '%sListEntry' % name_hint) 414 entry_label = self.__ComputeLabel(items) 415 if entry_label == descriptor.FieldDescriptor.Label.REPEATED: 416 parent_name = self.__names.ClassName( 417 items.get('title') or name_hint) 418 entry_type_name = self.__AddEntryType( 419 entry_name_hint, items.get('items'), parent_name) 420 return TypeInfo(type_name=entry_type_name, 421 variant=messages.Variant.MESSAGE) 422 else: 423 return self.__GetTypeInfo(items, entry_name_hint) 424 elif type_name == 'any': 425 self.__AddImport('from %s import extra_types' % 426 self.__base_files_package) 427 return self.PRIMITIVE_TYPE_INFO_MAP['any'] 428 elif type_name == 'object': 429 # TODO(craigcitro): Think of a better way to come up with names. 430 if not name_hint: 431 raise ValueError( 432 'Cannot create subtype without some name hint') 433 schema = dict(attrs) 434 schema['id'] = name_hint 435 self.AddDescriptorFromSchema(name_hint, schema) 436 self.__AddIfUnknown(name_hint) 437 return TypeInfo( 438 type_name=name_hint, variant=messages.Variant.MESSAGE) 439 440 raise ValueError('Unknown type: %s' % type_name) 441 442 def FixupMessageFields(self): 443 for message_type in self.file_descriptor.message_types: 444 self._FixupMessage(message_type) 445 446 def _FixupMessage(self, message_type): 447 with self.__DescriptorEnv(message_type): 448 for field in message_type.fields: 449 if field.field_descriptor.variant == messages.Variant.MESSAGE: 450 field_type_name = field.field_descriptor.type_name 451 field_type = self.LookupDescriptor(field_type_name) 452 if isinstance(field_type, 453 extended_descriptor.ExtendedEnumDescriptor): 454 field.field_descriptor.variant = messages.Variant.ENUM 455 for submessage_type in message_type.message_types: 456 self._FixupMessage(submessage_type) 457