1#!/usr/bin/env python
2"""Extended protorpc descriptors.
3
4This takes existing protorpc Descriptor classes and adds extra
5properties not directly supported in proto itself, notably field and
6message descriptions. We need this in order to generate protorpc
7message files with comments.
8
9Note that for most of these classes, we can't simply wrap the existing
10message, since we need to change the type of the subfields. We could
11have a "plain" descriptor attached, but that seems like unnecessary
12bookkeeping. Where possible, we purposely reuse existing tag numbers;
13for new fields, we start numbering at 100.
14"""
15import abc
16import operator
17import textwrap
18
19from protorpc import descriptor as protorpc_descriptor
20from protorpc import message_types
21from protorpc import messages
22import six
23
24import apitools.base.py as apitools_base
25
26
27class ExtendedEnumValueDescriptor(messages.Message):
28
29    """Enum value descriptor with additional fields.
30
31    Fields:
32      name: Name of enumeration value.
33      number: Number of enumeration value.
34      description: Description of this enum value.
35    """
36    name = messages.StringField(1)
37    number = messages.IntegerField(2, variant=messages.Variant.INT32)
38
39    description = messages.StringField(100)
40
41
42class ExtendedEnumDescriptor(messages.Message):
43
44    """Enum class descriptor with additional fields.
45
46    Fields:
47      name: Name of Enum without any qualification.
48      values: Values defined by Enum class.
49      description: Description of this enum class.
50      full_name: Fully qualified name of this enum class.
51      enum_mappings: Mappings from python to JSON names for enum values.
52    """
53
54    class JsonEnumMapping(messages.Message):
55
56        """Mapping from a python name to the wire name for an enum."""
57        python_name = messages.StringField(1)
58        json_name = messages.StringField(2)
59
60    name = messages.StringField(1)
61    values = messages.MessageField(
62        ExtendedEnumValueDescriptor, 2, repeated=True)
63
64    description = messages.StringField(100)
65    full_name = messages.StringField(101)
66    enum_mappings = messages.MessageField(
67        'JsonEnumMapping', 102, repeated=True)
68
69
70class ExtendedFieldDescriptor(messages.Message):
71
72    """Field descriptor with additional fields.
73
74    Fields:
75      field_descriptor: The underlying field descriptor.
76      name: The name of this field.
77      description: Description of this field.
78    """
79    field_descriptor = messages.MessageField(
80        protorpc_descriptor.FieldDescriptor, 100)
81    # We duplicate the names for easier bookkeeping.
82    name = messages.StringField(101)
83    description = messages.StringField(102)
84
85
86class ExtendedMessageDescriptor(messages.Message):
87
88    """Message descriptor with additional fields.
89
90    Fields:
91      name: Name of Message without any qualification.
92      fields: Fields defined for message.
93      message_types: Nested Message classes defined on message.
94      enum_types: Nested Enum classes defined on message.
95      description: Description of this message.
96      full_name: Full qualified name of this message.
97      decorators: Decorators to include in the definition when printing.
98          Printed in the given order from top to bottom (so the last entry
99          is the innermost decorator).
100      alias_for: This type is just an alias for the named type.
101      field_mappings: Mappings from python to json field names.
102    """
103
104    class JsonFieldMapping(messages.Message):
105
106        """Mapping from a python name to the wire name for a field."""
107        python_name = messages.StringField(1)
108        json_name = messages.StringField(2)
109
110    name = messages.StringField(1)
111    fields = messages.MessageField(ExtendedFieldDescriptor, 2, repeated=True)
112    message_types = messages.MessageField(
113        'extended_descriptor.ExtendedMessageDescriptor', 3, repeated=True)
114    enum_types = messages.MessageField(
115        ExtendedEnumDescriptor, 4, repeated=True)
116
117    description = messages.StringField(100)
118    full_name = messages.StringField(101)
119    decorators = messages.StringField(102, repeated=True)
120    alias_for = messages.StringField(103)
121    field_mappings = messages.MessageField(
122        'JsonFieldMapping', 104, repeated=True)
123
124
125class ExtendedFileDescriptor(messages.Message):
126
127    """File descriptor with additional fields.
128
129    Fields:
130      package: Fully qualified name of package that definitions belong to.
131      message_types: Message definitions contained in file.
132      enum_types: Enum definitions contained in file.
133      description: Description of this file.
134      additional_imports: Extra imports used in this package.
135    """
136    package = messages.StringField(2)
137
138    message_types = messages.MessageField(
139        ExtendedMessageDescriptor, 4, repeated=True)
140    enum_types = messages.MessageField(
141        ExtendedEnumDescriptor, 5, repeated=True)
142
143    description = messages.StringField(100)
144    additional_imports = messages.StringField(101, repeated=True)
145
146
147def _WriteFile(file_descriptor, package, version, proto_printer):
148    """Write the given extended file descriptor to the printer."""
149    proto_printer.PrintPreamble(package, version, file_descriptor)
150    _PrintEnums(proto_printer, file_descriptor.enum_types)
151    _PrintMessages(proto_printer, file_descriptor.message_types)
152    custom_json_mappings = _FetchCustomMappings(
153        file_descriptor.enum_types, file_descriptor.package)
154    custom_json_mappings.extend(
155        _FetchCustomMappings(
156            file_descriptor.message_types, file_descriptor.package))
157    for mapping in custom_json_mappings:
158        proto_printer.PrintCustomJsonMapping(mapping)
159
160
161def WriteMessagesFile(file_descriptor, package, version, printer):
162    """Write the given extended file descriptor to out as a message file."""
163    _WriteFile(file_descriptor, package, version,
164               _Proto2Printer(printer))
165
166
167def WritePythonFile(file_descriptor, package, version, printer):
168    """Write the given extended file descriptor to out."""
169    _WriteFile(file_descriptor, package, version,
170               _ProtoRpcPrinter(printer))
171
172
173def PrintIndentedDescriptions(printer, ls, name, prefix=''):
174    if ls:
175        with printer.Indent(indent=prefix):
176            with printer.CommentContext():
177                width = printer.CalculateWidth() - len(prefix)
178                printer()
179                printer(name + ':')
180                for x in ls:
181                    description = '%s: %s' % (x.name, x.description)
182                    for line in textwrap.wrap(description, width,
183                                              initial_indent='  ',
184                                              subsequent_indent='    '):
185                        printer(line)
186
187
188def _FetchCustomMappings(descriptor_ls, package):
189    """Find and return all custom mappings for descriptors in descriptor_ls."""
190    custom_mappings = []
191    for descriptor in descriptor_ls:
192        if isinstance(descriptor, ExtendedEnumDescriptor):
193            custom_mappings.extend(
194                _FormatCustomJsonMapping('Enum', m, descriptor, package)
195                for m in descriptor.enum_mappings)
196        elif isinstance(descriptor, ExtendedMessageDescriptor):
197            custom_mappings.extend(
198                _FormatCustomJsonMapping('Field', m, descriptor, package)
199                for m in descriptor.field_mappings)
200            custom_mappings.extend(
201                _FetchCustomMappings(descriptor.enum_types, package))
202            custom_mappings.extend(
203                _FetchCustomMappings(descriptor.message_types, package))
204    return custom_mappings
205
206
207def _FormatCustomJsonMapping(mapping_type, mapping, descriptor, package):
208    return '\n'.join((
209        'encoding.AddCustomJson%sMapping(' % mapping_type,
210        "    %s, '%s', '%s'," % (descriptor.full_name, mapping.python_name,
211                                 mapping.json_name),
212        '    package=%r)' % package,
213    ))
214
215
216def _EmptyMessage(message_type):
217    return not any((message_type.enum_types,
218                    message_type.message_types,
219                    message_type.fields))
220
221
222class ProtoPrinter(six.with_metaclass(abc.ABCMeta, object)):
223
224    """Interface for proto printers."""
225
226    @abc.abstractmethod
227    def PrintPreamble(self, package, version, file_descriptor):
228        """Print the file docstring and import lines."""
229
230    @abc.abstractmethod
231    def PrintEnum(self, enum_type):
232        """Print the given enum declaration."""
233
234    @abc.abstractmethod
235    def PrintMessage(self, message_type):
236        """Print the given message declaration."""
237
238
239class _Proto2Printer(ProtoPrinter):
240
241    """Printer for proto2 definitions."""
242
243    def __init__(self, printer):
244        self.__printer = printer
245
246    def __PrintEnumCommentLines(self, enum_type):
247        description = enum_type.description or '%s enum type.' % enum_type.name
248        for line in textwrap.wrap(description,
249                                  self.__printer.CalculateWidth() - 3):
250            self.__printer('// %s', line)
251        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values',
252                                  prefix='// ')
253
254    def __PrintEnumValueCommentLines(self, enum_value):
255        if enum_value.description:
256            width = self.__printer.CalculateWidth() - 3
257            for line in textwrap.wrap(enum_value.description, width):
258                self.__printer('// %s', line)
259
260    def PrintEnum(self, enum_type):
261        self.__PrintEnumCommentLines(enum_type)
262        self.__printer('enum %s {', enum_type.name)
263        with self.__printer.Indent():
264            enum_values = sorted(
265                enum_type.values, key=operator.attrgetter('number'))
266            for enum_value in enum_values:
267                self.__printer()
268                self.__PrintEnumValueCommentLines(enum_value)
269                self.__printer('%s = %s;', enum_value.name, enum_value.number)
270        self.__printer('}')
271        self.__printer()
272
273    def PrintPreamble(self, package, version, file_descriptor):
274        self.__printer('// Generated message classes for %s version %s.',
275                       package, version)
276        self.__printer('// NOTE: This file is autogenerated and should not be '
277                       'edited by hand.')
278        description_lines = textwrap.wrap(file_descriptor.description, 75)
279        if description_lines:
280            self.__printer('//')
281            for line in description_lines:
282                self.__printer('// %s', line)
283        self.__printer()
284        self.__printer('syntax = "proto2";')
285        self.__printer('package %s;', file_descriptor.package)
286
287    def __PrintMessageCommentLines(self, message_type):
288        """Print the description of this message."""
289        description = message_type.description or '%s message type.' % (
290            message_type.name)
291        width = self.__printer.CalculateWidth() - 3
292        for line in textwrap.wrap(description, width):
293            self.__printer('// %s', line)
294        PrintIndentedDescriptions(self.__printer, message_type.enum_types,
295                                  'Enums', prefix='// ')
296        PrintIndentedDescriptions(self.__printer, message_type.message_types,
297                                  'Messages', prefix='// ')
298        PrintIndentedDescriptions(self.__printer, message_type.fields,
299                                  'Fields', prefix='// ')
300
301    def __PrintFieldDescription(self, description):
302        for line in textwrap.wrap(description,
303                                  self.__printer.CalculateWidth() - 3):
304            self.__printer('// %s', line)
305
306    def __PrintFields(self, fields):
307        for extended_field in fields:
308            field = extended_field.field_descriptor
309            field_type = messages.Field.lookup_field_type_by_variant(
310                field.variant)
311            self.__printer()
312            self.__PrintFieldDescription(extended_field.description)
313            label = str(field.label).lower()
314            if field_type in (messages.EnumField, messages.MessageField):
315                proto_type = field.type_name
316            else:
317                proto_type = str(field.variant).lower()
318            default_statement = ''
319            if field.default_value:
320                if field_type in [messages.BytesField, messages.StringField]:
321                    default_value = '"%s"' % field.default_value
322                elif field_type is messages.BooleanField:
323                    default_value = str(field.default_value).lower()
324                else:
325                    default_value = str(field.default_value)
326
327                default_statement = ' [default = %s]' % default_value
328            self.__printer(
329                '%s %s %s = %d%s;',
330                label, proto_type, field.name, field.number, default_statement)
331
332    def PrintMessage(self, message_type):
333        self.__printer()
334        self.__PrintMessageCommentLines(message_type)
335        if _EmptyMessage(message_type):
336            self.__printer('message %s {}', message_type.name)
337            return
338        self.__printer('message %s {', message_type.name)
339        with self.__printer.Indent():
340            _PrintEnums(self, message_type.enum_types)
341            _PrintMessages(self, message_type.message_types)
342            self.__PrintFields(message_type.fields)
343        self.__printer('}')
344
345    def PrintCustomJsonMapping(self, mapping_lines):
346        raise NotImplementedError(
347            'Custom JSON encoding not supported for proto2')
348
349
350class _ProtoRpcPrinter(ProtoPrinter):
351
352    """Printer for ProtoRPC definitions."""
353
354    def __init__(self, printer):
355        self.__printer = printer
356
357    def __PrintClassSeparator(self):
358        self.__printer()
359        if not self.__printer.indent:
360            self.__printer()
361
362    def __PrintEnumDocstringLines(self, enum_type):
363        description = enum_type.description or '%s enum type.' % enum_type.name
364        for line in textwrap.wrap('"""%s' % description,
365                                  self.__printer.CalculateWidth()):
366            self.__printer(line)
367        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values')
368        self.__printer('"""')
369
370    def PrintEnum(self, enum_type):
371        self.__printer('class %s(_messages.Enum):', enum_type.name)
372        with self.__printer.Indent():
373            self.__PrintEnumDocstringLines(enum_type)
374            enum_values = sorted(
375                enum_type.values, key=operator.attrgetter('number'))
376            for enum_value in enum_values:
377                self.__printer('%s = %s', enum_value.name, enum_value.number)
378            if not enum_type.values:
379                self.__printer('pass')
380        self.__PrintClassSeparator()
381
382    def __PrintAdditionalImports(self, imports):
383        """Print additional imports needed for protorpc."""
384        google_imports = [x for x in imports if 'google' in x]
385        other_imports = [x for x in imports if 'google' not in x]
386        if other_imports:
387            for import_ in sorted(other_imports):
388                self.__printer(import_)
389            self.__printer()
390        # Note: If we ever were going to add imports from this package, we'd
391        # need to sort those out and put them at the end.
392        if google_imports:
393            for import_ in sorted(google_imports):
394                self.__printer(import_)
395            self.__printer()
396
397    def PrintPreamble(self, package, version, file_descriptor):
398        self.__printer('"""Generated message classes for %s version %s.',
399                       package, version)
400        self.__printer()
401        for line in textwrap.wrap(file_descriptor.description, 78):
402            self.__printer(line)
403        self.__printer('"""')
404        self.__printer('# NOTE: This file is autogenerated and should not be '
405                       'edited by hand.')
406        self.__printer()
407        self.__PrintAdditionalImports(file_descriptor.additional_imports)
408        self.__printer()
409        self.__printer("package = '%s'", file_descriptor.package)
410        self.__printer()
411        self.__printer()
412
413    def __PrintMessageDocstringLines(self, message_type):
414        """Print the docstring for this message."""
415        description = message_type.description or '%s message type.' % (
416            message_type.name)
417        short_description = (
418            _EmptyMessage(message_type) and
419            len(description) < (self.__printer.CalculateWidth() - 6))
420        with self.__printer.CommentContext():
421            if short_description:
422                # Note that we use explicit string interpolation here since
423                # we're in comment context.
424                self.__printer('"""%s"""' % description)
425                return
426            for line in textwrap.wrap('"""%s' % description,
427                                      self.__printer.CalculateWidth()):
428                self.__printer(line)
429
430            PrintIndentedDescriptions(self.__printer, message_type.enum_types,
431                                      'Enums')
432            PrintIndentedDescriptions(
433                self.__printer, message_type.message_types, 'Messages')
434            PrintIndentedDescriptions(
435                self.__printer, message_type.fields, 'Fields')
436            self.__printer('"""')
437            self.__printer()
438
439    def PrintMessage(self, message_type):
440        if message_type.alias_for:
441            self.__printer(
442                '%s = %s', message_type.name, message_type.alias_for)
443            self.__PrintClassSeparator()
444            return
445        for decorator in message_type.decorators:
446            self.__printer('@%s', decorator)
447        self.__printer('class %s(_messages.Message):', message_type.name)
448        with self.__printer.Indent():
449            self.__PrintMessageDocstringLines(message_type)
450            _PrintEnums(self, message_type.enum_types)
451            _PrintMessages(self, message_type.message_types)
452            _PrintFields(message_type.fields, self.__printer)
453        self.__PrintClassSeparator()
454
455    def PrintCustomJsonMapping(self, mapping):
456        self.__printer(mapping)
457
458
459def _PrintEnums(proto_printer, enum_types):
460    """Print all enums to the given proto_printer."""
461    enum_types = sorted(enum_types, key=operator.attrgetter('name'))
462    for enum_type in enum_types:
463        proto_printer.PrintEnum(enum_type)
464
465
466def _PrintMessages(proto_printer, message_list):
467    message_list = sorted(message_list, key=operator.attrgetter('name'))
468    for message_type in message_list:
469        proto_printer.PrintMessage(message_type)
470
471
472_MESSAGE_FIELD_MAP = {
473    message_types.DateTimeMessage.definition_name(): (
474        message_types.DateTimeField),
475}
476
477
478def _PrintFields(fields, printer):
479    for extended_field in fields:
480        field = extended_field.field_descriptor
481        printed_field_info = {
482            'name': field.name,
483            'module': '_messages',
484            'type_name': '',
485            'type_format': '',
486            'number': field.number,
487            'label_format': '',
488            'variant_format': '',
489            'default_format': '',
490        }
491
492        message_field = _MESSAGE_FIELD_MAP.get(field.type_name)
493        if message_field:
494            printed_field_info['module'] = '_message_types'
495            field_type = message_field
496        elif field.type_name == 'extra_types.DateField':
497            printed_field_info['module'] = 'extra_types'
498            field_type = apitools_base.DateField
499        else:
500            field_type = messages.Field.lookup_field_type_by_variant(
501                field.variant)
502
503        if field_type in (messages.EnumField, messages.MessageField):
504            printed_field_info['type_format'] = "'%s', " % field.type_name
505
506        if field.label == protorpc_descriptor.FieldDescriptor.Label.REQUIRED:
507            printed_field_info['label_format'] = ', required=True'
508        elif field.label == protorpc_descriptor.FieldDescriptor.Label.REPEATED:
509            printed_field_info['label_format'] = ', repeated=True'
510
511        if field_type.DEFAULT_VARIANT != field.variant:
512            printed_field_info['variant_format'] = (
513                ', variant=_messages.Variant.%s' % field.variant)
514
515        if field.default_value:
516            if field_type in [messages.BytesField, messages.StringField]:
517                default_value = repr(field.default_value)
518            elif field_type is messages.EnumField:
519                try:
520                    default_value = str(int(field.default_value))
521                except ValueError:
522                    default_value = repr(field.default_value)
523            else:
524                default_value = field.default_value
525
526            printed_field_info[
527                'default_format'] = ', default=%s' % (default_value,)
528
529        printed_field_info['type_name'] = field_type.__name__
530        args = ''.join('%%(%s)s' % field for field in (
531            'type_format',
532            'number',
533            'label_format',
534            'variant_format',
535            'default_format'))
536        format_str = '%%(name)s = %%(module)s.%%(type_name)s(%s)' % args
537        printer(format_str % printed_field_info)
538