1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import sys
55import struct
56import weakref
57
58import six
59try:
60  import six.moves.copyreg as copyreg
61except ImportError:
62  # On some platforms, for example gMac, we run native Python because there is
63  # nothing like hermetic Python. This means lesser control on the system and
64  # the six.moves package may be missing (is missing on 20150321 on gMac). Be
65  # extra conservative and try to load the old replacement if it fails.
66  import copy_reg as copyreg
67
68# We use "as" to avoid name collisions with variables.
69from google.protobuf.internal import containers
70from google.protobuf.internal import decoder
71from google.protobuf.internal import encoder
72from google.protobuf.internal import enum_type_wrapper
73from google.protobuf.internal import message_listener as message_listener_mod
74from google.protobuf.internal import type_checkers
75from google.protobuf.internal import well_known_types
76from google.protobuf.internal import wire_format
77from google.protobuf import descriptor as descriptor_mod
78from google.protobuf import message as message_mod
79from google.protobuf import symbol_database
80from google.protobuf import text_format
81
82_FieldDescriptor = descriptor_mod.FieldDescriptor
83_AnyFullTypeName = 'google.protobuf.Any'
84
85
86class GeneratedProtocolMessageType(type):
87
88  """Metaclass for protocol message classes created at runtime from Descriptors.
89
90  We add implementations for all methods described in the Message class.  We
91  also create properties to allow getting/setting all fields in the protocol
92  message.  Finally, we create slots to prevent users from accidentally
93  "setting" nonexistent fields in the protocol message, which then wouldn't get
94  serialized / deserialized properly.
95
96  The protocol compiler currently uses this metaclass to create protocol
97  message classes at runtime.  Clients can also manually create their own
98  classes at runtime, as in this example:
99
100  mydescriptor = Descriptor(.....)
101  class MyProtoClass(Message):
102    __metaclass__ = GeneratedProtocolMessageType
103    DESCRIPTOR = mydescriptor
104  myproto_instance = MyProtoClass()
105  myproto.foo_field = 23
106  ...
107
108  The above example will not work for nested types. If you wish to include them,
109  use reflection.MakeClass() instead of manually instantiating the class in
110  order to create the appropriate class structure.
111  """
112
113  # Must be consistent with the protocol-compiler code in
114  # proto2/compiler/internal/generator.*.
115  _DESCRIPTOR_KEY = 'DESCRIPTOR'
116
117  def __new__(cls, name, bases, dictionary):
118    """Custom allocation for runtime-generated class types.
119
120    We override __new__ because this is apparently the only place
121    where we can meaningfully set __slots__ on the class we're creating(?).
122    (The interplay between metaclasses and slots is not very well-documented).
123
124    Args:
125      name: Name of the class (ignored, but required by the
126        metaclass protocol).
127      bases: Base classes of the class we're constructing.
128        (Should be message.Message).  We ignore this field, but
129        it's required by the metaclass protocol
130      dictionary: The class dictionary of the class we're
131        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
132        a Descriptor object describing this protocol message
133        type.
134
135    Returns:
136      Newly-allocated class.
137    """
138    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
139    if descriptor.full_name in well_known_types.WKTBASES:
140      bases += (well_known_types.WKTBASES[descriptor.full_name],)
141    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
142    _AddSlots(descriptor, dictionary)
143
144    superclass = super(GeneratedProtocolMessageType, cls)
145    new_class = superclass.__new__(cls, name, bases, dictionary)
146    return new_class
147
148  def __init__(cls, name, bases, dictionary):
149    """Here we perform the majority of our work on the class.
150    We add enum getters, an __init__ method, implementations
151    of all Message methods, and properties for all fields
152    in the protocol type.
153
154    Args:
155      name: Name of the class (ignored, but required by the
156        metaclass protocol).
157      bases: Base classes of the class we're constructing.
158        (Should be message.Message).  We ignore this field, but
159        it's required by the metaclass protocol
160      dictionary: The class dictionary of the class we're
161        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
162        a Descriptor object describing this protocol message
163        type.
164    """
165    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
166    cls._decoders_by_tag = {}
167    cls._extensions_by_name = {}
168    cls._extensions_by_number = {}
169    if (descriptor.has_options and
170        descriptor.GetOptions().message_set_wire_format):
171      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
172          decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
173
174    # Attach stuff to each FieldDescriptor for quick lookup later on.
175    for field in descriptor.fields:
176      _AttachFieldHelpers(cls, field)
177
178    descriptor._concrete_class = cls  # pylint: disable=protected-access
179    _AddEnumValues(descriptor, cls)
180    _AddInitMethod(descriptor, cls)
181    _AddPropertiesForFields(descriptor, cls)
182    _AddPropertiesForExtensions(descriptor, cls)
183    _AddStaticMethods(cls)
184    _AddMessageMethods(descriptor, cls)
185    _AddPrivateHelperMethods(descriptor, cls)
186    copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
187
188    superclass = super(GeneratedProtocolMessageType, cls)
189    superclass.__init__(name, bases, dictionary)
190
191
192# Stateless helpers for GeneratedProtocolMessageType below.
193# Outside clients should not access these directly.
194#
195# I opted not to make any of these methods on the metaclass, to make it more
196# clear that I'm not really using any state there and to keep clients from
197# thinking that they have direct access to these construction helpers.
198
199
200def _PropertyName(proto_field_name):
201  """Returns the name of the public property attribute which
202  clients can use to get and (in some cases) set the value
203  of a protocol message field.
204
205  Args:
206    proto_field_name: The protocol message field name, exactly
207      as it appears (or would appear) in a .proto file.
208  """
209  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
210  # nnorwitz makes my day by writing:
211  # """
212  # FYI.  See the keyword module in the stdlib. This could be as simple as:
213  #
214  # if keyword.iskeyword(proto_field_name):
215  #   return proto_field_name + "_"
216  # return proto_field_name
217  # """
218  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
219  #   getattr() and setattr() to reflectively manipulate field values.  If we
220  #   rename the properties, then every such user has to also make sure to apply
221  #   the same transformation.  Note that currently if you name a field "yield",
222  #   you can still access it just fine using getattr/setattr -- it's not even
223  #   that cumbersome to do so.
224  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
225  #   position.
226  return proto_field_name
227
228
229def _VerifyExtensionHandle(message, extension_handle):
230  """Verify that the given extension handle is valid."""
231
232  if not isinstance(extension_handle, _FieldDescriptor):
233    raise KeyError('HasExtension() expects an extension handle, got: %s' %
234                   extension_handle)
235
236  if not extension_handle.is_extension:
237    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
238
239  if not extension_handle.containing_type:
240    raise KeyError('"%s" is missing a containing_type.'
241                   % extension_handle.full_name)
242
243  if extension_handle.containing_type is not message.DESCRIPTOR:
244    raise KeyError('Extension "%s" extends message type "%s", but this '
245                   'message is of type "%s".' %
246                   (extension_handle.full_name,
247                    extension_handle.containing_type.full_name,
248                    message.DESCRIPTOR.full_name))
249
250
251def _AddSlots(message_descriptor, dictionary):
252  """Adds a __slots__ entry to dictionary, containing the names of all valid
253  attributes for this message type.
254
255  Args:
256    message_descriptor: A Descriptor instance describing this message type.
257    dictionary: Class dictionary to which we'll add a '__slots__' entry.
258  """
259  dictionary['__slots__'] = ['_cached_byte_size',
260                             '_cached_byte_size_dirty',
261                             '_fields',
262                             '_unknown_fields',
263                             '_is_present_in_parent',
264                             '_listener',
265                             '_listener_for_children',
266                             '__weakref__',
267                             '_oneofs']
268
269
270def _IsMessageSetExtension(field):
271  return (field.is_extension and
272          field.containing_type.has_options and
273          field.containing_type.GetOptions().message_set_wire_format and
274          field.type == _FieldDescriptor.TYPE_MESSAGE and
275          field.label == _FieldDescriptor.LABEL_OPTIONAL)
276
277
278def _IsMapField(field):
279  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
280          field.message_type.has_options and
281          field.message_type.GetOptions().map_entry)
282
283
284def _IsMessageMapField(field):
285  value_type = field.message_type.fields_by_name["value"]
286  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
287
288
289def _AttachFieldHelpers(cls, field_descriptor):
290  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
291  is_packable = (is_repeated and
292                 wire_format.IsTypePackable(field_descriptor.type))
293  if not is_packable:
294    is_packed = False
295  elif field_descriptor.containing_type.syntax == "proto2":
296    is_packed = (field_descriptor.has_options and
297                field_descriptor.GetOptions().packed)
298  else:
299    has_packed_false = (field_descriptor.has_options and
300                        field_descriptor.GetOptions().HasField("packed") and
301                        field_descriptor.GetOptions().packed == False)
302    is_packed = not has_packed_false
303  is_map_entry = _IsMapField(field_descriptor)
304
305  if is_map_entry:
306    field_encoder = encoder.MapEncoder(field_descriptor)
307    sizer = encoder.MapSizer(field_descriptor)
308  elif _IsMessageSetExtension(field_descriptor):
309    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
310    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
311  else:
312    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
313        field_descriptor.number, is_repeated, is_packed)
314    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
315        field_descriptor.number, is_repeated, is_packed)
316
317  field_descriptor._encoder = field_encoder
318  field_descriptor._sizer = sizer
319  field_descriptor._default_constructor = _DefaultValueConstructorForField(
320      field_descriptor)
321
322  def AddDecoder(wiretype, is_packed):
323    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
324    decode_type = field_descriptor.type
325    if (decode_type == _FieldDescriptor.TYPE_ENUM and
326        type_checkers.SupportsOpenEnums(field_descriptor)):
327      decode_type = _FieldDescriptor.TYPE_INT32
328
329    oneof_descriptor = None
330    if field_descriptor.containing_oneof is not None:
331      oneof_descriptor = field_descriptor
332
333    if is_map_entry:
334      is_message_map = _IsMessageMapField(field_descriptor)
335
336      field_decoder = decoder.MapDecoder(
337          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
338          is_message_map)
339    else:
340      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
341              field_descriptor.number, is_repeated, is_packed,
342              field_descriptor, field_descriptor._default_constructor)
343
344    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
345
346  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
347             False)
348
349  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
350    # To support wire compatibility of adding packed = true, add a decoder for
351    # packed values regardless of the field's options.
352    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
353
354
355def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
356  extension_dict = descriptor.extensions_by_name
357  for extension_name, extension_field in extension_dict.items():
358    assert extension_name not in dictionary
359    dictionary[extension_name] = extension_field
360
361
362def _AddEnumValues(descriptor, cls):
363  """Sets class-level attributes for all enum fields defined in this message.
364
365  Also exporting a class-level object that can name enum values.
366
367  Args:
368    descriptor: Descriptor object for this message type.
369    cls: Class we're constructing for this message type.
370  """
371  for enum_type in descriptor.enum_types:
372    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
373    for enum_value in enum_type.values:
374      setattr(cls, enum_value.name, enum_value.number)
375
376
377def _GetInitializeDefaultForMap(field):
378  if field.label != _FieldDescriptor.LABEL_REPEATED:
379    raise ValueError('map_entry set on non-repeated field %s' % (
380        field.name))
381  fields_by_name = field.message_type.fields_by_name
382  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
383
384  value_field = fields_by_name['value']
385  if _IsMessageMapField(field):
386    def MakeMessageMapDefault(message):
387      return containers.MessageMap(
388          message._listener_for_children, value_field.message_type, key_checker)
389    return MakeMessageMapDefault
390  else:
391    value_checker = type_checkers.GetTypeChecker(value_field)
392    def MakePrimitiveMapDefault(message):
393      return containers.ScalarMap(
394          message._listener_for_children, key_checker, value_checker)
395    return MakePrimitiveMapDefault
396
397def _DefaultValueConstructorForField(field):
398  """Returns a function which returns a default value for a field.
399
400  Args:
401    field: FieldDescriptor object for this field.
402
403  The returned function has one argument:
404    message: Message instance containing this field, or a weakref proxy
405      of same.
406
407  That function in turn returns a default value for this field.  The default
408    value may refer back to |message| via a weak reference.
409  """
410
411  if _IsMapField(field):
412    return _GetInitializeDefaultForMap(field)
413
414  if field.label == _FieldDescriptor.LABEL_REPEATED:
415    if field.has_default_value and field.default_value != []:
416      raise ValueError('Repeated field default value not empty list: %s' % (
417          field.default_value))
418    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
419      # We can't look at _concrete_class yet since it might not have
420      # been set.  (Depends on order in which we initialize the classes).
421      message_type = field.message_type
422      def MakeRepeatedMessageDefault(message):
423        return containers.RepeatedCompositeFieldContainer(
424            message._listener_for_children, field.message_type)
425      return MakeRepeatedMessageDefault
426    else:
427      type_checker = type_checkers.GetTypeChecker(field)
428      def MakeRepeatedScalarDefault(message):
429        return containers.RepeatedScalarFieldContainer(
430            message._listener_for_children, type_checker)
431      return MakeRepeatedScalarDefault
432
433  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
434    # _concrete_class may not yet be initialized.
435    message_type = field.message_type
436    def MakeSubMessageDefault(message):
437      result = message_type._concrete_class()
438      result._SetListener(
439          _OneofListener(message, field)
440          if field.containing_oneof is not None
441          else message._listener_for_children)
442      return result
443    return MakeSubMessageDefault
444
445  def MakeScalarDefault(message):
446    # TODO(protobuf-team): This may be broken since there may not be
447    # default_value.  Combine with has_default_value somehow.
448    return field.default_value
449  return MakeScalarDefault
450
451
452def _ReraiseTypeErrorWithFieldName(message_name, field_name):
453  """Re-raise the currently-handled TypeError with the field name added."""
454  exc = sys.exc_info()[1]
455  if len(exc.args) == 1 and type(exc) is TypeError:
456    # simple TypeError; add field name to exception message
457    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
458
459  # re-raise possibly-amended exception with original traceback:
460  six.reraise(type(exc), exc, sys.exc_info()[2])
461
462
463def _AddInitMethod(message_descriptor, cls):
464  """Adds an __init__ method to cls."""
465
466  def _GetIntegerEnumValue(enum_type, value):
467    """Convert a string or integer enum value to an integer.
468
469    If the value is a string, it is converted to the enum value in
470    enum_type with the same name.  If the value is not a string, it's
471    returned as-is.  (No conversion or bounds-checking is done.)
472    """
473    if isinstance(value, six.string_types):
474      try:
475        return enum_type.values_by_name[value].number
476      except KeyError:
477        raise ValueError('Enum type %s: unknown label "%s"' % (
478            enum_type.full_name, value))
479    return value
480
481  def init(self, **kwargs):
482    self._cached_byte_size = 0
483    self._cached_byte_size_dirty = len(kwargs) > 0
484    self._fields = {}
485    # Contains a mapping from oneof field descriptors to the descriptor
486    # of the currently set field in that oneof field.
487    self._oneofs = {}
488
489    # _unknown_fields is () when empty for efficiency, and will be turned into
490    # a list if fields are added.
491    self._unknown_fields = ()
492    self._is_present_in_parent = False
493    self._listener = message_listener_mod.NullMessageListener()
494    self._listener_for_children = _Listener(self)
495    for field_name, field_value in kwargs.items():
496      field = _GetFieldByName(message_descriptor, field_name)
497      if field is None:
498        raise TypeError("%s() got an unexpected keyword argument '%s'" %
499                        (message_descriptor.name, field_name))
500      if field_value is None:
501        # field=None is the same as no field at all.
502        continue
503      if field.label == _FieldDescriptor.LABEL_REPEATED:
504        copy = field._default_constructor(self)
505        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
506          if _IsMapField(field):
507            if _IsMessageMapField(field):
508              for key in field_value:
509                copy[key].MergeFrom(field_value[key])
510            else:
511              copy.update(field_value)
512          else:
513            for val in field_value:
514              if isinstance(val, dict):
515                copy.add(**val)
516              else:
517                copy.add().MergeFrom(val)
518        else:  # Scalar
519          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
520            field_value = [_GetIntegerEnumValue(field.enum_type, val)
521                           for val in field_value]
522          copy.extend(field_value)
523        self._fields[field] = copy
524      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
525        copy = field._default_constructor(self)
526        new_val = field_value
527        if isinstance(field_value, dict):
528          new_val = field.message_type._concrete_class(**field_value)
529        try:
530          copy.MergeFrom(new_val)
531        except TypeError:
532          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
533        self._fields[field] = copy
534      else:
535        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
536          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
537        try:
538          setattr(self, field_name, field_value)
539        except TypeError:
540          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
541
542  init.__module__ = None
543  init.__doc__ = None
544  cls.__init__ = init
545
546
547def _GetFieldByName(message_descriptor, field_name):
548  """Returns a field descriptor by field name.
549
550  Args:
551    message_descriptor: A Descriptor describing all fields in message.
552    field_name: The name of the field to retrieve.
553  Returns:
554    The field descriptor associated with the field name.
555  """
556  try:
557    return message_descriptor.fields_by_name[field_name]
558  except KeyError:
559    raise ValueError('Protocol message %s has no "%s" field.' %
560                     (message_descriptor.name, field_name))
561
562
563def _AddPropertiesForFields(descriptor, cls):
564  """Adds properties for all fields in this protocol message type."""
565  for field in descriptor.fields:
566    _AddPropertiesForField(field, cls)
567
568  if descriptor.is_extendable:
569    # _ExtensionDict is just an adaptor with no state so we allocate a new one
570    # every time it is accessed.
571    cls.Extensions = property(lambda self: _ExtensionDict(self))
572
573
574def _AddPropertiesForField(field, cls):
575  """Adds a public property for a protocol message field.
576  Clients can use this property to get and (in the case
577  of non-repeated scalar fields) directly set the value
578  of a protocol message field.
579
580  Args:
581    field: A FieldDescriptor for this field.
582    cls: The class we're constructing.
583  """
584  # Catch it if we add other types that we should
585  # handle specially here.
586  assert _FieldDescriptor.MAX_CPPTYPE == 10
587
588  constant_name = field.name.upper() + "_FIELD_NUMBER"
589  setattr(cls, constant_name, field.number)
590
591  if field.label == _FieldDescriptor.LABEL_REPEATED:
592    _AddPropertiesForRepeatedField(field, cls)
593  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
594    _AddPropertiesForNonRepeatedCompositeField(field, cls)
595  else:
596    _AddPropertiesForNonRepeatedScalarField(field, cls)
597
598
599def _AddPropertiesForRepeatedField(field, cls):
600  """Adds a public property for a "repeated" protocol message field.  Clients
601  can use this property to get the value of the field, which will be either a
602  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
603  below).
604
605  Note that when clients add values to these containers, we perform
606  type-checking in the case of repeated scalar fields, and we also set any
607  necessary "has" bits as a side-effect.
608
609  Args:
610    field: A FieldDescriptor for this field.
611    cls: The class we're constructing.
612  """
613  proto_field_name = field.name
614  property_name = _PropertyName(proto_field_name)
615
616  def getter(self):
617    field_value = self._fields.get(field)
618    if field_value is None:
619      # Construct a new object to represent this field.
620      field_value = field._default_constructor(self)
621
622      # Atomically check if another thread has preempted us and, if not, swap
623      # in the new object we just created.  If someone has preempted us, we
624      # take that object and discard ours.
625      # WARNING:  We are relying on setdefault() being atomic.  This is true
626      #   in CPython but we haven't investigated others.  This warning appears
627      #   in several other locations in this file.
628      field_value = self._fields.setdefault(field, field_value)
629    return field_value
630  getter.__module__ = None
631  getter.__doc__ = 'Getter for %s.' % proto_field_name
632
633  # We define a setter just so we can throw an exception with a more
634  # helpful error message.
635  def setter(self, new_value):
636    raise AttributeError('Assignment not allowed to repeated field '
637                         '"%s" in protocol message object.' % proto_field_name)
638
639  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
640  setattr(cls, property_name, property(getter, setter, doc=doc))
641
642
643def _AddPropertiesForNonRepeatedScalarField(field, cls):
644  """Adds a public property for a nonrepeated, scalar protocol message field.
645  Clients can use this property to get and directly set the value of the field.
646  Note that when the client sets the value of a field by using this property,
647  all necessary "has" bits are set as a side-effect, and we also perform
648  type-checking.
649
650  Args:
651    field: A FieldDescriptor for this field.
652    cls: The class we're constructing.
653  """
654  proto_field_name = field.name
655  property_name = _PropertyName(proto_field_name)
656  type_checker = type_checkers.GetTypeChecker(field)
657  default_value = field.default_value
658  valid_values = set()
659  is_proto3 = field.containing_type.syntax == "proto3"
660
661  def getter(self):
662    # TODO(protobuf-team): This may be broken since there may not be
663    # default_value.  Combine with has_default_value somehow.
664    return self._fields.get(field, default_value)
665  getter.__module__ = None
666  getter.__doc__ = 'Getter for %s.' % proto_field_name
667
668  clear_when_set_to_default = is_proto3 and not field.containing_oneof
669
670  def field_setter(self, new_value):
671    # pylint: disable=protected-access
672    # Testing the value for truthiness captures all of the proto3 defaults
673    # (0, 0.0, enum 0, and False).
674    new_value = type_checker.CheckValue(new_value)
675    if clear_when_set_to_default and not new_value:
676      self._fields.pop(field, None)
677    else:
678      self._fields[field] = new_value
679    # Check _cached_byte_size_dirty inline to improve performance, since scalar
680    # setters are called frequently.
681    if not self._cached_byte_size_dirty:
682      self._Modified()
683
684  if field.containing_oneof:
685    def setter(self, new_value):
686      field_setter(self, new_value)
687      self._UpdateOneofState(field)
688  else:
689    setter = field_setter
690
691  setter.__module__ = None
692  setter.__doc__ = 'Setter for %s.' % proto_field_name
693
694  # Add a property to encapsulate the getter/setter.
695  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
696  setattr(cls, property_name, property(getter, setter, doc=doc))
697
698
699def _AddPropertiesForNonRepeatedCompositeField(field, cls):
700  """Adds a public property for a nonrepeated, composite protocol message field.
701  A composite field is a "group" or "message" field.
702
703  Clients can use this property to get the value of the field, but cannot
704  assign to the property directly.
705
706  Args:
707    field: A FieldDescriptor for this field.
708    cls: The class we're constructing.
709  """
710  # TODO(robinson): Remove duplication with similar method
711  # for non-repeated scalars.
712  proto_field_name = field.name
713  property_name = _PropertyName(proto_field_name)
714
715  def getter(self):
716    field_value = self._fields.get(field)
717    if field_value is None:
718      # Construct a new object to represent this field.
719      field_value = field._default_constructor(self)
720
721      # Atomically check if another thread has preempted us and, if not, swap
722      # in the new object we just created.  If someone has preempted us, we
723      # take that object and discard ours.
724      # WARNING:  We are relying on setdefault() being atomic.  This is true
725      #   in CPython but we haven't investigated others.  This warning appears
726      #   in several other locations in this file.
727      field_value = self._fields.setdefault(field, field_value)
728    return field_value
729  getter.__module__ = None
730  getter.__doc__ = 'Getter for %s.' % proto_field_name
731
732  # We define a setter just so we can throw an exception with a more
733  # helpful error message.
734  def setter(self, new_value):
735    raise AttributeError('Assignment not allowed to composite field '
736                         '"%s" in protocol message object.' % proto_field_name)
737
738  # Add a property to encapsulate the getter.
739  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
740  setattr(cls, property_name, property(getter, setter, doc=doc))
741
742
743def _AddPropertiesForExtensions(descriptor, cls):
744  """Adds properties for all fields in this protocol message type."""
745  extension_dict = descriptor.extensions_by_name
746  for extension_name, extension_field in extension_dict.items():
747    constant_name = extension_name.upper() + "_FIELD_NUMBER"
748    setattr(cls, constant_name, extension_field.number)
749
750
751def _AddStaticMethods(cls):
752  # TODO(robinson): This probably needs to be thread-safe(?)
753  def RegisterExtension(extension_handle):
754    extension_handle.containing_type = cls.DESCRIPTOR
755    _AttachFieldHelpers(cls, extension_handle)
756
757    # Try to insert our extension, failing if an extension with the same number
758    # already exists.
759    actual_handle = cls._extensions_by_number.setdefault(
760        extension_handle.number, extension_handle)
761    if actual_handle is not extension_handle:
762      raise AssertionError(
763          'Extensions "%s" and "%s" both try to extend message type "%s" with '
764          'field number %d.' %
765          (extension_handle.full_name, actual_handle.full_name,
766           cls.DESCRIPTOR.full_name, extension_handle.number))
767
768    cls._extensions_by_name[extension_handle.full_name] = extension_handle
769
770    handle = extension_handle  # avoid line wrapping
771    if _IsMessageSetExtension(handle):
772      # MessageSet extension.  Also register under type name.
773      cls._extensions_by_name[
774          extension_handle.message_type.full_name] = extension_handle
775
776  cls.RegisterExtension = staticmethod(RegisterExtension)
777
778  def FromString(s):
779    message = cls()
780    message.MergeFromString(s)
781    return message
782  cls.FromString = staticmethod(FromString)
783
784
785def _IsPresent(item):
786  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
787  value should be included in the list returned by ListFields()."""
788
789  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
790    return bool(item[1])
791  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
792    return item[1]._is_present_in_parent
793  else:
794    return True
795
796
797def _AddListFieldsMethod(message_descriptor, cls):
798  """Helper for _AddMessageMethods()."""
799
800  def ListFields(self):
801    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
802    all_fields.sort(key = lambda item: item[0].number)
803    return all_fields
804
805  cls.ListFields = ListFields
806
807_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
808_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
809
810def _AddHasFieldMethod(message_descriptor, cls):
811  """Helper for _AddMessageMethods()."""
812
813  is_proto3 = (message_descriptor.syntax == "proto3")
814  error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
815
816  hassable_fields = {}
817  for field in message_descriptor.fields:
818    if field.label == _FieldDescriptor.LABEL_REPEATED:
819      continue
820    # For proto3, only submessages and fields inside a oneof have presence.
821    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
822        not field.containing_oneof):
823      continue
824    hassable_fields[field.name] = field
825
826  if not is_proto3:
827    # Fields inside oneofs are never repeated (enforced by the compiler).
828    for oneof in message_descriptor.oneofs:
829      hassable_fields[oneof.name] = oneof
830
831  def HasField(self, field_name):
832    try:
833      field = hassable_fields[field_name]
834    except KeyError:
835      raise ValueError(error_msg % field_name)
836
837    if isinstance(field, descriptor_mod.OneofDescriptor):
838      try:
839        return HasField(self, self._oneofs[field].name)
840      except KeyError:
841        return False
842    else:
843      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
844        value = self._fields.get(field)
845        return value is not None and value._is_present_in_parent
846      else:
847        return field in self._fields
848
849  cls.HasField = HasField
850
851
852def _AddClearFieldMethod(message_descriptor, cls):
853  """Helper for _AddMessageMethods()."""
854  def ClearField(self, field_name):
855    try:
856      field = message_descriptor.fields_by_name[field_name]
857    except KeyError:
858      try:
859        field = message_descriptor.oneofs_by_name[field_name]
860        if field in self._oneofs:
861          field = self._oneofs[field]
862        else:
863          return
864      except KeyError:
865        raise ValueError('Protocol message %s() has no "%s" field.' %
866                         (message_descriptor.name, field_name))
867
868    if field in self._fields:
869      # To match the C++ implementation, we need to invalidate iterators
870      # for map fields when ClearField() happens.
871      if hasattr(self._fields[field], 'InvalidateIterators'):
872        self._fields[field].InvalidateIterators()
873
874      # Note:  If the field is a sub-message, its listener will still point
875      #   at us.  That's fine, because the worst than can happen is that it
876      #   will call _Modified() and invalidate our byte size.  Big deal.
877      del self._fields[field]
878
879      if self._oneofs.get(field.containing_oneof, None) is field:
880        del self._oneofs[field.containing_oneof]
881
882    # Always call _Modified() -- even if nothing was changed, this is
883    # a mutating method, and thus calling it should cause the field to become
884    # present in the parent message.
885    self._Modified()
886
887  cls.ClearField = ClearField
888
889
890def _AddClearExtensionMethod(cls):
891  """Helper for _AddMessageMethods()."""
892  def ClearExtension(self, extension_handle):
893    _VerifyExtensionHandle(self, extension_handle)
894
895    # Similar to ClearField(), above.
896    if extension_handle in self._fields:
897      del self._fields[extension_handle]
898    self._Modified()
899  cls.ClearExtension = ClearExtension
900
901
902def _AddHasExtensionMethod(cls):
903  """Helper for _AddMessageMethods()."""
904  def HasExtension(self, extension_handle):
905    _VerifyExtensionHandle(self, extension_handle)
906    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
907      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
908
909    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
910      value = self._fields.get(extension_handle)
911      return value is not None and value._is_present_in_parent
912    else:
913      return extension_handle in self._fields
914  cls.HasExtension = HasExtension
915
916def _InternalUnpackAny(msg):
917  """Unpacks Any message and returns the unpacked message.
918
919  This internal method is differnt from public Any Unpack method which takes
920  the target message as argument. _InternalUnpackAny method does not have
921  target message type and need to find the message type in descriptor pool.
922
923  Args:
924    msg: An Any message to be unpacked.
925
926  Returns:
927    The unpacked message.
928  """
929  type_url = msg.type_url
930  db = symbol_database.Default()
931
932  if not type_url:
933    return None
934
935  # TODO(haberman): For now we just strip the hostname.  Better logic will be
936  # required.
937  type_name = type_url.split("/")[-1]
938  descriptor = db.pool.FindMessageTypeByName(type_name)
939
940  if descriptor is None:
941    return None
942
943  message_class = db.GetPrototype(descriptor)
944  message = message_class()
945
946  message.ParseFromString(msg.value)
947  return message
948
949def _AddEqualsMethod(message_descriptor, cls):
950  """Helper for _AddMessageMethods()."""
951  def __eq__(self, other):
952    if (not isinstance(other, message_mod.Message) or
953        other.DESCRIPTOR != self.DESCRIPTOR):
954      return False
955
956    if self is other:
957      return True
958
959    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
960      any_a = _InternalUnpackAny(self)
961      any_b = _InternalUnpackAny(other)
962      if any_a and any_b:
963        return any_a == any_b
964
965    if not self.ListFields() == other.ListFields():
966      return False
967
968    # Sort unknown fields because their order shouldn't affect equality test.
969    unknown_fields = list(self._unknown_fields)
970    unknown_fields.sort()
971    other_unknown_fields = list(other._unknown_fields)
972    other_unknown_fields.sort()
973
974    return unknown_fields == other_unknown_fields
975
976  cls.__eq__ = __eq__
977
978
979def _AddStrMethod(message_descriptor, cls):
980  """Helper for _AddMessageMethods()."""
981  def __str__(self):
982    return text_format.MessageToString(self)
983  cls.__str__ = __str__
984
985
986def _AddReprMethod(message_descriptor, cls):
987  """Helper for _AddMessageMethods()."""
988  def __repr__(self):
989    return text_format.MessageToString(self)
990  cls.__repr__ = __repr__
991
992
993def _AddUnicodeMethod(unused_message_descriptor, cls):
994  """Helper for _AddMessageMethods()."""
995
996  def __unicode__(self):
997    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
998  cls.__unicode__ = __unicode__
999
1000
1001def _BytesForNonRepeatedElement(value, field_number, field_type):
1002  """Returns the number of bytes needed to serialize a non-repeated element.
1003  The returned byte count includes space for tag information and any
1004  other additional space associated with serializing value.
1005
1006  Args:
1007    value: Value we're serializing.
1008    field_number: Field number of this value.  (Since the field number
1009      is stored as part of a varint-encoded tag, this has an impact
1010      on the total bytes required to serialize the value).
1011    field_type: The type of the field.  One of the TYPE_* constants
1012      within FieldDescriptor.
1013  """
1014  try:
1015    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1016    return fn(field_number, value)
1017  except KeyError:
1018    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1019
1020
1021def _AddByteSizeMethod(message_descriptor, cls):
1022  """Helper for _AddMessageMethods()."""
1023
1024  def ByteSize(self):
1025    if not self._cached_byte_size_dirty:
1026      return self._cached_byte_size
1027
1028    size = 0
1029    for field_descriptor, field_value in self.ListFields():
1030      size += field_descriptor._sizer(field_value)
1031
1032    for tag_bytes, value_bytes in self._unknown_fields:
1033      size += len(tag_bytes) + len(value_bytes)
1034
1035    self._cached_byte_size = size
1036    self._cached_byte_size_dirty = False
1037    self._listener_for_children.dirty = False
1038    return size
1039
1040  cls.ByteSize = ByteSize
1041
1042
1043def _AddSerializeToStringMethod(message_descriptor, cls):
1044  """Helper for _AddMessageMethods()."""
1045
1046  def SerializeToString(self):
1047    # Check if the message has all of its required fields set.
1048    errors = []
1049    if not self.IsInitialized():
1050      raise message_mod.EncodeError(
1051          'Message %s is missing required fields: %s' % (
1052          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1053    return self.SerializePartialToString()
1054  cls.SerializeToString = SerializeToString
1055
1056
1057def _AddSerializePartialToStringMethod(message_descriptor, cls):
1058  """Helper for _AddMessageMethods()."""
1059
1060  def SerializePartialToString(self):
1061    out = BytesIO()
1062    self._InternalSerialize(out.write)
1063    return out.getvalue()
1064  cls.SerializePartialToString = SerializePartialToString
1065
1066  def InternalSerialize(self, write_bytes):
1067    for field_descriptor, field_value in self.ListFields():
1068      field_descriptor._encoder(write_bytes, field_value)
1069    for tag_bytes, value_bytes in self._unknown_fields:
1070      write_bytes(tag_bytes)
1071      write_bytes(value_bytes)
1072  cls._InternalSerialize = InternalSerialize
1073
1074
1075def _AddMergeFromStringMethod(message_descriptor, cls):
1076  """Helper for _AddMessageMethods()."""
1077  def MergeFromString(self, serialized):
1078    length = len(serialized)
1079    try:
1080      if self._InternalParse(serialized, 0, length) != length:
1081        # The only reason _InternalParse would return early is if it
1082        # encountered an end-group tag.
1083        raise message_mod.DecodeError('Unexpected end-group tag.')
1084    except (IndexError, TypeError):
1085      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1086      raise message_mod.DecodeError('Truncated message.')
1087    except struct.error as e:
1088      raise message_mod.DecodeError(e)
1089    return length   # Return this for legacy reasons.
1090  cls.MergeFromString = MergeFromString
1091
1092  local_ReadTag = decoder.ReadTag
1093  local_SkipField = decoder.SkipField
1094  decoders_by_tag = cls._decoders_by_tag
1095  is_proto3 = message_descriptor.syntax == "proto3"
1096
1097  def InternalParse(self, buffer, pos, end):
1098    self._Modified()
1099    field_dict = self._fields
1100    unknown_field_list = self._unknown_fields
1101    while pos != end:
1102      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1103      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1104      if field_decoder is None:
1105        value_start_pos = new_pos
1106        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
1107        if new_pos == -1:
1108          return pos
1109        if not is_proto3:
1110          if not unknown_field_list:
1111            unknown_field_list = self._unknown_fields = []
1112          unknown_field_list.append(
1113              (tag_bytes, buffer[value_start_pos:new_pos]))
1114        pos = new_pos
1115      else:
1116        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1117        if field_desc:
1118          self._UpdateOneofState(field_desc)
1119    return pos
1120  cls._InternalParse = InternalParse
1121
1122
1123def _AddIsInitializedMethod(message_descriptor, cls):
1124  """Adds the IsInitialized and FindInitializationError methods to the
1125  protocol message class."""
1126
1127  required_fields = [field for field in message_descriptor.fields
1128                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1129
1130  def IsInitialized(self, errors=None):
1131    """Checks if all required fields of a message are set.
1132
1133    Args:
1134      errors:  A list which, if provided, will be populated with the field
1135               paths of all missing required fields.
1136
1137    Returns:
1138      True iff the specified message has all required fields set.
1139    """
1140
1141    # Performance is critical so we avoid HasField() and ListFields().
1142
1143    for field in required_fields:
1144      if (field not in self._fields or
1145          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1146           not self._fields[field]._is_present_in_parent)):
1147        if errors is not None:
1148          errors.extend(self.FindInitializationErrors())
1149        return False
1150
1151    for field, value in list(self._fields.items()):  # dict can change size!
1152      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1153        if field.label == _FieldDescriptor.LABEL_REPEATED:
1154          if (field.message_type.has_options and
1155              field.message_type.GetOptions().map_entry):
1156            continue
1157          for element in value:
1158            if not element.IsInitialized():
1159              if errors is not None:
1160                errors.extend(self.FindInitializationErrors())
1161              return False
1162        elif value._is_present_in_parent and not value.IsInitialized():
1163          if errors is not None:
1164            errors.extend(self.FindInitializationErrors())
1165          return False
1166
1167    return True
1168
1169  cls.IsInitialized = IsInitialized
1170
1171  def FindInitializationErrors(self):
1172    """Finds required fields which are not initialized.
1173
1174    Returns:
1175      A list of strings.  Each string is a path to an uninitialized field from
1176      the top-level message, e.g. "foo.bar[5].baz".
1177    """
1178
1179    errors = []  # simplify things
1180
1181    for field in required_fields:
1182      if not self.HasField(field.name):
1183        errors.append(field.name)
1184
1185    for field, value in self.ListFields():
1186      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1187        if field.is_extension:
1188          name = "(%s)" % field.full_name
1189        else:
1190          name = field.name
1191
1192        if _IsMapField(field):
1193          if _IsMessageMapField(field):
1194            for key in value:
1195              element = value[key]
1196              prefix = "%s[%s]." % (name, key)
1197              sub_errors = element.FindInitializationErrors()
1198              errors += [prefix + error for error in sub_errors]
1199          else:
1200            # ScalarMaps can't have any initialization errors.
1201            pass
1202        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1203          for i in range(len(value)):
1204            element = value[i]
1205            prefix = "%s[%d]." % (name, i)
1206            sub_errors = element.FindInitializationErrors()
1207            errors += [prefix + error for error in sub_errors]
1208        else:
1209          prefix = name + "."
1210          sub_errors = value.FindInitializationErrors()
1211          errors += [prefix + error for error in sub_errors]
1212
1213    return errors
1214
1215  cls.FindInitializationErrors = FindInitializationErrors
1216
1217
1218def _AddMergeFromMethod(cls):
1219  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1220  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1221
1222  def MergeFrom(self, msg):
1223    if not isinstance(msg, cls):
1224      raise TypeError(
1225          "Parameter to MergeFrom() must be instance of same class: "
1226          "expected %s got %s." % (cls.__name__, type(msg).__name__))
1227
1228    assert msg is not self
1229    self._Modified()
1230
1231    fields = self._fields
1232
1233    for field, value in msg._fields.items():
1234      if field.label == LABEL_REPEATED:
1235        field_value = fields.get(field)
1236        if field_value is None:
1237          # Construct a new object to represent this field.
1238          field_value = field._default_constructor(self)
1239          fields[field] = field_value
1240        field_value.MergeFrom(value)
1241      elif field.cpp_type == CPPTYPE_MESSAGE:
1242        if value._is_present_in_parent:
1243          field_value = fields.get(field)
1244          if field_value is None:
1245            # Construct a new object to represent this field.
1246            field_value = field._default_constructor(self)
1247            fields[field] = field_value
1248          field_value.MergeFrom(value)
1249      else:
1250        self._fields[field] = value
1251        if field.containing_oneof:
1252          self._UpdateOneofState(field)
1253
1254    if msg._unknown_fields:
1255      if not self._unknown_fields:
1256        self._unknown_fields = []
1257      self._unknown_fields.extend(msg._unknown_fields)
1258
1259  cls.MergeFrom = MergeFrom
1260
1261
1262def _AddWhichOneofMethod(message_descriptor, cls):
1263  def WhichOneof(self, oneof_name):
1264    """Returns the name of the currently set field inside a oneof, or None."""
1265    try:
1266      field = message_descriptor.oneofs_by_name[oneof_name]
1267    except KeyError:
1268      raise ValueError(
1269          'Protocol message has no oneof "%s" field.' % oneof_name)
1270
1271    nested_field = self._oneofs.get(field, None)
1272    if nested_field is not None and self.HasField(nested_field.name):
1273      return nested_field.name
1274    else:
1275      return None
1276
1277  cls.WhichOneof = WhichOneof
1278
1279
1280def _Clear(self):
1281  # Clear fields.
1282  self._fields = {}
1283  self._unknown_fields = ()
1284  self._oneofs = {}
1285  self._Modified()
1286
1287
1288def _DiscardUnknownFields(self):
1289  self._unknown_fields = []
1290  for field, value in self.ListFields():
1291    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1292      if field.label == _FieldDescriptor.LABEL_REPEATED:
1293        for sub_message in value:
1294          sub_message.DiscardUnknownFields()
1295      else:
1296        value.DiscardUnknownFields()
1297
1298
1299def _SetListener(self, listener):
1300  if listener is None:
1301    self._listener = message_listener_mod.NullMessageListener()
1302  else:
1303    self._listener = listener
1304
1305
1306def _AddMessageMethods(message_descriptor, cls):
1307  """Adds implementations of all Message methods to cls."""
1308  _AddListFieldsMethod(message_descriptor, cls)
1309  _AddHasFieldMethod(message_descriptor, cls)
1310  _AddClearFieldMethod(message_descriptor, cls)
1311  if message_descriptor.is_extendable:
1312    _AddClearExtensionMethod(cls)
1313    _AddHasExtensionMethod(cls)
1314  _AddEqualsMethod(message_descriptor, cls)
1315  _AddStrMethod(message_descriptor, cls)
1316  _AddReprMethod(message_descriptor, cls)
1317  _AddUnicodeMethod(message_descriptor, cls)
1318  _AddByteSizeMethod(message_descriptor, cls)
1319  _AddSerializeToStringMethod(message_descriptor, cls)
1320  _AddSerializePartialToStringMethod(message_descriptor, cls)
1321  _AddMergeFromStringMethod(message_descriptor, cls)
1322  _AddIsInitializedMethod(message_descriptor, cls)
1323  _AddMergeFromMethod(cls)
1324  _AddWhichOneofMethod(message_descriptor, cls)
1325  # Adds methods which do not depend on cls.
1326  cls.Clear = _Clear
1327  cls.DiscardUnknownFields = _DiscardUnknownFields
1328  cls._SetListener = _SetListener
1329
1330
1331def _AddPrivateHelperMethods(message_descriptor, cls):
1332  """Adds implementation of private helper methods to cls."""
1333
1334  def Modified(self):
1335    """Sets the _cached_byte_size_dirty bit to true,
1336    and propagates this to our listener iff this was a state change.
1337    """
1338
1339    # Note:  Some callers check _cached_byte_size_dirty before calling
1340    #   _Modified() as an extra optimization.  So, if this method is ever
1341    #   changed such that it does stuff even when _cached_byte_size_dirty is
1342    #   already true, the callers need to be updated.
1343    if not self._cached_byte_size_dirty:
1344      self._cached_byte_size_dirty = True
1345      self._listener_for_children.dirty = True
1346      self._is_present_in_parent = True
1347      self._listener.Modified()
1348
1349  def _UpdateOneofState(self, field):
1350    """Sets field as the active field in its containing oneof.
1351
1352    Will also delete currently active field in the oneof, if it is different
1353    from the argument. Does not mark the message as modified.
1354    """
1355    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1356    if other_field is not field:
1357      del self._fields[other_field]
1358      self._oneofs[field.containing_oneof] = field
1359
1360  cls._Modified = Modified
1361  cls.SetInParent = Modified
1362  cls._UpdateOneofState = _UpdateOneofState
1363
1364
1365class _Listener(object):
1366
1367  """MessageListener implementation that a parent message registers with its
1368  child message.
1369
1370  In order to support semantics like:
1371
1372    foo.bar.baz.qux = 23
1373    assert foo.HasField('bar')
1374
1375  ...child objects must have back references to their parents.
1376  This helper class is at the heart of this support.
1377  """
1378
1379  def __init__(self, parent_message):
1380    """Args:
1381      parent_message: The message whose _Modified() method we should call when
1382        we receive Modified() messages.
1383    """
1384    # This listener establishes a back reference from a child (contained) object
1385    # to its parent (containing) object.  We make this a weak reference to avoid
1386    # creating cyclic garbage when the client finishes with the 'parent' object
1387    # in the tree.
1388    if isinstance(parent_message, weakref.ProxyType):
1389      self._parent_message_weakref = parent_message
1390    else:
1391      self._parent_message_weakref = weakref.proxy(parent_message)
1392
1393    # As an optimization, we also indicate directly on the listener whether
1394    # or not the parent message is dirty.  This way we can avoid traversing
1395    # up the tree in the common case.
1396    self.dirty = False
1397
1398  def Modified(self):
1399    if self.dirty:
1400      return
1401    try:
1402      # Propagate the signal to our parents iff this is the first field set.
1403      self._parent_message_weakref._Modified()
1404    except ReferenceError:
1405      # We can get here if a client has kept a reference to a child object,
1406      # and is now setting a field on it, but the child's parent has been
1407      # garbage-collected.  This is not an error.
1408      pass
1409
1410
1411class _OneofListener(_Listener):
1412  """Special listener implementation for setting composite oneof fields."""
1413
1414  def __init__(self, parent_message, field):
1415    """Args:
1416      parent_message: The message whose _Modified() method we should call when
1417        we receive Modified() messages.
1418      field: The descriptor of the field being set in the parent message.
1419    """
1420    super(_OneofListener, self).__init__(parent_message)
1421    self._field = field
1422
1423  def Modified(self):
1424    """Also updates the state of the containing oneof in the parent message."""
1425    try:
1426      self._parent_message_weakref._UpdateOneofState(self._field)
1427      super(_OneofListener, self).Modified()
1428    except ReferenceError:
1429      pass
1430
1431
1432# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1433# TODO(robinson): Unify error handling of "unknown extension" crap.
1434# TODO(robinson): Support iteritems()-style iteration over all
1435# extensions with the "has" bits turned on?
1436class _ExtensionDict(object):
1437
1438  """Dict-like container for supporting an indexable "Extensions"
1439  field on proto instances.
1440
1441  Note that in all cases we expect extension handles to be
1442  FieldDescriptors.
1443  """
1444
1445  def __init__(self, extended_message):
1446    """extended_message: Message instance for which we are the Extensions dict.
1447    """
1448
1449    self._extended_message = extended_message
1450
1451  def __getitem__(self, extension_handle):
1452    """Returns the current value of the given extension handle."""
1453
1454    _VerifyExtensionHandle(self._extended_message, extension_handle)
1455
1456    result = self._extended_message._fields.get(extension_handle)
1457    if result is not None:
1458      return result
1459
1460    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1461      result = extension_handle._default_constructor(self._extended_message)
1462    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1463      result = extension_handle.message_type._concrete_class()
1464      try:
1465        result._SetListener(self._extended_message._listener_for_children)
1466      except ReferenceError:
1467        pass
1468    else:
1469      # Singular scalar -- just return the default without inserting into the
1470      # dict.
1471      return extension_handle.default_value
1472
1473    # Atomically check if another thread has preempted us and, if not, swap
1474    # in the new object we just created.  If someone has preempted us, we
1475    # take that object and discard ours.
1476    # WARNING:  We are relying on setdefault() being atomic.  This is true
1477    #   in CPython but we haven't investigated others.  This warning appears
1478    #   in several other locations in this file.
1479    result = self._extended_message._fields.setdefault(
1480        extension_handle, result)
1481
1482    return result
1483
1484  def __eq__(self, other):
1485    if not isinstance(other, self.__class__):
1486      return False
1487
1488    my_fields = self._extended_message.ListFields()
1489    other_fields = other._extended_message.ListFields()
1490
1491    # Get rid of non-extension fields.
1492    my_fields    = [ field for field in my_fields    if field.is_extension ]
1493    other_fields = [ field for field in other_fields if field.is_extension ]
1494
1495    return my_fields == other_fields
1496
1497  def __ne__(self, other):
1498    return not self == other
1499
1500  def __hash__(self):
1501    raise TypeError('unhashable object')
1502
1503  # Note that this is only meaningful for non-repeated, scalar extension
1504  # fields.  Note also that we may have to call _Modified() when we do
1505  # successfully set a field this way, to set any necssary "has" bits in the
1506  # ancestors of the extended message.
1507  def __setitem__(self, extension_handle, value):
1508    """If extension_handle specifies a non-repeated, scalar extension
1509    field, sets the value of that field.
1510    """
1511
1512    _VerifyExtensionHandle(self._extended_message, extension_handle)
1513
1514    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1515        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1516      raise TypeError(
1517          'Cannot assign to extension "%s" because it is a repeated or '
1518          'composite type.' % extension_handle.full_name)
1519
1520    # It's slightly wasteful to lookup the type checker each time,
1521    # but we expect this to be a vanishingly uncommon case anyway.
1522    type_checker = type_checkers.GetTypeChecker(extension_handle)
1523    # pylint: disable=protected-access
1524    self._extended_message._fields[extension_handle] = (
1525        type_checker.CheckValue(value))
1526    self._extended_message._Modified()
1527
1528  def _FindExtensionByName(self, name):
1529    """Tries to find a known extension with the specified name.
1530
1531    Args:
1532      name: Extension full name.
1533
1534    Returns:
1535      Extension field descriptor.
1536    """
1537    return self._extended_message._extensions_by_name.get(name, None)
1538
1539  def _FindExtensionByNumber(self, number):
1540    """Tries to find a known extension with the field number.
1541
1542    Args:
1543      number: Extension field number.
1544
1545    Returns:
1546      Extension field descriptor.
1547    """
1548    return self._extended_message._extensions_by_number.get(number, None)
1549