1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Extended protorpc descriptors.
18
19This takes existing protorpc Descriptor classes and adds extra
20properties not directly supported in proto itself, notably field and
21message descriptions. We need this in order to generate protorpc
22message files with comments.
23
24Note that for most of these classes, we can't simply wrap the existing
25message, since we need to change the type of the subfields. We could
26have a "plain" descriptor attached, but that seems like unnecessary
27bookkeeping. Where possible, we purposely reuse existing tag numbers;
28for new fields, we start numbering at 100.
29"""
30import abc
31import operator
32import textwrap
33
34import six
35
36from apitools.base.protorpclite import descriptor as protorpc_descriptor
37from apitools.base.protorpclite import message_types
38from apitools.base.protorpclite import messages
39from apitools.base.py import extra_types
40
41
42class ExtendedEnumValueDescriptor(messages.Message):
43
44    """Enum value descriptor with additional fields.
45
46    Fields:
47      name: Name of enumeration value.
48      number: Number of enumeration value.
49      description: Description of this enum value.
50    """
51    name = messages.StringField(1)
52    number = messages.IntegerField(2, variant=messages.Variant.INT32)
53
54    description = messages.StringField(100)
55
56
57class ExtendedEnumDescriptor(messages.Message):
58
59    """Enum class descriptor with additional fields.
60
61    Fields:
62      name: Name of Enum without any qualification.
63      values: Values defined by Enum class.
64      description: Description of this enum class.
65      full_name: Fully qualified name of this enum class.
66      enum_mappings: Mappings from python to JSON names for enum values.
67    """
68
69    class JsonEnumMapping(messages.Message):
70
71        """Mapping from a python name to the wire name for an enum."""
72        python_name = messages.StringField(1)
73        json_name = messages.StringField(2)
74
75    name = messages.StringField(1)
76    values = messages.MessageField(
77        ExtendedEnumValueDescriptor, 2, repeated=True)
78
79    description = messages.StringField(100)
80    full_name = messages.StringField(101)
81    enum_mappings = messages.MessageField(
82        'JsonEnumMapping', 102, repeated=True)
83
84
85class ExtendedFieldDescriptor(messages.Message):
86
87    """Field descriptor with additional fields.
88
89    Fields:
90      field_descriptor: The underlying field descriptor.
91      name: The name of this field.
92      description: Description of this field.
93    """
94    field_descriptor = messages.MessageField(
95        protorpc_descriptor.FieldDescriptor, 100)
96    # We duplicate the names for easier bookkeeping.
97    name = messages.StringField(101)
98    description = messages.StringField(102)
99
100
101class ExtendedMessageDescriptor(messages.Message):
102
103    """Message descriptor with additional fields.
104
105    Fields:
106      name: Name of Message without any qualification.
107      fields: Fields defined for message.
108      message_types: Nested Message classes defined on message.
109      enum_types: Nested Enum classes defined on message.
110      description: Description of this message.
111      full_name: Full qualified name of this message.
112      decorators: Decorators to include in the definition when printing.
113          Printed in the given order from top to bottom (so the last entry
114          is the innermost decorator).
115      alias_for: This type is just an alias for the named type.
116      field_mappings: Mappings from python to json field names.
117    """
118
119    class JsonFieldMapping(messages.Message):
120
121        """Mapping from a python name to the wire name for a field."""
122        python_name = messages.StringField(1)
123        json_name = messages.StringField(2)
124
125    name = messages.StringField(1)
126    fields = messages.MessageField(ExtendedFieldDescriptor, 2, repeated=True)
127    message_types = messages.MessageField(
128        'extended_descriptor.ExtendedMessageDescriptor', 3, repeated=True)
129    enum_types = messages.MessageField(
130        ExtendedEnumDescriptor, 4, repeated=True)
131
132    description = messages.StringField(100)
133    full_name = messages.StringField(101)
134    decorators = messages.StringField(102, repeated=True)
135    alias_for = messages.StringField(103)
136    field_mappings = messages.MessageField(
137        'JsonFieldMapping', 104, repeated=True)
138
139
140class ExtendedFileDescriptor(messages.Message):
141
142    """File descriptor with additional fields.
143
144    Fields:
145      package: Fully qualified name of package that definitions belong to.
146      message_types: Message definitions contained in file.
147      enum_types: Enum definitions contained in file.
148      description: Description of this file.
149      additional_imports: Extra imports used in this package.
150    """
151    package = messages.StringField(2)
152
153    message_types = messages.MessageField(
154        ExtendedMessageDescriptor, 4, repeated=True)
155    enum_types = messages.MessageField(
156        ExtendedEnumDescriptor, 5, repeated=True)
157
158    description = messages.StringField(100)
159    additional_imports = messages.StringField(101, repeated=True)
160
161
162def _WriteFile(file_descriptor, package, version, proto_printer):
163    """Write the given extended file descriptor to the printer."""
164    proto_printer.PrintPreamble(package, version, file_descriptor)
165    _PrintEnums(proto_printer, file_descriptor.enum_types)
166    _PrintMessages(proto_printer, file_descriptor.message_types)
167    custom_json_mappings = _FetchCustomMappings(file_descriptor.enum_types)
168    custom_json_mappings.extend(
169        _FetchCustomMappings(file_descriptor.message_types))
170    for mapping in custom_json_mappings:
171        proto_printer.PrintCustomJsonMapping(mapping)
172
173
174def WriteMessagesFile(file_descriptor, package, version, printer):
175    """Write the given extended file descriptor to out as a message file."""
176    _WriteFile(file_descriptor, package, version,
177               _Proto2Printer(printer))
178
179
180def WritePythonFile(file_descriptor, package, version, printer):
181    """Write the given extended file descriptor to out."""
182    _WriteFile(file_descriptor, package, version,
183               _ProtoRpcPrinter(printer))
184
185
186def PrintIndentedDescriptions(printer, ls, name, prefix=''):
187    if ls:
188        with printer.Indent(indent=prefix):
189            with printer.CommentContext():
190                width = printer.CalculateWidth() - len(prefix)
191                printer()
192                printer(name + ':')
193                for x in ls:
194                    description = '%s: %s' % (x.name, x.description)
195                    for line in textwrap.wrap(description, width,
196                                              initial_indent='  ',
197                                              subsequent_indent='    '):
198                        printer(line)
199
200
201def _FetchCustomMappings(descriptor_ls):
202    """Find and return all custom mappings for descriptors in descriptor_ls."""
203    custom_mappings = []
204    for descriptor in descriptor_ls:
205        if isinstance(descriptor, ExtendedEnumDescriptor):
206            custom_mappings.extend(
207                _FormatCustomJsonMapping('Enum', m, descriptor)
208                for m in descriptor.enum_mappings)
209        elif isinstance(descriptor, ExtendedMessageDescriptor):
210            custom_mappings.extend(
211                _FormatCustomJsonMapping('Field', m, descriptor)
212                for m in descriptor.field_mappings)
213            custom_mappings.extend(
214                _FetchCustomMappings(descriptor.enum_types))
215            custom_mappings.extend(
216                _FetchCustomMappings(descriptor.message_types))
217    return custom_mappings
218
219
220def _FormatCustomJsonMapping(mapping_type, mapping, descriptor):
221    return '\n'.join((
222        'encoding.AddCustomJson%sMapping(' % mapping_type,
223        "    %s, '%s', '%s')" % (descriptor.full_name, mapping.python_name,
224                                 mapping.json_name),
225    ))
226
227
228def _EmptyMessage(message_type):
229    return not any((message_type.enum_types,
230                    message_type.message_types,
231                    message_type.fields))
232
233
234class ProtoPrinter(six.with_metaclass(abc.ABCMeta, object)):
235
236    """Interface for proto printers."""
237
238    @abc.abstractmethod
239    def PrintPreamble(self, package, version, file_descriptor):
240        """Print the file docstring and import lines."""
241
242    @abc.abstractmethod
243    def PrintEnum(self, enum_type):
244        """Print the given enum declaration."""
245
246    @abc.abstractmethod
247    def PrintMessage(self, message_type):
248        """Print the given message declaration."""
249
250
251class _Proto2Printer(ProtoPrinter):
252
253    """Printer for proto2 definitions."""
254
255    def __init__(self, printer):
256        self.__printer = printer
257
258    def __PrintEnumCommentLines(self, enum_type):
259        description = enum_type.description or '%s enum type.' % enum_type.name
260        for line in textwrap.wrap(description,
261                                  self.__printer.CalculateWidth() - 3):
262            self.__printer('// %s', line)
263        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values',
264                                  prefix='// ')
265
266    def __PrintEnumValueCommentLines(self, enum_value):
267        if enum_value.description:
268            width = self.__printer.CalculateWidth() - 3
269            for line in textwrap.wrap(enum_value.description, width):
270                self.__printer('// %s', line)
271
272    def PrintEnum(self, enum_type):
273        self.__PrintEnumCommentLines(enum_type)
274        self.__printer('enum %s {', enum_type.name)
275        with self.__printer.Indent():
276            enum_values = sorted(
277                enum_type.values, key=operator.attrgetter('number'))
278            for enum_value in enum_values:
279                self.__printer()
280                self.__PrintEnumValueCommentLines(enum_value)
281                self.__printer('%s = %s;', enum_value.name, enum_value.number)
282        self.__printer('}')
283        self.__printer()
284
285    def PrintPreamble(self, package, version, file_descriptor):
286        self.__printer('// Generated message classes for %s version %s.',
287                       package, version)
288        self.__printer('// NOTE: This file is autogenerated and should not be '
289                       'edited by hand.')
290        description_lines = textwrap.wrap(file_descriptor.description, 75)
291        if description_lines:
292            self.__printer('//')
293            for line in description_lines:
294                self.__printer('// %s', line)
295        self.__printer()
296        self.__printer('syntax = "proto2";')
297        self.__printer('package %s;', file_descriptor.package)
298
299    def __PrintMessageCommentLines(self, message_type):
300        """Print the description of this message."""
301        description = message_type.description or '%s message type.' % (
302            message_type.name)
303        width = self.__printer.CalculateWidth() - 3
304        for line in textwrap.wrap(description, width):
305            self.__printer('// %s', line)
306        PrintIndentedDescriptions(self.__printer, message_type.enum_types,
307                                  'Enums', prefix='// ')
308        PrintIndentedDescriptions(self.__printer, message_type.message_types,
309                                  'Messages', prefix='// ')
310        PrintIndentedDescriptions(self.__printer, message_type.fields,
311                                  'Fields', prefix='// ')
312
313    def __PrintFieldDescription(self, description):
314        for line in textwrap.wrap(description,
315                                  self.__printer.CalculateWidth() - 3):
316            self.__printer('// %s', line)
317
318    def __PrintFields(self, fields):
319        for extended_field in fields:
320            field = extended_field.field_descriptor
321            field_type = messages.Field.lookup_field_type_by_variant(
322                field.variant)
323            self.__printer()
324            self.__PrintFieldDescription(extended_field.description)
325            label = str(field.label).lower()
326            if field_type in (messages.EnumField, messages.MessageField):
327                proto_type = field.type_name
328            else:
329                proto_type = str(field.variant).lower()
330            default_statement = ''
331            if field.default_value:
332                if field_type in [messages.BytesField, messages.StringField]:
333                    default_value = '"%s"' % field.default_value
334                elif field_type is messages.BooleanField:
335                    default_value = str(field.default_value).lower()
336                else:
337                    default_value = str(field.default_value)
338
339                default_statement = ' [default = %s]' % default_value
340            self.__printer(
341                '%s %s %s = %d%s;',
342                label, proto_type, field.name, field.number, default_statement)
343
344    def PrintMessage(self, message_type):
345        self.__printer()
346        self.__PrintMessageCommentLines(message_type)
347        if _EmptyMessage(message_type):
348            self.__printer('message %s {}', message_type.name)
349            return
350        self.__printer('message %s {', message_type.name)
351        with self.__printer.Indent():
352            _PrintEnums(self, message_type.enum_types)
353            _PrintMessages(self, message_type.message_types)
354            self.__PrintFields(message_type.fields)
355        self.__printer('}')
356
357    def PrintCustomJsonMapping(self, mapping_lines):
358        raise NotImplementedError(
359            'Custom JSON encoding not supported for proto2')
360
361
362class _ProtoRpcPrinter(ProtoPrinter):
363
364    """Printer for ProtoRPC definitions."""
365
366    def __init__(self, printer):
367        self.__printer = printer
368
369    def __PrintClassSeparator(self):
370        self.__printer()
371        if not self.__printer.indent:
372            self.__printer()
373
374    def __PrintEnumDocstringLines(self, enum_type):
375        description = enum_type.description or '%s enum type.' % enum_type.name
376        for line in textwrap.wrap('r"""%s' % description,
377                                  self.__printer.CalculateWidth()):
378            self.__printer(line)
379        PrintIndentedDescriptions(self.__printer, enum_type.values, 'Values')
380        self.__printer('"""')
381
382    def PrintEnum(self, enum_type):
383        self.__printer('class %s(_messages.Enum):', enum_type.name)
384        with self.__printer.Indent():
385            self.__PrintEnumDocstringLines(enum_type)
386            enum_values = sorted(
387                enum_type.values, key=operator.attrgetter('number'))
388            for enum_value in enum_values:
389                self.__printer('%s = %s', enum_value.name, enum_value.number)
390            if not enum_type.values:
391                self.__printer('pass')
392        self.__PrintClassSeparator()
393
394    def __PrintAdditionalImports(self, imports):
395        """Print additional imports needed for protorpc."""
396        google_imports = [x for x in imports if 'google' in x]
397        other_imports = [x for x in imports if 'google' not in x]
398        if other_imports:
399            for import_ in sorted(other_imports):
400                self.__printer(import_)
401            self.__printer()
402        # Note: If we ever were going to add imports from this package, we'd
403        # need to sort those out and put them at the end.
404        if google_imports:
405            for import_ in sorted(google_imports):
406                self.__printer(import_)
407            self.__printer()
408
409    def PrintPreamble(self, package, version, file_descriptor):
410        self.__printer('"""Generated message classes for %s version %s.',
411                       package, version)
412        self.__printer()
413        for line in textwrap.wrap(file_descriptor.description, 78):
414            self.__printer(line)
415        self.__printer('"""')
416        self.__printer('# NOTE: This file is autogenerated and should not be '
417                       'edited by hand.')
418        self.__printer()
419        self.__PrintAdditionalImports(file_descriptor.additional_imports)
420        self.__printer()
421        self.__printer("package = '%s'", file_descriptor.package)
422        self.__printer()
423        self.__printer()
424
425    def __PrintMessageDocstringLines(self, message_type):
426        """Print the docstring for this message."""
427        description = message_type.description or '%s message type.' % (
428            message_type.name)
429        short_description = (
430            _EmptyMessage(message_type) and
431            len(description) < (self.__printer.CalculateWidth() - 6))
432        with self.__printer.CommentContext():
433            if short_description:
434                # Note that we use explicit string interpolation here since
435                # we're in comment context.
436                self.__printer('r"""%s"""' % description)
437                return
438            for line in textwrap.wrap('r"""%s' % description,
439                                      self.__printer.CalculateWidth()):
440                self.__printer(line)
441
442            PrintIndentedDescriptions(self.__printer, message_type.enum_types,
443                                      'Enums')
444            PrintIndentedDescriptions(
445                self.__printer, message_type.message_types, 'Messages')
446            PrintIndentedDescriptions(
447                self.__printer, message_type.fields, 'Fields')
448            self.__printer('"""')
449            self.__printer()
450
451    def PrintMessage(self, message_type):
452        if message_type.alias_for:
453            self.__printer(
454                '%s = %s', message_type.name, message_type.alias_for)
455            self.__PrintClassSeparator()
456            return
457        for decorator in message_type.decorators:
458            self.__printer('@%s', decorator)
459        self.__printer('class %s(_messages.Message):', message_type.name)
460        with self.__printer.Indent():
461            self.__PrintMessageDocstringLines(message_type)
462            _PrintEnums(self, message_type.enum_types)
463            _PrintMessages(self, message_type.message_types)
464            _PrintFields(message_type.fields, self.__printer)
465        self.__PrintClassSeparator()
466
467    def PrintCustomJsonMapping(self, mapping):
468        self.__printer(mapping)
469
470
471def _PrintEnums(proto_printer, enum_types):
472    """Print all enums to the given proto_printer."""
473    enum_types = sorted(enum_types, key=operator.attrgetter('name'))
474    for enum_type in enum_types:
475        proto_printer.PrintEnum(enum_type)
476
477
478def _PrintMessages(proto_printer, message_list):
479    message_list = sorted(message_list, key=operator.attrgetter('name'))
480    for message_type in message_list:
481        proto_printer.PrintMessage(message_type)
482
483
484_MESSAGE_FIELD_MAP = {
485    message_types.DateTimeMessage.definition_name(): (
486        message_types.DateTimeField),
487}
488
489
490def _PrintFields(fields, printer):
491    for extended_field in fields:
492        field = extended_field.field_descriptor
493        printed_field_info = {
494            'name': field.name,
495            'module': '_messages',
496            'type_name': '',
497            'type_format': '',
498            'number': field.number,
499            'label_format': '',
500            'variant_format': '',
501            'default_format': '',
502        }
503
504        message_field = _MESSAGE_FIELD_MAP.get(field.type_name)
505        if message_field:
506            printed_field_info['module'] = '_message_types'
507            field_type = message_field
508        elif field.type_name == 'extra_types.DateField':
509            printed_field_info['module'] = 'extra_types'
510            field_type = extra_types.DateField
511        else:
512            field_type = messages.Field.lookup_field_type_by_variant(
513                field.variant)
514
515        if field_type in (messages.EnumField, messages.MessageField):
516            printed_field_info['type_format'] = "'%s', " % field.type_name
517
518        if field.label == protorpc_descriptor.FieldDescriptor.Label.REQUIRED:
519            printed_field_info['label_format'] = ', required=True'
520        elif field.label == protorpc_descriptor.FieldDescriptor.Label.REPEATED:
521            printed_field_info['label_format'] = ', repeated=True'
522
523        if field_type.DEFAULT_VARIANT != field.variant:
524            printed_field_info['variant_format'] = (
525                ', variant=_messages.Variant.%s' % field.variant)
526
527        if field.default_value:
528            if field_type in [messages.BytesField, messages.StringField]:
529                default_value = repr(field.default_value)
530            elif field_type is messages.EnumField:
531                try:
532                    default_value = str(int(field.default_value))
533                except ValueError:
534                    default_value = repr(field.default_value)
535            else:
536                default_value = field.default_value
537
538            printed_field_info[
539                'default_format'] = ', default=%s' % (default_value,)
540
541        printed_field_info['type_name'] = field_type.__name__
542        args = ''.join('%%(%s)s' % field for field in (
543            'type_format',
544            'number',
545            'label_format',
546            'variant_format',
547            'default_format'))
548        format_str = '%%(name)s = %%(module)s.%%(type_name)s(%s)' % args
549        printer(format_str % printed_field_info)
550