1#!/usr/bin/env python
2"""Message registry for apitools."""
3
4import collections
5import contextlib
6import json
7
8from protorpc import descriptor
9from protorpc import messages
10import six
11
12from apitools.gen import extended_descriptor
13from apitools.gen import util
14
15TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant'))
16
17
18class MessageRegistry(object):
19
20    """Registry for message types.
21
22    This closely mirrors a messages.FileDescriptor, but adds additional
23    attributes (such as message and field descriptions) and some extra
24    code for validation and cycle detection.
25    """
26
27    # Type information from these two maps comes from here:
28    #  https://developers.google.com/discovery/v1/type-format
29    PRIMITIVE_TYPE_INFO_MAP = {
30        'string': TypeInfo(type_name='string',
31                           variant=messages.StringField.DEFAULT_VARIANT),
32        'integer': TypeInfo(type_name='integer',
33                            variant=messages.IntegerField.DEFAULT_VARIANT),
34        'boolean': TypeInfo(type_name='boolean',
35                            variant=messages.BooleanField.DEFAULT_VARIANT),
36        'number': TypeInfo(type_name='number',
37                           variant=messages.FloatField.DEFAULT_VARIANT),
38        'any': TypeInfo(type_name='extra_types.JsonValue',
39                        variant=messages.Variant.MESSAGE),
40    }
41
42    PRIMITIVE_FORMAT_MAP = {
43        'int32': TypeInfo(type_name='integer',
44                          variant=messages.Variant.INT32),
45        'uint32': TypeInfo(type_name='integer',
46                           variant=messages.Variant.UINT32),
47        'int64': TypeInfo(type_name='string',
48                          variant=messages.Variant.INT64),
49        'uint64': TypeInfo(type_name='string',
50                           variant=messages.Variant.UINT64),
51        'double': TypeInfo(type_name='number',
52                           variant=messages.Variant.DOUBLE),
53        'float': TypeInfo(type_name='number',
54                          variant=messages.Variant.FLOAT),
55        'byte': TypeInfo(type_name='byte',
56                         variant=messages.BytesField.DEFAULT_VARIANT),
57        'date': TypeInfo(type_name='extra_types.DateField',
58                         variant=messages.Variant.STRING),
59        'date-time': TypeInfo(
60            type_name='protorpc.message_types.DateTimeMessage',
61            variant=messages.Variant.MESSAGE),
62    }
63
64    def __init__(self, client_info, names, description,
65                 root_package_dir, base_files_package):
66        self.__names = names
67        self.__client_info = client_info
68        self.__package = client_info.package
69        self.__description = util.CleanDescription(description)
70        self.__root_package_dir = root_package_dir
71        self.__base_files_package = base_files_package
72        self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor(
73            package=self.__package, description=self.__description)
74        # Add required imports
75        self.__file_descriptor.additional_imports = [
76            'from protorpc import messages as _messages',
77        ]
78        # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors.
79        self.__message_registry = collections.OrderedDict()
80        # A set of types that we're currently adding (for cycle detection).
81        self.__nascent_types = set()
82        # A set of types for which we've seen a reference but no
83        # definition; if this set is nonempty, validation fails.
84        self.__unknown_types = set()
85        # Used for tracking paths during message creation
86        self.__current_path = []
87        # Where to register created messages
88        self.__current_env = self.__file_descriptor
89        # TODO(craigcitro): Add a `Finalize` method.
90
91    @property
92    def file_descriptor(self):
93        self.Validate()
94        return self.__file_descriptor
95
96    def WriteProtoFile(self, printer):
97        """Write the messages file to out as proto."""
98        self.Validate()
99        extended_descriptor.WriteMessagesFile(
100            self.__file_descriptor, self.__package, self.__client_info.version,
101            printer)
102
103    def WriteFile(self, printer):
104        """Write the messages file to out."""
105        self.Validate()
106        extended_descriptor.WritePythonFile(
107            self.__file_descriptor, self.__package, self.__client_info.version,
108            printer)
109
110    def Validate(self):
111        mysteries = self.__nascent_types or self.__unknown_types
112        if mysteries:
113            raise ValueError('Malformed MessageRegistry: %s' % mysteries)
114
115    def __ComputeFullName(self, name):
116        return '.'.join(map(six.text_type, self.__current_path[:] + [name]))
117
118    def __AddImport(self, new_import):
119        if new_import not in self.__file_descriptor.additional_imports:
120            self.__file_descriptor.additional_imports.append(new_import)
121
122    def __DeclareDescriptor(self, name):
123        self.__nascent_types.add(self.__ComputeFullName(name))
124
125    def __RegisterDescriptor(self, new_descriptor):
126        """Register the given descriptor in this registry."""
127        if not isinstance(new_descriptor, (
128                extended_descriptor.ExtendedMessageDescriptor,
129                extended_descriptor.ExtendedEnumDescriptor)):
130            raise ValueError('Cannot add descriptor of type %s' % (
131                type(new_descriptor),))
132        full_name = self.__ComputeFullName(new_descriptor.name)
133        if full_name in self.__message_registry:
134            raise ValueError(
135                'Attempt to re-register descriptor %s' % full_name)
136        if full_name not in self.__nascent_types:
137            raise ValueError('Directly adding types is not supported')
138        new_descriptor.full_name = full_name
139        self.__message_registry[full_name] = new_descriptor
140        if isinstance(new_descriptor,
141                      extended_descriptor.ExtendedMessageDescriptor):
142            self.__current_env.message_types.append(new_descriptor)
143        elif isinstance(new_descriptor,
144                        extended_descriptor.ExtendedEnumDescriptor):
145            self.__current_env.enum_types.append(new_descriptor)
146        self.__unknown_types.discard(full_name)
147        self.__nascent_types.remove(full_name)
148
149    def LookupDescriptor(self, name):
150        return self.__GetDescriptorByName(name)
151
152    def LookupDescriptorOrDie(self, name):
153        message_descriptor = self.LookupDescriptor(name)
154        if message_descriptor is None:
155            raise ValueError('No message descriptor named "%s"', name)
156        return message_descriptor
157
158    def __GetDescriptor(self, name):
159        return self.__GetDescriptorByName(self.__ComputeFullName(name))
160
161    def __GetDescriptorByName(self, name):
162        if name in self.__message_registry:
163            return self.__message_registry[name]
164        if name in self.__nascent_types:
165            raise ValueError(
166                'Cannot retrieve type currently being created: %s' % name)
167        return None
168
169    @contextlib.contextmanager
170    def __DescriptorEnv(self, message_descriptor):
171        # TODO(craigcitro): Typecheck?
172        previous_env = self.__current_env
173        self.__current_path.append(message_descriptor.name)
174        self.__current_env = message_descriptor
175        yield
176        self.__current_path.pop()
177        self.__current_env = previous_env
178
179    def AddEnumDescriptor(self, name, description,
180                          enum_values, enum_descriptions):
181        """Add a new EnumDescriptor named name with the given enum values."""
182        message = extended_descriptor.ExtendedEnumDescriptor()
183        message.name = self.__names.ClassName(name)
184        message.description = util.CleanDescription(description)
185        self.__DeclareDescriptor(message.name)
186        for index, (enum_name, enum_description) in enumerate(
187                zip(enum_values, enum_descriptions)):
188            enum_value = extended_descriptor.ExtendedEnumValueDescriptor()
189            enum_value.name = self.__names.NormalizeEnumName(enum_name)
190            if enum_value.name != enum_name:
191                message.enum_mappings.append(
192                    extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping(
193                        python_name=enum_value.name, json_name=enum_name))
194                self.__AddImport('from %s import encoding' %
195                                 self.__base_files_package)
196            enum_value.number = index
197            enum_value.description = util.CleanDescription(
198                enum_description or '<no description>')
199            message.values.append(enum_value)
200        self.__RegisterDescriptor(message)
201
202    def __DeclareMessageAlias(self, schema, alias_for):
203        """Declare schema as an alias for alias_for."""
204        # TODO(craigcitro): This is a hack. Remove it.
205        message = extended_descriptor.ExtendedMessageDescriptor()
206        message.name = self.__names.ClassName(schema['id'])
207        message.alias_for = alias_for
208        self.__DeclareDescriptor(message.name)
209        self.__AddImport('from %s import extra_types' %
210                         self.__base_files_package)
211        self.__RegisterDescriptor(message)
212
213    def __AddAdditionalProperties(self, message, schema, properties):
214        """Add an additionalProperties field to message."""
215        additional_properties_info = schema['additionalProperties']
216        entries_type_name = self.__AddAdditionalPropertyType(
217            message.name, additional_properties_info)
218        description = util.CleanDescription(
219            additional_properties_info.get('description'))
220        if description is None:
221            description = 'Additional properties of type %s' % message.name
222        attrs = {
223            'items': {
224                '$ref': entries_type_name,
225            },
226            'description': description,
227            'type': 'array',
228        }
229        field_name = 'additionalProperties'
230        message.fields.append(self.__FieldDescriptorFromProperties(
231            field_name, len(properties) + 1, attrs))
232        self.__AddImport('from %s import encoding' % self.__base_files_package)
233        message.decorators.append(
234            'encoding.MapUnrecognizedFields(%r)' % field_name)
235
236    def AddDescriptorFromSchema(self, schema_name, schema):
237        """Add a new MessageDescriptor named schema_name based on schema."""
238        # TODO(craigcitro): Is schema_name redundant?
239        if self.__GetDescriptor(schema_name):
240            return
241        if schema.get('enum'):
242            self.__DeclareEnum(schema_name, schema)
243            return
244        if schema.get('type') == 'any':
245            self.__DeclareMessageAlias(schema, 'extra_types.JsonValue')
246            return
247        if schema.get('type') != 'object':
248            raise ValueError('Cannot create message descriptors for type %s',
249                             schema.get('type'))
250        message = extended_descriptor.ExtendedMessageDescriptor()
251        message.name = self.__names.ClassName(schema['id'])
252        message.description = util.CleanDescription(schema.get(
253            'description', 'A %s object.' % message.name))
254        self.__DeclareDescriptor(message.name)
255        with self.__DescriptorEnv(message):
256            properties = schema.get('properties', {})
257            for index, (name, attrs) in enumerate(sorted(properties.items())):
258                field = self.__FieldDescriptorFromProperties(
259                    name, index + 1, attrs)
260                message.fields.append(field)
261                if field.name != name:
262                    message.field_mappings.append(
263                        type(message).JsonFieldMapping(
264                            python_name=field.name, json_name=name))
265                    self.__AddImport(
266                        'from %s import encoding' % self.__base_files_package)
267            if 'additionalProperties' in schema:
268                self.__AddAdditionalProperties(message, schema, properties)
269        self.__RegisterDescriptor(message)
270
271    def __AddAdditionalPropertyType(self, name, property_schema):
272        """Add a new nested AdditionalProperty message."""
273        new_type_name = 'AdditionalProperty'
274        property_schema = dict(property_schema)
275        # We drop the description here on purpose, so the resulting
276        # messages are less repetitive.
277        property_schema.pop('description', None)
278        description = 'An additional property for a %s object.' % name
279        schema = {
280            'id': new_type_name,
281            'type': 'object',
282            'description': description,
283            'properties': {
284                'key': {
285                    'type': 'string',
286                    'description': 'Name of the additional property.',
287                },
288                'value': property_schema,
289            },
290        }
291        self.AddDescriptorFromSchema(new_type_name, schema)
292        return new_type_name
293
294    def __AddEntryType(self, entry_type_name, entry_schema, parent_name):
295        """Add a type for a list entry."""
296        entry_schema.pop('description', None)
297        description = 'Single entry in a %s.' % parent_name
298        schema = {
299            'id': entry_type_name,
300            'type': 'object',
301            'description': description,
302            'properties': {
303                'entry': {
304                    'type': 'array',
305                    'items': entry_schema,
306                },
307            },
308        }
309        self.AddDescriptorFromSchema(entry_type_name, schema)
310        return entry_type_name
311
312    def __FieldDescriptorFromProperties(self, name, index, attrs):
313        """Create a field descriptor for these attrs."""
314        field = descriptor.FieldDescriptor()
315        field.name = self.__names.CleanName(name)
316        field.number = index
317        field.label = self.__ComputeLabel(attrs)
318        new_type_name_hint = self.__names.ClassName(
319            '%sValue' % self.__names.ClassName(name))
320        type_info = self.__GetTypeInfo(attrs, new_type_name_hint)
321        field.type_name = type_info.type_name
322        field.variant = type_info.variant
323        if 'default' in attrs:
324            # TODO(craigcitro): Correctly handle non-primitive default values.
325            default = attrs['default']
326            if not (field.type_name == 'string' or
327                    field.variant == messages.Variant.ENUM):
328                default = str(json.loads(default))
329            if field.variant == messages.Variant.ENUM:
330                default = self.__names.NormalizeEnumName(default)
331            field.default_value = default
332        extended_field = extended_descriptor.ExtendedFieldDescriptor()
333        extended_field.name = field.name
334        extended_field.description = util.CleanDescription(
335            attrs.get('description', 'A %s attribute.' % field.type_name))
336        extended_field.field_descriptor = field
337        return extended_field
338
339    @staticmethod
340    def __ComputeLabel(attrs):
341        if attrs.get('required', False):
342            return descriptor.FieldDescriptor.Label.REQUIRED
343        elif attrs.get('type') == 'array':
344            return descriptor.FieldDescriptor.Label.REPEATED
345        elif attrs.get('repeated'):
346            return descriptor.FieldDescriptor.Label.REPEATED
347        return descriptor.FieldDescriptor.Label.OPTIONAL
348
349    def __DeclareEnum(self, enum_name, attrs):
350        description = util.CleanDescription(attrs.get('description', ''))
351        enum_values = attrs['enum']
352        enum_descriptions = attrs.get(
353            'enumDescriptions', [''] * len(enum_values))
354        self.AddEnumDescriptor(enum_name, description,
355                               enum_values, enum_descriptions)
356        self.__AddIfUnknown(enum_name)
357        return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM)
358
359    def __AddIfUnknown(self, type_name):
360        type_name = self.__names.ClassName(type_name)
361        full_type_name = self.__ComputeFullName(type_name)
362        if (full_type_name not in self.__message_registry.keys() and
363                type_name not in self.__message_registry.keys()):
364            self.__unknown_types.add(type_name)
365
366    def __GetTypeInfo(self, attrs, name_hint):
367        """Return a TypeInfo object for attrs, creating one if needed."""
368
369        type_ref = self.__names.ClassName(attrs.get('$ref'))
370        type_name = attrs.get('type')
371        if not (type_ref or type_name):
372            raise ValueError('No type found for %s' % attrs)
373
374        if type_ref:
375            self.__AddIfUnknown(type_ref)
376            # We don't actually know this is a message -- it might be an
377            # enum. However, we can't check that until we've created all the
378            # types, so we come back and fix this up later.
379            return TypeInfo(
380                type_name=type_ref, variant=messages.Variant.MESSAGE)
381
382        if 'enum' in attrs:
383            enum_name = '%sValuesEnum' % name_hint
384            return self.__DeclareEnum(enum_name, attrs)
385
386        if 'format' in attrs:
387            type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format'])
388            if type_info is None:
389                # If we don't recognize the format, the spec says we fall back
390                # to just using the type name.
391                if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
392                    return self.PRIMITIVE_TYPE_INFO_MAP[type_name]
393                raise ValueError('Unknown type/format "%s"/"%s"' % (
394                    attrs['format'], type_name))
395            if (type_info.type_name.startswith('protorpc.message_types.') or
396                    type_info.type_name.startswith('message_types.')):
397                self.__AddImport(
398                    'from protorpc import message_types as _message_types')
399            if type_info.type_name.startswith('extra_types.'):
400                self.__AddImport(
401                    'from %s import extra_types' % self.__base_files_package)
402            return type_info
403
404        if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
405            type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name]
406            return type_info
407
408        if type_name == 'array':
409            items = attrs.get('items')
410            if not items:
411                raise ValueError('Array type with no item type: %s' % attrs)
412            entry_name_hint = self.__names.ClassName(
413                items.get('title') or '%sListEntry' % name_hint)
414            entry_label = self.__ComputeLabel(items)
415            if entry_label == descriptor.FieldDescriptor.Label.REPEATED:
416                parent_name = self.__names.ClassName(
417                    items.get('title') or name_hint)
418                entry_type_name = self.__AddEntryType(
419                    entry_name_hint, items.get('items'), parent_name)
420                return TypeInfo(type_name=entry_type_name,
421                                variant=messages.Variant.MESSAGE)
422            else:
423                return self.__GetTypeInfo(items, entry_name_hint)
424        elif type_name == 'any':
425            self.__AddImport('from %s import extra_types' %
426                             self.__base_files_package)
427            return self.PRIMITIVE_TYPE_INFO_MAP['any']
428        elif type_name == 'object':
429            # TODO(craigcitro): Think of a better way to come up with names.
430            if not name_hint:
431                raise ValueError(
432                    'Cannot create subtype without some name hint')
433            schema = dict(attrs)
434            schema['id'] = name_hint
435            self.AddDescriptorFromSchema(name_hint, schema)
436            self.__AddIfUnknown(name_hint)
437            return TypeInfo(
438                type_name=name_hint, variant=messages.Variant.MESSAGE)
439
440        raise ValueError('Unknown type: %s' % type_name)
441
442    def FixupMessageFields(self):
443        for message_type in self.file_descriptor.message_types:
444            self._FixupMessage(message_type)
445
446    def _FixupMessage(self, message_type):
447        with self.__DescriptorEnv(message_type):
448            for field in message_type.fields:
449                if field.field_descriptor.variant == messages.Variant.MESSAGE:
450                    field_type_name = field.field_descriptor.type_name
451                    field_type = self.LookupDescriptor(field_type_name)
452                    if isinstance(field_type,
453                                  extended_descriptor.ExtendedEnumDescriptor):
454                        field.field_descriptor.variant = messages.Variant.ENUM
455            for submessage_type in message_type.message_types:
456                self._FixupMessage(submessage_type)
457