1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
32
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
36
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
40
41Below is a straightforward example on how to use this class:
42
43  pool = DescriptorPool()
44  file_descriptor_protos = [ ... ]
45  for file_descriptor_proto in file_descriptor_protos:
46    pool.Add(file_descriptor_proto)
47  my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
48
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
52
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
57
58__author__ = 'matthewtoia@google.com (Matt Toia)'
59
60import sys
61
62from google.protobuf import descriptor
63from google.protobuf import descriptor_database
64from google.protobuf import text_encoding
65
66
67def _NormalizeFullyQualifiedName(name):
68  """Remove leading period from fully-qualified type name.
69
70  Due to b/13860351 in descriptor_database.py, types in the root namespace are
71  generated with a leading period. This function removes that prefix.
72
73  Args:
74    name: A str, the fully-qualified symbol name.
75
76  Returns:
77    A str, the normalized fully-qualified symbol name.
78  """
79  return name.lstrip('.')
80
81
82class DescriptorPool(object):
83  """A collection of protobufs dynamically constructed by descriptor protos."""
84
85  def __init__(self, descriptor_db=None):
86    """Initializes a Pool of proto buffs.
87
88    The descriptor_db argument to the constructor is provided to allow
89    specialized file descriptor proto lookup code to be triggered on demand. An
90    example would be an implementation which will read and compile a file
91    specified in a call to FindFileByName() and not require the call to Add()
92    at all. Results from this database will be cached internally here as well.
93
94    Args:
95      descriptor_db: A secondary source of file descriptors.
96    """
97
98    self._internal_db = descriptor_database.DescriptorDatabase()
99    self._descriptor_db = descriptor_db
100    self._descriptors = {}
101    self._enum_descriptors = {}
102    self._file_descriptors = {}
103
104  def Add(self, file_desc_proto):
105    """Adds the FileDescriptorProto and its types to this pool.
106
107    Args:
108      file_desc_proto: The FileDescriptorProto to add.
109    """
110
111    self._internal_db.Add(file_desc_proto)
112
113  def AddDescriptor(self, desc):
114    """Adds a Descriptor to the pool, non-recursively.
115
116    If the Descriptor contains nested messages or enums, the caller must
117    explicitly register them. This method also registers the FileDescriptor
118    associated with the message.
119
120    Args:
121      desc: A Descriptor.
122    """
123    if not isinstance(desc, descriptor.Descriptor):
124      raise TypeError('Expected instance of descriptor.Descriptor.')
125
126    self._descriptors[desc.full_name] = desc
127    self.AddFileDescriptor(desc.file)
128
129  def AddEnumDescriptor(self, enum_desc):
130    """Adds an EnumDescriptor to the pool.
131
132    This method also registers the FileDescriptor associated with the message.
133
134    Args:
135      enum_desc: An EnumDescriptor.
136    """
137
138    if not isinstance(enum_desc, descriptor.EnumDescriptor):
139      raise TypeError('Expected instance of descriptor.EnumDescriptor.')
140
141    self._enum_descriptors[enum_desc.full_name] = enum_desc
142    self.AddFileDescriptor(enum_desc.file)
143
144  def AddFileDescriptor(self, file_desc):
145    """Adds a FileDescriptor to the pool, non-recursively.
146
147    If the FileDescriptor contains messages or enums, the caller must explicitly
148    register them.
149
150    Args:
151      file_desc: A FileDescriptor.
152    """
153
154    if not isinstance(file_desc, descriptor.FileDescriptor):
155      raise TypeError('Expected instance of descriptor.FileDescriptor.')
156    self._file_descriptors[file_desc.name] = file_desc
157
158  def FindFileByName(self, file_name):
159    """Gets a FileDescriptor by file name.
160
161    Args:
162      file_name: The path to the file to get a descriptor for.
163
164    Returns:
165      A FileDescriptor for the named file.
166
167    Raises:
168      KeyError: if the file can not be found in the pool.
169    """
170
171    try:
172      return self._file_descriptors[file_name]
173    except KeyError:
174      pass
175
176    try:
177      file_proto = self._internal_db.FindFileByName(file_name)
178    except KeyError:
179      _, error, _ = sys.exc_info()  #PY25 compatible for GAE.
180      if self._descriptor_db:
181        file_proto = self._descriptor_db.FindFileByName(file_name)
182      else:
183        raise error
184    if not file_proto:
185      raise KeyError('Cannot find a file named %s' % file_name)
186    return self._ConvertFileProtoToFileDescriptor(file_proto)
187
188  def FindFileContainingSymbol(self, symbol):
189    """Gets the FileDescriptor for the file containing the specified symbol.
190
191    Args:
192      symbol: The name of the symbol to search for.
193
194    Returns:
195      A FileDescriptor that contains the specified symbol.
196
197    Raises:
198      KeyError: if the file can not be found in the pool.
199    """
200
201    symbol = _NormalizeFullyQualifiedName(symbol)
202    try:
203      return self._descriptors[symbol].file
204    except KeyError:
205      pass
206
207    try:
208      return self._enum_descriptors[symbol].file
209    except KeyError:
210      pass
211
212    try:
213      file_proto = self._internal_db.FindFileContainingSymbol(symbol)
214    except KeyError:
215      _, error, _ = sys.exc_info()  #PY25 compatible for GAE.
216      if self._descriptor_db:
217        file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
218      else:
219        raise error
220    if not file_proto:
221      raise KeyError('Cannot find a file containing %s' % symbol)
222    return self._ConvertFileProtoToFileDescriptor(file_proto)
223
224  def FindMessageTypeByName(self, full_name):
225    """Loads the named descriptor from the pool.
226
227    Args:
228      full_name: The full name of the descriptor to load.
229
230    Returns:
231      The descriptor for the named type.
232    """
233
234    full_name = _NormalizeFullyQualifiedName(full_name)
235    if full_name not in self._descriptors:
236      self.FindFileContainingSymbol(full_name)
237    return self._descriptors[full_name]
238
239  def FindEnumTypeByName(self, full_name):
240    """Loads the named enum descriptor from the pool.
241
242    Args:
243      full_name: The full name of the enum descriptor to load.
244
245    Returns:
246      The enum descriptor for the named type.
247    """
248
249    full_name = _NormalizeFullyQualifiedName(full_name)
250    if full_name not in self._enum_descriptors:
251      self.FindFileContainingSymbol(full_name)
252    return self._enum_descriptors[full_name]
253
254  def _ConvertFileProtoToFileDescriptor(self, file_proto):
255    """Creates a FileDescriptor from a proto or returns a cached copy.
256
257    This method also has the side effect of loading all the symbols found in
258    the file into the appropriate dictionaries in the pool.
259
260    Args:
261      file_proto: The proto to convert.
262
263    Returns:
264      A FileDescriptor matching the passed in proto.
265    """
266
267    if file_proto.name not in self._file_descriptors:
268      built_deps = list(self._GetDeps(file_proto.dependency))
269      direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
270
271      file_descriptor = descriptor.FileDescriptor(
272          name=file_proto.name,
273          package=file_proto.package,
274          options=file_proto.options,
275          serialized_pb=file_proto.SerializeToString(),
276          dependencies=direct_deps)
277      scope = {}
278
279      # This loop extracts all the message and enum types from all the
280      # dependencoes of the file_proto. This is necessary to create the
281      # scope of available message types when defining the passed in
282      # file proto.
283      for dependency in built_deps:
284        scope.update(self._ExtractSymbols(
285            dependency.message_types_by_name.values()))
286        scope.update((_PrefixWithDot(enum.full_name), enum)
287                     for enum in dependency.enum_types_by_name.values())
288
289      for message_type in file_proto.message_type:
290        message_desc = self._ConvertMessageDescriptor(
291            message_type, file_proto.package, file_descriptor, scope)
292        file_descriptor.message_types_by_name[message_desc.name] = message_desc
293
294      for enum_type in file_proto.enum_type:
295        file_descriptor.enum_types_by_name[enum_type.name] = (
296            self._ConvertEnumDescriptor(enum_type, file_proto.package,
297                                        file_descriptor, None, scope))
298
299      for index, extension_proto in enumerate(file_proto.extension):
300        extension_desc = self.MakeFieldDescriptor(
301            extension_proto, file_proto.package, index, is_extension=True)
302        extension_desc.containing_type = self._GetTypeFromScope(
303            file_descriptor.package, extension_proto.extendee, scope)
304        self.SetFieldType(extension_proto, extension_desc,
305                          file_descriptor.package, scope)
306        file_descriptor.extensions_by_name[extension_desc.name] = extension_desc
307
308      for desc_proto in file_proto.message_type:
309        self.SetAllFieldTypes(file_proto.package, desc_proto, scope)
310
311      if file_proto.package:
312        desc_proto_prefix = _PrefixWithDot(file_proto.package)
313      else:
314        desc_proto_prefix = ''
315
316      for desc_proto in file_proto.message_type:
317        desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope)
318        file_descriptor.message_types_by_name[desc_proto.name] = desc
319      self.Add(file_proto)
320      self._file_descriptors[file_proto.name] = file_descriptor
321
322    return self._file_descriptors[file_proto.name]
323
324  def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
325                                scope=None):
326    """Adds the proto to the pool in the specified package.
327
328    Args:
329      desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
330      package: The package the proto should be located in.
331      file_desc: The file containing this message.
332      scope: Dict mapping short and full symbols to message and enum types.
333
334    Returns:
335      The added descriptor.
336    """
337
338    if package:
339      desc_name = '.'.join((package, desc_proto.name))
340    else:
341      desc_name = desc_proto.name
342
343    if file_desc is None:
344      file_name = None
345    else:
346      file_name = file_desc.name
347
348    if scope is None:
349      scope = {}
350
351    nested = [
352        self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope)
353        for nested in desc_proto.nested_type]
354    enums = [
355        self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
356        for enum in desc_proto.enum_type]
357    fields = [self.MakeFieldDescriptor(field, desc_name, index)
358              for index, field in enumerate(desc_proto.field)]
359    extensions = [
360        self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True)
361        for index, extension in enumerate(desc_proto.extension)]
362    oneofs = [
363        descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
364                                   index, None, [])
365        for index, desc in enumerate(desc_proto.oneof_decl)]
366    extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
367    if extension_ranges:
368      is_extendable = True
369    else:
370      is_extendable = False
371    desc = descriptor.Descriptor(
372        name=desc_proto.name,
373        full_name=desc_name,
374        filename=file_name,
375        containing_type=None,
376        fields=fields,
377        oneofs=oneofs,
378        nested_types=nested,
379        enum_types=enums,
380        extensions=extensions,
381        options=desc_proto.options,
382        is_extendable=is_extendable,
383        extension_ranges=extension_ranges,
384        file=file_desc,
385        serialized_start=None,
386        serialized_end=None)
387    for nested in desc.nested_types:
388      nested.containing_type = desc
389    for enum in desc.enum_types:
390      enum.containing_type = desc
391    for field_index, field_desc in enumerate(desc_proto.field):
392      if field_desc.HasField('oneof_index'):
393        oneof_index = field_desc.oneof_index
394        oneofs[oneof_index].fields.append(fields[field_index])
395        fields[field_index].containing_oneof = oneofs[oneof_index]
396
397    scope[_PrefixWithDot(desc_name)] = desc
398    self._descriptors[desc_name] = desc
399    return desc
400
401  def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
402                             containing_type=None, scope=None):
403    """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
404
405    Args:
406      enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
407      package: Optional package name for the new message EnumDescriptor.
408      file_desc: The file containing the enum descriptor.
409      containing_type: The type containing this enum.
410      scope: Scope containing available types.
411
412    Returns:
413      The added descriptor
414    """
415
416    if package:
417      enum_name = '.'.join((package, enum_proto.name))
418    else:
419      enum_name = enum_proto.name
420
421    if file_desc is None:
422      file_name = None
423    else:
424      file_name = file_desc.name
425
426    values = [self._MakeEnumValueDescriptor(value, index)
427              for index, value in enumerate(enum_proto.value)]
428    desc = descriptor.EnumDescriptor(name=enum_proto.name,
429                                     full_name=enum_name,
430                                     filename=file_name,
431                                     file=file_desc,
432                                     values=values,
433                                     containing_type=containing_type,
434                                     options=enum_proto.options)
435    scope['.%s' % enum_name] = desc
436    self._enum_descriptors[enum_name] = desc
437    return desc
438
439  def MakeFieldDescriptor(self, field_proto, message_name, index,
440                          is_extension=False):
441    """Creates a field descriptor from a FieldDescriptorProto.
442
443    For message and enum type fields, this method will do a look up
444    in the pool for the appropriate descriptor for that type. If it
445    is unavailable, it will fall back to the _source function to
446    create it. If this type is still unavailable, construction will
447    fail.
448
449    Args:
450      field_proto: The proto describing the field.
451      message_name: The name of the containing message.
452      index: Index of the field
453      is_extension: Indication that this field is for an extension.
454
455    Returns:
456      An initialized FieldDescriptor object
457    """
458
459    if message_name:
460      full_name = '.'.join((message_name, field_proto.name))
461    else:
462      full_name = field_proto.name
463
464    return descriptor.FieldDescriptor(
465        name=field_proto.name,
466        full_name=full_name,
467        index=index,
468        number=field_proto.number,
469        type=field_proto.type,
470        cpp_type=None,
471        message_type=None,
472        enum_type=None,
473        containing_type=None,
474        label=field_proto.label,
475        has_default_value=False,
476        default_value=None,
477        is_extension=is_extension,
478        extension_scope=None,
479        options=field_proto.options)
480
481  def SetAllFieldTypes(self, package, desc_proto, scope):
482    """Sets all the descriptor's fields's types.
483
484    This method also sets the containing types on any extensions.
485
486    Args:
487      package: The current package of desc_proto.
488      desc_proto: The message descriptor to update.
489      scope: Enclosing scope of available types.
490    """
491
492    package = _PrefixWithDot(package)
493
494    main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
495
496    if package == '.':
497      nested_package = _PrefixWithDot(desc_proto.name)
498    else:
499      nested_package = '.'.join([package, desc_proto.name])
500
501    for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
502      self.SetFieldType(field_proto, field_desc, nested_package, scope)
503
504    for extension_proto, extension_desc in (
505        zip(desc_proto.extension, main_desc.extensions)):
506      extension_desc.containing_type = self._GetTypeFromScope(
507          nested_package, extension_proto.extendee, scope)
508      self.SetFieldType(extension_proto, extension_desc, nested_package, scope)
509
510    for nested_type in desc_proto.nested_type:
511      self.SetAllFieldTypes(nested_package, nested_type, scope)
512
513  def SetFieldType(self, field_proto, field_desc, package, scope):
514    """Sets the field's type, cpp_type, message_type and enum_type.
515
516    Args:
517      field_proto: Data about the field in proto format.
518      field_desc: The descriptor to modiy.
519      package: The package the field's container is in.
520      scope: Enclosing scope of available types.
521    """
522    if field_proto.type_name:
523      desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
524    else:
525      desc = None
526
527    if not field_proto.HasField('type'):
528      if isinstance(desc, descriptor.Descriptor):
529        field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
530      else:
531        field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
532
533    field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
534        field_proto.type)
535
536    if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
537        or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
538      field_desc.message_type = desc
539
540    if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
541      field_desc.enum_type = desc
542
543    if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
544      field_desc.has_default_value = False
545      field_desc.default_value = []
546    elif field_proto.HasField('default_value'):
547      field_desc.has_default_value = True
548      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
549          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
550        field_desc.default_value = float(field_proto.default_value)
551      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
552        field_desc.default_value = field_proto.default_value
553      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
554        field_desc.default_value = field_proto.default_value.lower() == 'true'
555      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
556        field_desc.default_value = field_desc.enum_type.values_by_name[
557            field_proto.default_value].index
558      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
559        field_desc.default_value = text_encoding.CUnescape(
560            field_proto.default_value)
561      else:
562        field_desc.default_value = int(field_proto.default_value)
563    else:
564      field_desc.has_default_value = False
565      field_desc.default_value = None
566
567    field_desc.type = field_proto.type
568
569  def _MakeEnumValueDescriptor(self, value_proto, index):
570    """Creates a enum value descriptor object from a enum value proto.
571
572    Args:
573      value_proto: The proto describing the enum value.
574      index: The index of the enum value.
575
576    Returns:
577      An initialized EnumValueDescriptor object.
578    """
579
580    return descriptor.EnumValueDescriptor(
581        name=value_proto.name,
582        index=index,
583        number=value_proto.number,
584        options=value_proto.options,
585        type=None)
586
587  def _ExtractSymbols(self, descriptors):
588    """Pulls out all the symbols from descriptor protos.
589
590    Args:
591      descriptors: The messages to extract descriptors from.
592    Yields:
593      A two element tuple of the type name and descriptor object.
594    """
595
596    for desc in descriptors:
597      yield (_PrefixWithDot(desc.full_name), desc)
598      for symbol in self._ExtractSymbols(desc.nested_types):
599        yield symbol
600      for enum in desc.enum_types:
601        yield (_PrefixWithDot(enum.full_name), enum)
602
603  def _GetDeps(self, dependencies):
604    """Recursively finds dependencies for file protos.
605
606    Args:
607      dependencies: The names of the files being depended on.
608
609    Yields:
610      Each direct and indirect dependency.
611    """
612
613    for dependency in dependencies:
614      dep_desc = self.FindFileByName(dependency)
615      yield dep_desc
616      for parent_dep in dep_desc.dependencies:
617        yield parent_dep
618
619  def _GetTypeFromScope(self, package, type_name, scope):
620    """Finds a given type name in the current scope.
621
622    Args:
623      package: The package the proto should be located in.
624      type_name: The name of the type to be found in the scope.
625      scope: Dict mapping short and full symbols to message and enum types.
626
627    Returns:
628      The descriptor for the requested type.
629    """
630    if type_name not in scope:
631      components = _PrefixWithDot(package).split('.')
632      while components:
633        possible_match = '.'.join(components + [type_name])
634        if possible_match in scope:
635          type_name = possible_match
636          break
637        else:
638          components.pop(-1)
639    return scope[type_name]
640
641
642def _PrefixWithDot(name):
643  return name if name.startswith('.') else '.%s' % name
644