1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Extended protorpc descriptors.
18
19This takes existing protorpc Descriptor classes and adds extra
20properties not directly supported in proto itself, notably field and
21message descriptions. We need this in order to generate protorpc
22message files with comments.
23
24Note that for most of these classes, we can't simply wrap the existing
25message, since we need to change the type of the subfields. We could
26have a "plain" descriptor attached, but that seems like unnecessary
27bookkeeping. Where possible, we purposely reuse existing tag numbers;
28for new fields, we start numbering at 100.
29"""
30import abc
31import operator
32import textwrap
33
34import six
35
36from apitools.base.protorpclite import descriptor as protorpc_descriptor
37from apitools.base.protorpclite import message_types
38from apitools.base.protorpclite import messages
39import apitools.base.py as apitools_base
40
41
42class ExtendedEnumValueDescriptor(messages.Message):
43
44    """Enum value descriptor with additional fields.
45
46    Fields:
47      name: Name of enumeration value.
48      number: Number of enumeration value.
49      description: Description of this enum value.
50    """
51    name = messages.StringField(1)
52    number = messages.IntegerField(2, variant=messages.Variant.INT32)
53
54    description = messages.StringField(100)
55
56
57class ExtendedEnumDescriptor(messages.Message):
58
59    """Enum class descriptor with additional fields.
60
61    Fields:
62      name: Name of Enum without any qualification.
63      values: Values defined by Enum class.
64      description: Description of this enum class.
65      full_name: Fully qualified name of this enum class.
66      enum_mappings: Mappings from python to JSON names for enum values.
67    """
68
69    class JsonEnumMapping(messages.Message):
70
71        """Mapping from a python name to the wire name for an enum."""
72        python_name = messages.StringField(1)
73        json_name = messages.StringField(2)
74
75    name = messages.StringField(1)
76    values = messages.MessageField(
77        ExtendedEnumValueDescriptor, 2, repeated=True)
78
79    description = messages.StringField(100)
80    full_name = messages.StringField(101)
81    enum_mappings = messages.MessageField(
82        'JsonEnumMapping', 102, repeated=True)
83
84
85class ExtendedFieldDescriptor(messages.Message):
86
87    """Field descriptor with additional fields.
88
89    Fields:
90      field_descriptor: The underlying field descriptor.
91      name: The name of this field.
92      description: Description of this field.
93    """
94    field_descriptor = messages.MessageField(
95        protorpc_descriptor.FieldDescriptor, 100)
96    # We duplicate the names for easier bookkeeping.
97    name = messages.StringField(101)
98    description = messages.StringField(102)
99
100
101class ExtendedMessageDescriptor(messages.Message):
102
103    """Message descriptor with additional fields.
104
105    Fields:
106      name: Name of Message without any qualification.
107      fields: Fields defined for message.
108      message_types: Nested Message classes defined on message.
109      enum_types: Nested Enum classes defined on message.
110      description: Description of this message.
111      full_name: Full qualified name of this message.
112      decorators: Decorators to include in the definition when printing.
113          Printed in the given order from top to bottom (so the last entry
114          is the innermost decorator).
115      alias_for: This type is just an alias for the named type.
116      field_mappings: Mappings from python to json field names.
117    """
118
119    class JsonFieldMapping(messages.Message):
120
121        """Mapping from a python name to the wire name for a field."""
122        python_name = messages.StringField(1)
123        json_name = messages.StringField(2)
124
125    name = messages.StringField(1)
126    fields = messages.MessageField(ExtendedFieldDescriptor, 2, repeated=True)
127    message_types = messages.MessageField(
128        'extended_descriptor.ExtendedMessageDescriptor', 3, repeated=True)
129    enum_types = messages.MessageField(
130        ExtendedEnumDescriptor, 4, repeated=True)
131
132    description = messages.StringField(100)
133    full_name = messages.StringField(101)
134    decorators = messages.StringField(102, repeated=True)
135    alias_for = messages.StringField(103)
136    field_mappings = messages.MessageField(
137        'JsonFieldMapping', 104, repeated=True)
138
139
140class ExtendedFileDescriptor(messages.Message):
141
142    """File descriptor with additional fields.
143
144    Fields:
145      package: Fully qualified name of package that definitions belong to.
146      message_types: Message definitions contained in file.
147      enum_types: Enum definitions contained in file.
148      description: Description of this file.
149      additional_imports: Extra imports used in this package.
150    """
151    package = messages.StringField(2)
152
153    message_types = messages.MessageField(
154        ExtendedMessageDescriptor, 4, repeated=True)
155    enum_types = messages.MessageField(
156        ExtendedEnumDescriptor, 5, repeated=True)
157
158    description = messages.StringField(100)
159    additional_imports = messages.StringField(101, repeated=True)
160
161
162def _WriteFile(file_descriptor, package, version, proto_printer):
163    """Write the given extended file descriptor to the printer."""
164    proto_printer.PrintPreamble(package, version, file_descriptor)
165    _PrintEnums(proto_printer, file_descriptor.enum_types)
166    _PrintMessages(proto_printer, file_descriptor.message_types)
167    custom_json_mappings = _FetchCustomMappings(
168        file_descriptor.enum_types, file_descriptor.package)
169    custom_json_mappings.extend(
170        _FetchCustomMappings(
171            file_descriptor.message_types, file_descriptor.package))
172    for mapping in custom_json_mappings:
173        proto_printer.PrintCustomJsonMapping(mapping)
174
175
176def WriteMessagesFile(file_descriptor, package, version, printer):
177    """Write the given extended file descriptor to out as a message file."""
178    _WriteFile(file_descriptor, package, version,
179               _Proto2Printer(printer))
180
181
182def WritePythonFile(file_descriptor, package, version, printer):
183    """Write the given extended file descriptor to out."""
184    _WriteFile(file_descriptor, package, version,
185               _ProtoRpcPrinter(printer))
186
187
188def PrintIndentedDescriptions(printer, ls, name, prefix=''):
189    if ls:
190        with printer.Indent(indent=prefix):
191            with printer.CommentContext():
192                width = printer.CalculateWidth() - len(prefix)
193                printer()
194                printer(name + ':')
195                for x in ls:
196                    description = '%s: %s' % (x.name, x.description)
197                    for line in textwrap.wrap(description, width,
198                                              initial_indent='  ',
199                                              subsequent_indent='    '):
200                        printer(line)
201
202
203def _FetchCustomMappings(descriptor_ls, package):
204    """Find and return all custom mappings for descriptors in descriptor_ls."""
205    custom_mappings = []
206    for descriptor in descriptor_ls:
207        if isinstance(descriptor, ExtendedEnumDescriptor):
208            custom_mappings.extend(
209                _FormatCustomJsonMapping('Enum', m, descriptor, package)
210                for m in descriptor.enum_mappings)
211        elif isinstance(descriptor, ExtendedMessageDescriptor):
212            custom_mappings.extend(
213                _FormatCustomJsonMapping('Field', m, descriptor, package)
214                for m in descriptor.field_mappings)
215            custom_mappings.extend(
216                _FetchCustomMappings(descriptor.enum_types, package))
217            custom_mappings.extend(
218                _FetchCustomMappings(descriptor.message_types, package))
219    return custom_mappings
220
221
222def _FormatCustomJsonMapping(mapping_type, mapping, descriptor, package):
223    return '\n'.join((
224        'encoding.AddCustomJson%sMapping(' % mapping_type,
225        "    %s, '%s', '%s'," % (descriptor.full_name, mapping.python_name,
226                                 mapping.json_name),
227        '    package=%r)' % package,
228    ))
229
230
231def _EmptyMessage(message_type):
232    return not any((message_type.enum_types,
233                    message_type.message_types,
234                    message_type.fields))
235
236
237class ProtoPrinter(six.with_metaclass(abc.ABCMeta, object)):
238
239    """Interface for proto printers."""
240
241    @abc.abstractmethod
242    def PrintPreamble(self, package, version, file_descriptor):
243        """Print the file docstring and import lines."""
244
245    @abc.abstractmethod
246    def PrintEnum(self, enum_type):
247        """Print the given enum declaration."""
248
249    @abc.abstractmethod
250    def PrintMessage(self, message_type):
251        """Print the given message declaration."""
252
253
254class _Proto2Printer(ProtoPrinter):
255
256    """Printer for proto2 definitions."""
257
258    def __init__(self, printer):
259        self.__printer = printer
260
261    def __PrintEnumCommentLines(self, enum_type):
262        description = enum_type.description or '%s enum type.' % enum_type.name
263        for line in textwrap.wrap(description,
264                                  self.__printer.CalculateWidth() - 3):
265            self.__printer('// %s', line)
266        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values',
267                                  prefix='// ')
268
269    def __PrintEnumValueCommentLines(self, enum_value):
270        if enum_value.description:
271            width = self.__printer.CalculateWidth() - 3
272            for line in textwrap.wrap(enum_value.description, width):
273                self.__printer('// %s', line)
274
275    def PrintEnum(self, enum_type):
276        self.__PrintEnumCommentLines(enum_type)
277        self.__printer('enum %s {', enum_type.name)
278        with self.__printer.Indent():
279            enum_values = sorted(
280                enum_type.values, key=operator.attrgetter('number'))
281            for enum_value in enum_values:
282                self.__printer()
283                self.__PrintEnumValueCommentLines(enum_value)
284                self.__printer('%s = %s;', enum_value.name, enum_value.number)
285        self.__printer('}')
286        self.__printer()
287
288    def PrintPreamble(self, package, version, file_descriptor):
289        self.__printer('// Generated message classes for %s version %s.',
290                       package, version)
291        self.__printer('// NOTE: This file is autogenerated and should not be '
292                       'edited by hand.')
293        description_lines = textwrap.wrap(file_descriptor.description, 75)
294        if description_lines:
295            self.__printer('//')
296            for line in description_lines:
297                self.__printer('// %s', line)
298        self.__printer()
299        self.__printer('syntax = "proto2";')
300        self.__printer('package %s;', file_descriptor.package)
301
302    def __PrintMessageCommentLines(self, message_type):
303        """Print the description of this message."""
304        description = message_type.description or '%s message type.' % (
305            message_type.name)
306        width = self.__printer.CalculateWidth() - 3
307        for line in textwrap.wrap(description, width):
308            self.__printer('// %s', line)
309        PrintIndentedDescriptions(self.__printer, message_type.enum_types,
310                                  'Enums', prefix='// ')
311        PrintIndentedDescriptions(self.__printer, message_type.message_types,
312                                  'Messages', prefix='// ')
313        PrintIndentedDescriptions(self.__printer, message_type.fields,
314                                  'Fields', prefix='// ')
315
316    def __PrintFieldDescription(self, description):
317        for line in textwrap.wrap(description,
318                                  self.__printer.CalculateWidth() - 3):
319            self.__printer('// %s', line)
320
321    def __PrintFields(self, fields):
322        for extended_field in fields:
323            field = extended_field.field_descriptor
324            field_type = messages.Field.lookup_field_type_by_variant(
325                field.variant)
326            self.__printer()
327            self.__PrintFieldDescription(extended_field.description)
328            label = str(field.label).lower()
329            if field_type in (messages.EnumField, messages.MessageField):
330                proto_type = field.type_name
331            else:
332                proto_type = str(field.variant).lower()
333            default_statement = ''
334            if field.default_value:
335                if field_type in [messages.BytesField, messages.StringField]:
336                    default_value = '"%s"' % field.default_value
337                elif field_type is messages.BooleanField:
338                    default_value = str(field.default_value).lower()
339                else:
340                    default_value = str(field.default_value)
341
342                default_statement = ' [default = %s]' % default_value
343            self.__printer(
344                '%s %s %s = %d%s;',
345                label, proto_type, field.name, field.number, default_statement)
346
347    def PrintMessage(self, message_type):
348        self.__printer()
349        self.__PrintMessageCommentLines(message_type)
350        if _EmptyMessage(message_type):
351            self.__printer('message %s {}', message_type.name)
352            return
353        self.__printer('message %s {', message_type.name)
354        with self.__printer.Indent():
355            _PrintEnums(self, message_type.enum_types)
356            _PrintMessages(self, message_type.message_types)
357            self.__PrintFields(message_type.fields)
358        self.__printer('}')
359
360    def PrintCustomJsonMapping(self, mapping_lines):
361        raise NotImplementedError(
362            'Custom JSON encoding not supported for proto2')
363
364
365class _ProtoRpcPrinter(ProtoPrinter):
366
367    """Printer for ProtoRPC definitions."""
368
369    def __init__(self, printer):
370        self.__printer = printer
371
372    def __PrintClassSeparator(self):
373        self.__printer()
374        if not self.__printer.indent:
375            self.__printer()
376
377    def __PrintEnumDocstringLines(self, enum_type):
378        description = enum_type.description or '%s enum type.' % enum_type.name
379        for line in textwrap.wrap('"""%s' % description,
380                                  self.__printer.CalculateWidth()):
381            self.__printer(line)
382        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values')
383        self.__printer('"""')
384
385    def PrintEnum(self, enum_type):
386        self.__printer('class %s(_messages.Enum):', enum_type.name)
387        with self.__printer.Indent():
388            self.__PrintEnumDocstringLines(enum_type)
389            enum_values = sorted(
390                enum_type.values, key=operator.attrgetter('number'))
391            for enum_value in enum_values:
392                self.__printer('%s = %s', enum_value.name, enum_value.number)
393            if not enum_type.values:
394                self.__printer('pass')
395        self.__PrintClassSeparator()
396
397    def __PrintAdditionalImports(self, imports):
398        """Print additional imports needed for protorpc."""
399        google_imports = [x for x in imports if 'google' in x]
400        other_imports = [x for x in imports if 'google' not in x]
401        if other_imports:
402            for import_ in sorted(other_imports):
403                self.__printer(import_)
404            self.__printer()
405        # Note: If we ever were going to add imports from this package, we'd
406        # need to sort those out and put them at the end.
407        if google_imports:
408            for import_ in sorted(google_imports):
409                self.__printer(import_)
410            self.__printer()
411
412    def PrintPreamble(self, package, version, file_descriptor):
413        self.__printer('"""Generated message classes for %s version %s.',
414                       package, version)
415        self.__printer()
416        for line in textwrap.wrap(file_descriptor.description, 78):
417            self.__printer(line)
418        self.__printer('"""')
419        self.__printer('# NOTE: This file is autogenerated and should not be '
420                       'edited by hand.')
421        self.__printer()
422        self.__PrintAdditionalImports(file_descriptor.additional_imports)
423        self.__printer()
424        self.__printer("package = '%s'", file_descriptor.package)
425        self.__printer()
426        self.__printer()
427
428    def __PrintMessageDocstringLines(self, message_type):
429        """Print the docstring for this message."""
430        description = message_type.description or '%s message type.' % (
431            message_type.name)
432        short_description = (
433            _EmptyMessage(message_type) and
434            len(description) < (self.__printer.CalculateWidth() - 6))
435        with self.__printer.CommentContext():
436            if short_description:
437                # Note that we use explicit string interpolation here since
438                # we're in comment context.
439                self.__printer('"""%s"""' % description)
440                return
441            for line in textwrap.wrap('"""%s' % description,
442                                      self.__printer.CalculateWidth()):
443                self.__printer(line)
444
445            PrintIndentedDescriptions(self.__printer, message_type.enum_types,
446                                      'Enums')
447            PrintIndentedDescriptions(
448                self.__printer, message_type.message_types, 'Messages')
449            PrintIndentedDescriptions(
450                self.__printer, message_type.fields, 'Fields')
451            self.__printer('"""')
452            self.__printer()
453
454    def PrintMessage(self, message_type):
455        if message_type.alias_for:
456            self.__printer(
457                '%s = %s', message_type.name, message_type.alias_for)
458            self.__PrintClassSeparator()
459            return
460        for decorator in message_type.decorators:
461            self.__printer('@%s', decorator)
462        self.__printer('class %s(_messages.Message):', message_type.name)
463        with self.__printer.Indent():
464            self.__PrintMessageDocstringLines(message_type)
465            _PrintEnums(self, message_type.enum_types)
466            _PrintMessages(self, message_type.message_types)
467            _PrintFields(message_type.fields, self.__printer)
468        self.__PrintClassSeparator()
469
470    def PrintCustomJsonMapping(self, mapping):
471        self.__printer(mapping)
472
473
474def _PrintEnums(proto_printer, enum_types):
475    """Print all enums to the given proto_printer."""
476    enum_types = sorted(enum_types, key=operator.attrgetter('name'))
477    for enum_type in enum_types:
478        proto_printer.PrintEnum(enum_type)
479
480
481def _PrintMessages(proto_printer, message_list):
482    message_list = sorted(message_list, key=operator.attrgetter('name'))
483    for message_type in message_list:
484        proto_printer.PrintMessage(message_type)
485
486
487_MESSAGE_FIELD_MAP = {
488    message_types.DateTimeMessage.definition_name(): (
489        message_types.DateTimeField),
490}
491
492
493def _PrintFields(fields, printer):
494    for extended_field in fields:
495        field = extended_field.field_descriptor
496        printed_field_info = {
497            'name': field.name,
498            'module': '_messages',
499            'type_name': '',
500            'type_format': '',
501            'number': field.number,
502            'label_format': '',
503            'variant_format': '',
504            'default_format': '',
505        }
506
507        message_field = _MESSAGE_FIELD_MAP.get(field.type_name)
508        if message_field:
509            printed_field_info['module'] = '_message_types'
510            field_type = message_field
511        elif field.type_name == 'extra_types.DateField':
512            printed_field_info['module'] = 'extra_types'
513            field_type = apitools_base.DateField
514        else:
515            field_type = messages.Field.lookup_field_type_by_variant(
516                field.variant)
517
518        if field_type in (messages.EnumField, messages.MessageField):
519            printed_field_info['type_format'] = "'%s', " % field.type_name
520
521        if field.label == protorpc_descriptor.FieldDescriptor.Label.REQUIRED:
522            printed_field_info['label_format'] = ', required=True'
523        elif field.label == protorpc_descriptor.FieldDescriptor.Label.REPEATED:
524            printed_field_info['label_format'] = ', repeated=True'
525
526        if field_type.DEFAULT_VARIANT != field.variant:
527            printed_field_info['variant_format'] = (
528                ', variant=_messages.Variant.%s' % field.variant)
529
530        if field.default_value:
531            if field_type in [messages.BytesField, messages.StringField]:
532                default_value = repr(field.default_value)
533            elif field_type is messages.EnumField:
534                try:
535                    default_value = str(int(field.default_value))
536                except ValueError:
537                    default_value = repr(field.default_value)
538            else:
539                default_value = field.default_value
540
541            printed_field_info[
542                'default_format'] = ', default=%s' % (default_value,)
543
544        printed_field_info['type_name'] = field_type.__name__
545        args = ''.join('%%(%s)s' % field for field in (
546            'type_format',
547            'number',
548            'label_format',
549            'variant_format',
550            'default_format'))
551        format_str = '%%(name)s = %%(module)s.%%(type_name)s(%s)' % args
552        printer(format_str % printed_field_info)
553