1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import sys
55import struct
56import weakref
57
58import six
59try:
60  import six.moves.copyreg as copyreg
61except ImportError:
62  # On some platforms, for example gMac, we run native Python because there is
63  # nothing like hermetic Python. This means lesser control on the system and
64  # the six.moves package may be missing (is missing on 20150321 on gMac). Be
65  # extra conservative and try to load the old replacement if it fails.
66  import copy_reg as copyreg
67
68# We use "as" to avoid name collisions with variables.
69from google.protobuf.internal import containers
70from google.protobuf.internal import decoder
71from google.protobuf.internal import encoder
72from google.protobuf.internal import enum_type_wrapper
73from google.protobuf.internal import message_listener as message_listener_mod
74from google.protobuf.internal import type_checkers
75from google.protobuf.internal import well_known_types
76from google.protobuf.internal import wire_format
77from google.protobuf import descriptor as descriptor_mod
78from google.protobuf import message as message_mod
79from google.protobuf import text_format
80
81_FieldDescriptor = descriptor_mod.FieldDescriptor
82_AnyFullTypeName = 'google.protobuf.Any'
83
84
85class GeneratedProtocolMessageType(type):
86
87  """Metaclass for protocol message classes created at runtime from Descriptors.
88
89  We add implementations for all methods described in the Message class.  We
90  also create properties to allow getting/setting all fields in the protocol
91  message.  Finally, we create slots to prevent users from accidentally
92  "setting" nonexistent fields in the protocol message, which then wouldn't get
93  serialized / deserialized properly.
94
95  The protocol compiler currently uses this metaclass to create protocol
96  message classes at runtime.  Clients can also manually create their own
97  classes at runtime, as in this example:
98
99  mydescriptor = Descriptor(.....)
100  factory = symbol_database.Default()
101  factory.pool.AddDescriptor(mydescriptor)
102  MyProtoClass = factory.GetPrototype(mydescriptor)
103  myproto_instance = MyProtoClass()
104  myproto.foo_field = 23
105  ...
106  """
107
108  # Must be consistent with the protocol-compiler code in
109  # proto2/compiler/internal/generator.*.
110  _DESCRIPTOR_KEY = 'DESCRIPTOR'
111
112  def __new__(cls, name, bases, dictionary):
113    """Custom allocation for runtime-generated class types.
114
115    We override __new__ because this is apparently the only place
116    where we can meaningfully set __slots__ on the class we're creating(?).
117    (The interplay between metaclasses and slots is not very well-documented).
118
119    Args:
120      name: Name of the class (ignored, but required by the
121        metaclass protocol).
122      bases: Base classes of the class we're constructing.
123        (Should be message.Message).  We ignore this field, but
124        it's required by the metaclass protocol
125      dictionary: The class dictionary of the class we're
126        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
127        a Descriptor object describing this protocol message
128        type.
129
130    Returns:
131      Newly-allocated class.
132    """
133    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
134    if descriptor.full_name in well_known_types.WKTBASES:
135      bases += (well_known_types.WKTBASES[descriptor.full_name],)
136    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
137    _AddSlots(descriptor, dictionary)
138
139    superclass = super(GeneratedProtocolMessageType, cls)
140    new_class = superclass.__new__(cls, name, bases, dictionary)
141    return new_class
142
143  def __init__(cls, name, bases, dictionary):
144    """Here we perform the majority of our work on the class.
145    We add enum getters, an __init__ method, implementations
146    of all Message methods, and properties for all fields
147    in the protocol type.
148
149    Args:
150      name: Name of the class (ignored, but required by the
151        metaclass protocol).
152      bases: Base classes of the class we're constructing.
153        (Should be message.Message).  We ignore this field, but
154        it's required by the metaclass protocol
155      dictionary: The class dictionary of the class we're
156        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
157        a Descriptor object describing this protocol message
158        type.
159    """
160    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
161    cls._decoders_by_tag = {}
162    cls._extensions_by_name = {}
163    cls._extensions_by_number = {}
164    if (descriptor.has_options and
165        descriptor.GetOptions().message_set_wire_format):
166      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
167          decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
168
169    # Attach stuff to each FieldDescriptor for quick lookup later on.
170    for field in descriptor.fields:
171      _AttachFieldHelpers(cls, field)
172
173    descriptor._concrete_class = cls  # pylint: disable=protected-access
174    _AddEnumValues(descriptor, cls)
175    _AddInitMethod(descriptor, cls)
176    _AddPropertiesForFields(descriptor, cls)
177    _AddPropertiesForExtensions(descriptor, cls)
178    _AddStaticMethods(cls)
179    _AddMessageMethods(descriptor, cls)
180    _AddPrivateHelperMethods(descriptor, cls)
181    copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
182
183    superclass = super(GeneratedProtocolMessageType, cls)
184    superclass.__init__(name, bases, dictionary)
185
186
187# Stateless helpers for GeneratedProtocolMessageType below.
188# Outside clients should not access these directly.
189#
190# I opted not to make any of these methods on the metaclass, to make it more
191# clear that I'm not really using any state there and to keep clients from
192# thinking that they have direct access to these construction helpers.
193
194
195def _PropertyName(proto_field_name):
196  """Returns the name of the public property attribute which
197  clients can use to get and (in some cases) set the value
198  of a protocol message field.
199
200  Args:
201    proto_field_name: The protocol message field name, exactly
202      as it appears (or would appear) in a .proto file.
203  """
204  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
205  # nnorwitz makes my day by writing:
206  # """
207  # FYI.  See the keyword module in the stdlib. This could be as simple as:
208  #
209  # if keyword.iskeyword(proto_field_name):
210  #   return proto_field_name + "_"
211  # return proto_field_name
212  # """
213  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
214  #   getattr() and setattr() to reflectively manipulate field values.  If we
215  #   rename the properties, then every such user has to also make sure to apply
216  #   the same transformation.  Note that currently if you name a field "yield",
217  #   you can still access it just fine using getattr/setattr -- it's not even
218  #   that cumbersome to do so.
219  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
220  #   position.
221  return proto_field_name
222
223
224def _VerifyExtensionHandle(message, extension_handle):
225  """Verify that the given extension handle is valid."""
226
227  if not isinstance(extension_handle, _FieldDescriptor):
228    raise KeyError('HasExtension() expects an extension handle, got: %s' %
229                   extension_handle)
230
231  if not extension_handle.is_extension:
232    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
233
234  if not extension_handle.containing_type:
235    raise KeyError('"%s" is missing a containing_type.'
236                   % extension_handle.full_name)
237
238  if extension_handle.containing_type is not message.DESCRIPTOR:
239    raise KeyError('Extension "%s" extends message type "%s", but this '
240                   'message is of type "%s".' %
241                   (extension_handle.full_name,
242                    extension_handle.containing_type.full_name,
243                    message.DESCRIPTOR.full_name))
244
245
246def _AddSlots(message_descriptor, dictionary):
247  """Adds a __slots__ entry to dictionary, containing the names of all valid
248  attributes for this message type.
249
250  Args:
251    message_descriptor: A Descriptor instance describing this message type.
252    dictionary: Class dictionary to which we'll add a '__slots__' entry.
253  """
254  dictionary['__slots__'] = ['_cached_byte_size',
255                             '_cached_byte_size_dirty',
256                             '_fields',
257                             '_unknown_fields',
258                             '_is_present_in_parent',
259                             '_listener',
260                             '_listener_for_children',
261                             '__weakref__',
262                             '_oneofs']
263
264
265def _IsMessageSetExtension(field):
266  return (field.is_extension and
267          field.containing_type.has_options and
268          field.containing_type.GetOptions().message_set_wire_format and
269          field.type == _FieldDescriptor.TYPE_MESSAGE and
270          field.label == _FieldDescriptor.LABEL_OPTIONAL)
271
272
273def _IsMapField(field):
274  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
275          field.message_type.has_options and
276          field.message_type.GetOptions().map_entry)
277
278
279def _IsMessageMapField(field):
280  value_type = field.message_type.fields_by_name["value"]
281  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
282
283
284def _AttachFieldHelpers(cls, field_descriptor):
285  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
286  is_packable = (is_repeated and
287                 wire_format.IsTypePackable(field_descriptor.type))
288  if not is_packable:
289    is_packed = False
290  elif field_descriptor.containing_type.syntax == "proto2":
291    is_packed = (field_descriptor.has_options and
292                field_descriptor.GetOptions().packed)
293  else:
294    has_packed_false = (field_descriptor.has_options and
295                        field_descriptor.GetOptions().HasField("packed") and
296                        field_descriptor.GetOptions().packed == False)
297    is_packed = not has_packed_false
298  is_map_entry = _IsMapField(field_descriptor)
299
300  if is_map_entry:
301    field_encoder = encoder.MapEncoder(field_descriptor)
302    sizer = encoder.MapSizer(field_descriptor)
303  elif _IsMessageSetExtension(field_descriptor):
304    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
305    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
306  else:
307    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
308        field_descriptor.number, is_repeated, is_packed)
309    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
310        field_descriptor.number, is_repeated, is_packed)
311
312  field_descriptor._encoder = field_encoder
313  field_descriptor._sizer = sizer
314  field_descriptor._default_constructor = _DefaultValueConstructorForField(
315      field_descriptor)
316
317  def AddDecoder(wiretype, is_packed):
318    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
319    decode_type = field_descriptor.type
320    if (decode_type == _FieldDescriptor.TYPE_ENUM and
321        type_checkers.SupportsOpenEnums(field_descriptor)):
322      decode_type = _FieldDescriptor.TYPE_INT32
323
324    oneof_descriptor = None
325    if field_descriptor.containing_oneof is not None:
326      oneof_descriptor = field_descriptor
327
328    if is_map_entry:
329      is_message_map = _IsMessageMapField(field_descriptor)
330
331      field_decoder = decoder.MapDecoder(
332          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
333          is_message_map)
334    else:
335      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
336              field_descriptor.number, is_repeated, is_packed,
337              field_descriptor, field_descriptor._default_constructor)
338
339    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
340
341  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
342             False)
343
344  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
345    # To support wire compatibility of adding packed = true, add a decoder for
346    # packed values regardless of the field's options.
347    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
348
349
350def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
351  extension_dict = descriptor.extensions_by_name
352  for extension_name, extension_field in extension_dict.items():
353    assert extension_name not in dictionary
354    dictionary[extension_name] = extension_field
355
356
357def _AddEnumValues(descriptor, cls):
358  """Sets class-level attributes for all enum fields defined in this message.
359
360  Also exporting a class-level object that can name enum values.
361
362  Args:
363    descriptor: Descriptor object for this message type.
364    cls: Class we're constructing for this message type.
365  """
366  for enum_type in descriptor.enum_types:
367    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
368    for enum_value in enum_type.values:
369      setattr(cls, enum_value.name, enum_value.number)
370
371
372def _GetInitializeDefaultForMap(field):
373  if field.label != _FieldDescriptor.LABEL_REPEATED:
374    raise ValueError('map_entry set on non-repeated field %s' % (
375        field.name))
376  fields_by_name = field.message_type.fields_by_name
377  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
378
379  value_field = fields_by_name['value']
380  if _IsMessageMapField(field):
381    def MakeMessageMapDefault(message):
382      return containers.MessageMap(
383          message._listener_for_children, value_field.message_type, key_checker)
384    return MakeMessageMapDefault
385  else:
386    value_checker = type_checkers.GetTypeChecker(value_field)
387    def MakePrimitiveMapDefault(message):
388      return containers.ScalarMap(
389          message._listener_for_children, key_checker, value_checker)
390    return MakePrimitiveMapDefault
391
392def _DefaultValueConstructorForField(field):
393  """Returns a function which returns a default value for a field.
394
395  Args:
396    field: FieldDescriptor object for this field.
397
398  The returned function has one argument:
399    message: Message instance containing this field, or a weakref proxy
400      of same.
401
402  That function in turn returns a default value for this field.  The default
403    value may refer back to |message| via a weak reference.
404  """
405
406  if _IsMapField(field):
407    return _GetInitializeDefaultForMap(field)
408
409  if field.label == _FieldDescriptor.LABEL_REPEATED:
410    if field.has_default_value and field.default_value != []:
411      raise ValueError('Repeated field default value not empty list: %s' % (
412          field.default_value))
413    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
414      # We can't look at _concrete_class yet since it might not have
415      # been set.  (Depends on order in which we initialize the classes).
416      message_type = field.message_type
417      def MakeRepeatedMessageDefault(message):
418        return containers.RepeatedCompositeFieldContainer(
419            message._listener_for_children, field.message_type)
420      return MakeRepeatedMessageDefault
421    else:
422      type_checker = type_checkers.GetTypeChecker(field)
423      def MakeRepeatedScalarDefault(message):
424        return containers.RepeatedScalarFieldContainer(
425            message._listener_for_children, type_checker)
426      return MakeRepeatedScalarDefault
427
428  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
429    # _concrete_class may not yet be initialized.
430    message_type = field.message_type
431    def MakeSubMessageDefault(message):
432      result = message_type._concrete_class()
433      result._SetListener(
434          _OneofListener(message, field)
435          if field.containing_oneof is not None
436          else message._listener_for_children)
437      return result
438    return MakeSubMessageDefault
439
440  def MakeScalarDefault(message):
441    # TODO(protobuf-team): This may be broken since there may not be
442    # default_value.  Combine with has_default_value somehow.
443    return field.default_value
444  return MakeScalarDefault
445
446
447def _ReraiseTypeErrorWithFieldName(message_name, field_name):
448  """Re-raise the currently-handled TypeError with the field name added."""
449  exc = sys.exc_info()[1]
450  if len(exc.args) == 1 and type(exc) is TypeError:
451    # simple TypeError; add field name to exception message
452    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
453
454  # re-raise possibly-amended exception with original traceback:
455  six.reraise(type(exc), exc, sys.exc_info()[2])
456
457
458def _AddInitMethod(message_descriptor, cls):
459  """Adds an __init__ method to cls."""
460
461  def _GetIntegerEnumValue(enum_type, value):
462    """Convert a string or integer enum value to an integer.
463
464    If the value is a string, it is converted to the enum value in
465    enum_type with the same name.  If the value is not a string, it's
466    returned as-is.  (No conversion or bounds-checking is done.)
467    """
468    if isinstance(value, six.string_types):
469      try:
470        return enum_type.values_by_name[value].number
471      except KeyError:
472        raise ValueError('Enum type %s: unknown label "%s"' % (
473            enum_type.full_name, value))
474    return value
475
476  def init(self, **kwargs):
477    self._cached_byte_size = 0
478    self._cached_byte_size_dirty = len(kwargs) > 0
479    self._fields = {}
480    # Contains a mapping from oneof field descriptors to the descriptor
481    # of the currently set field in that oneof field.
482    self._oneofs = {}
483
484    # _unknown_fields is () when empty for efficiency, and will be turned into
485    # a list if fields are added.
486    self._unknown_fields = ()
487    self._is_present_in_parent = False
488    self._listener = message_listener_mod.NullMessageListener()
489    self._listener_for_children = _Listener(self)
490    for field_name, field_value in kwargs.items():
491      field = _GetFieldByName(message_descriptor, field_name)
492      if field is None:
493        raise TypeError("%s() got an unexpected keyword argument '%s'" %
494                        (message_descriptor.name, field_name))
495      if field_value is None:
496        # field=None is the same as no field at all.
497        continue
498      if field.label == _FieldDescriptor.LABEL_REPEATED:
499        copy = field._default_constructor(self)
500        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
501          if _IsMapField(field):
502            if _IsMessageMapField(field):
503              for key in field_value:
504                copy[key].MergeFrom(field_value[key])
505            else:
506              copy.update(field_value)
507          else:
508            for val in field_value:
509              if isinstance(val, dict):
510                copy.add(**val)
511              else:
512                copy.add().MergeFrom(val)
513        else:  # Scalar
514          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
515            field_value = [_GetIntegerEnumValue(field.enum_type, val)
516                           for val in field_value]
517          copy.extend(field_value)
518        self._fields[field] = copy
519      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
520        copy = field._default_constructor(self)
521        new_val = field_value
522        if isinstance(field_value, dict):
523          new_val = field.message_type._concrete_class(**field_value)
524        try:
525          copy.MergeFrom(new_val)
526        except TypeError:
527          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
528        self._fields[field] = copy
529      else:
530        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
531          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
532        try:
533          setattr(self, field_name, field_value)
534        except TypeError:
535          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
536
537  init.__module__ = None
538  init.__doc__ = None
539  cls.__init__ = init
540
541
542def _GetFieldByName(message_descriptor, field_name):
543  """Returns a field descriptor by field name.
544
545  Args:
546    message_descriptor: A Descriptor describing all fields in message.
547    field_name: The name of the field to retrieve.
548  Returns:
549    The field descriptor associated with the field name.
550  """
551  try:
552    return message_descriptor.fields_by_name[field_name]
553  except KeyError:
554    raise ValueError('Protocol message %s has no "%s" field.' %
555                     (message_descriptor.name, field_name))
556
557
558def _AddPropertiesForFields(descriptor, cls):
559  """Adds properties for all fields in this protocol message type."""
560  for field in descriptor.fields:
561    _AddPropertiesForField(field, cls)
562
563  if descriptor.is_extendable:
564    # _ExtensionDict is just an adaptor with no state so we allocate a new one
565    # every time it is accessed.
566    cls.Extensions = property(lambda self: _ExtensionDict(self))
567
568
569def _AddPropertiesForField(field, cls):
570  """Adds a public property for a protocol message field.
571  Clients can use this property to get and (in the case
572  of non-repeated scalar fields) directly set the value
573  of a protocol message field.
574
575  Args:
576    field: A FieldDescriptor for this field.
577    cls: The class we're constructing.
578  """
579  # Catch it if we add other types that we should
580  # handle specially here.
581  assert _FieldDescriptor.MAX_CPPTYPE == 10
582
583  constant_name = field.name.upper() + "_FIELD_NUMBER"
584  setattr(cls, constant_name, field.number)
585
586  if field.label == _FieldDescriptor.LABEL_REPEATED:
587    _AddPropertiesForRepeatedField(field, cls)
588  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
589    _AddPropertiesForNonRepeatedCompositeField(field, cls)
590  else:
591    _AddPropertiesForNonRepeatedScalarField(field, cls)
592
593
594def _AddPropertiesForRepeatedField(field, cls):
595  """Adds a public property for a "repeated" protocol message field.  Clients
596  can use this property to get the value of the field, which will be either a
597  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
598  below).
599
600  Note that when clients add values to these containers, we perform
601  type-checking in the case of repeated scalar fields, and we also set any
602  necessary "has" bits as a side-effect.
603
604  Args:
605    field: A FieldDescriptor for this field.
606    cls: The class we're constructing.
607  """
608  proto_field_name = field.name
609  property_name = _PropertyName(proto_field_name)
610
611  def getter(self):
612    field_value = self._fields.get(field)
613    if field_value is None:
614      # Construct a new object to represent this field.
615      field_value = field._default_constructor(self)
616
617      # Atomically check if another thread has preempted us and, if not, swap
618      # in the new object we just created.  If someone has preempted us, we
619      # take that object and discard ours.
620      # WARNING:  We are relying on setdefault() being atomic.  This is true
621      #   in CPython but we haven't investigated others.  This warning appears
622      #   in several other locations in this file.
623      field_value = self._fields.setdefault(field, field_value)
624    return field_value
625  getter.__module__ = None
626  getter.__doc__ = 'Getter for %s.' % proto_field_name
627
628  # We define a setter just so we can throw an exception with a more
629  # helpful error message.
630  def setter(self, new_value):
631    raise AttributeError('Assignment not allowed to repeated field '
632                         '"%s" in protocol message object.' % proto_field_name)
633
634  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
635  setattr(cls, property_name, property(getter, setter, doc=doc))
636
637
638def _AddPropertiesForNonRepeatedScalarField(field, cls):
639  """Adds a public property for a nonrepeated, scalar protocol message field.
640  Clients can use this property to get and directly set the value of the field.
641  Note that when the client sets the value of a field by using this property,
642  all necessary "has" bits are set as a side-effect, and we also perform
643  type-checking.
644
645  Args:
646    field: A FieldDescriptor for this field.
647    cls: The class we're constructing.
648  """
649  proto_field_name = field.name
650  property_name = _PropertyName(proto_field_name)
651  type_checker = type_checkers.GetTypeChecker(field)
652  default_value = field.default_value
653  valid_values = set()
654  is_proto3 = field.containing_type.syntax == "proto3"
655
656  def getter(self):
657    # TODO(protobuf-team): This may be broken since there may not be
658    # default_value.  Combine with has_default_value somehow.
659    return self._fields.get(field, default_value)
660  getter.__module__ = None
661  getter.__doc__ = 'Getter for %s.' % proto_field_name
662
663  clear_when_set_to_default = is_proto3 and not field.containing_oneof
664
665  def field_setter(self, new_value):
666    # pylint: disable=protected-access
667    # Testing the value for truthiness captures all of the proto3 defaults
668    # (0, 0.0, enum 0, and False).
669    new_value = type_checker.CheckValue(new_value)
670    if clear_when_set_to_default and not new_value:
671      self._fields.pop(field, None)
672    else:
673      self._fields[field] = new_value
674    # Check _cached_byte_size_dirty inline to improve performance, since scalar
675    # setters are called frequently.
676    if not self._cached_byte_size_dirty:
677      self._Modified()
678
679  if field.containing_oneof:
680    def setter(self, new_value):
681      field_setter(self, new_value)
682      self._UpdateOneofState(field)
683  else:
684    setter = field_setter
685
686  setter.__module__ = None
687  setter.__doc__ = 'Setter for %s.' % proto_field_name
688
689  # Add a property to encapsulate the getter/setter.
690  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
691  setattr(cls, property_name, property(getter, setter, doc=doc))
692
693
694def _AddPropertiesForNonRepeatedCompositeField(field, cls):
695  """Adds a public property for a nonrepeated, composite protocol message field.
696  A composite field is a "group" or "message" field.
697
698  Clients can use this property to get the value of the field, but cannot
699  assign to the property directly.
700
701  Args:
702    field: A FieldDescriptor for this field.
703    cls: The class we're constructing.
704  """
705  # TODO(robinson): Remove duplication with similar method
706  # for non-repeated scalars.
707  proto_field_name = field.name
708  property_name = _PropertyName(proto_field_name)
709
710  def getter(self):
711    field_value = self._fields.get(field)
712    if field_value is None:
713      # Construct a new object to represent this field.
714      field_value = field._default_constructor(self)
715
716      # Atomically check if another thread has preempted us and, if not, swap
717      # in the new object we just created.  If someone has preempted us, we
718      # take that object and discard ours.
719      # WARNING:  We are relying on setdefault() being atomic.  This is true
720      #   in CPython but we haven't investigated others.  This warning appears
721      #   in several other locations in this file.
722      field_value = self._fields.setdefault(field, field_value)
723    return field_value
724  getter.__module__ = None
725  getter.__doc__ = 'Getter for %s.' % proto_field_name
726
727  # We define a setter just so we can throw an exception with a more
728  # helpful error message.
729  def setter(self, new_value):
730    raise AttributeError('Assignment not allowed to composite field '
731                         '"%s" in protocol message object.' % proto_field_name)
732
733  # Add a property to encapsulate the getter.
734  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
735  setattr(cls, property_name, property(getter, setter, doc=doc))
736
737
738def _AddPropertiesForExtensions(descriptor, cls):
739  """Adds properties for all fields in this protocol message type."""
740  extension_dict = descriptor.extensions_by_name
741  for extension_name, extension_field in extension_dict.items():
742    constant_name = extension_name.upper() + "_FIELD_NUMBER"
743    setattr(cls, constant_name, extension_field.number)
744
745
746def _AddStaticMethods(cls):
747  # TODO(robinson): This probably needs to be thread-safe(?)
748  def RegisterExtension(extension_handle):
749    extension_handle.containing_type = cls.DESCRIPTOR
750    _AttachFieldHelpers(cls, extension_handle)
751
752    # Try to insert our extension, failing if an extension with the same number
753    # already exists.
754    actual_handle = cls._extensions_by_number.setdefault(
755        extension_handle.number, extension_handle)
756    if actual_handle is not extension_handle:
757      raise AssertionError(
758          'Extensions "%s" and "%s" both try to extend message type "%s" with '
759          'field number %d.' %
760          (extension_handle.full_name, actual_handle.full_name,
761           cls.DESCRIPTOR.full_name, extension_handle.number))
762
763    cls._extensions_by_name[extension_handle.full_name] = extension_handle
764
765    handle = extension_handle  # avoid line wrapping
766    if _IsMessageSetExtension(handle):
767      # MessageSet extension.  Also register under type name.
768      cls._extensions_by_name[
769          extension_handle.message_type.full_name] = extension_handle
770
771  cls.RegisterExtension = staticmethod(RegisterExtension)
772
773  def FromString(s):
774    message = cls()
775    message.MergeFromString(s)
776    return message
777  cls.FromString = staticmethod(FromString)
778
779
780def _IsPresent(item):
781  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
782  value should be included in the list returned by ListFields()."""
783
784  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
785    return bool(item[1])
786  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
787    return item[1]._is_present_in_parent
788  else:
789    return True
790
791
792def _AddListFieldsMethod(message_descriptor, cls):
793  """Helper for _AddMessageMethods()."""
794
795  def ListFields(self):
796    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
797    all_fields.sort(key = lambda item: item[0].number)
798    return all_fields
799
800  cls.ListFields = ListFields
801
802_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
803_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
804
805def _AddHasFieldMethod(message_descriptor, cls):
806  """Helper for _AddMessageMethods()."""
807
808  is_proto3 = (message_descriptor.syntax == "proto3")
809  error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
810
811  hassable_fields = {}
812  for field in message_descriptor.fields:
813    if field.label == _FieldDescriptor.LABEL_REPEATED:
814      continue
815    # For proto3, only submessages and fields inside a oneof have presence.
816    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
817        not field.containing_oneof):
818      continue
819    hassable_fields[field.name] = field
820
821  if not is_proto3:
822    # Fields inside oneofs are never repeated (enforced by the compiler).
823    for oneof in message_descriptor.oneofs:
824      hassable_fields[oneof.name] = oneof
825
826  def HasField(self, field_name):
827    try:
828      field = hassable_fields[field_name]
829    except KeyError:
830      raise ValueError(error_msg % field_name)
831
832    if isinstance(field, descriptor_mod.OneofDescriptor):
833      try:
834        return HasField(self, self._oneofs[field].name)
835      except KeyError:
836        return False
837    else:
838      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
839        value = self._fields.get(field)
840        return value is not None and value._is_present_in_parent
841      else:
842        return field in self._fields
843
844  cls.HasField = HasField
845
846
847def _AddClearFieldMethod(message_descriptor, cls):
848  """Helper for _AddMessageMethods()."""
849  def ClearField(self, field_name):
850    try:
851      field = message_descriptor.fields_by_name[field_name]
852    except KeyError:
853      try:
854        field = message_descriptor.oneofs_by_name[field_name]
855        if field in self._oneofs:
856          field = self._oneofs[field]
857        else:
858          return
859      except KeyError:
860        raise ValueError('Protocol message %s() has no "%s" field.' %
861                         (message_descriptor.name, field_name))
862
863    if field in self._fields:
864      # To match the C++ implementation, we need to invalidate iterators
865      # for map fields when ClearField() happens.
866      if hasattr(self._fields[field], 'InvalidateIterators'):
867        self._fields[field].InvalidateIterators()
868
869      # Note:  If the field is a sub-message, its listener will still point
870      #   at us.  That's fine, because the worst than can happen is that it
871      #   will call _Modified() and invalidate our byte size.  Big deal.
872      del self._fields[field]
873
874      if self._oneofs.get(field.containing_oneof, None) is field:
875        del self._oneofs[field.containing_oneof]
876
877    # Always call _Modified() -- even if nothing was changed, this is
878    # a mutating method, and thus calling it should cause the field to become
879    # present in the parent message.
880    self._Modified()
881
882  cls.ClearField = ClearField
883
884
885def _AddClearExtensionMethod(cls):
886  """Helper for _AddMessageMethods()."""
887  def ClearExtension(self, extension_handle):
888    _VerifyExtensionHandle(self, extension_handle)
889
890    # Similar to ClearField(), above.
891    if extension_handle in self._fields:
892      del self._fields[extension_handle]
893    self._Modified()
894  cls.ClearExtension = ClearExtension
895
896
897def _AddHasExtensionMethod(cls):
898  """Helper for _AddMessageMethods()."""
899  def HasExtension(self, extension_handle):
900    _VerifyExtensionHandle(self, extension_handle)
901    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
902      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
903
904    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
905      value = self._fields.get(extension_handle)
906      return value is not None and value._is_present_in_parent
907    else:
908      return extension_handle in self._fields
909  cls.HasExtension = HasExtension
910
911def _InternalUnpackAny(msg):
912  """Unpacks Any message and returns the unpacked message.
913
914  This internal method is differnt from public Any Unpack method which takes
915  the target message as argument. _InternalUnpackAny method does not have
916  target message type and need to find the message type in descriptor pool.
917
918  Args:
919    msg: An Any message to be unpacked.
920
921  Returns:
922    The unpacked message.
923  """
924  # TODO(amauryfa): Don't use the factory of generated messages.
925  # To make Any work with custom factories, use the message factory of the
926  # parent message.
927  # pylint: disable=g-import-not-at-top
928  from google.protobuf import symbol_database
929  factory = symbol_database.Default()
930
931  type_url = msg.type_url
932
933  if not type_url:
934    return None
935
936  # TODO(haberman): For now we just strip the hostname.  Better logic will be
937  # required.
938  type_name = type_url.split('/')[-1]
939  descriptor = factory.pool.FindMessageTypeByName(type_name)
940
941  if descriptor is None:
942    return None
943
944  message_class = factory.GetPrototype(descriptor)
945  message = message_class()
946
947  message.ParseFromString(msg.value)
948  return message
949
950
951def _AddEqualsMethod(message_descriptor, cls):
952  """Helper for _AddMessageMethods()."""
953  def __eq__(self, other):
954    if (not isinstance(other, message_mod.Message) or
955        other.DESCRIPTOR != self.DESCRIPTOR):
956      return False
957
958    if self is other:
959      return True
960
961    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
962      any_a = _InternalUnpackAny(self)
963      any_b = _InternalUnpackAny(other)
964      if any_a and any_b:
965        return any_a == any_b
966
967    if not self.ListFields() == other.ListFields():
968      return False
969
970    # Sort unknown fields because their order shouldn't affect equality test.
971    unknown_fields = list(self._unknown_fields)
972    unknown_fields.sort()
973    other_unknown_fields = list(other._unknown_fields)
974    other_unknown_fields.sort()
975
976    return unknown_fields == other_unknown_fields
977
978  cls.__eq__ = __eq__
979
980
981def _AddStrMethod(message_descriptor, cls):
982  """Helper for _AddMessageMethods()."""
983  def __str__(self):
984    return text_format.MessageToString(self)
985  cls.__str__ = __str__
986
987
988def _AddReprMethod(message_descriptor, cls):
989  """Helper for _AddMessageMethods()."""
990  def __repr__(self):
991    return text_format.MessageToString(self)
992  cls.__repr__ = __repr__
993
994
995def _AddUnicodeMethod(unused_message_descriptor, cls):
996  """Helper for _AddMessageMethods()."""
997
998  def __unicode__(self):
999    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1000  cls.__unicode__ = __unicode__
1001
1002
1003def _BytesForNonRepeatedElement(value, field_number, field_type):
1004  """Returns the number of bytes needed to serialize a non-repeated element.
1005  The returned byte count includes space for tag information and any
1006  other additional space associated with serializing value.
1007
1008  Args:
1009    value: Value we're serializing.
1010    field_number: Field number of this value.  (Since the field number
1011      is stored as part of a varint-encoded tag, this has an impact
1012      on the total bytes required to serialize the value).
1013    field_type: The type of the field.  One of the TYPE_* constants
1014      within FieldDescriptor.
1015  """
1016  try:
1017    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1018    return fn(field_number, value)
1019  except KeyError:
1020    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1021
1022
1023def _AddByteSizeMethod(message_descriptor, cls):
1024  """Helper for _AddMessageMethods()."""
1025
1026  def ByteSize(self):
1027    if not self._cached_byte_size_dirty:
1028      return self._cached_byte_size
1029
1030    size = 0
1031    for field_descriptor, field_value in self.ListFields():
1032      size += field_descriptor._sizer(field_value)
1033
1034    for tag_bytes, value_bytes in self._unknown_fields:
1035      size += len(tag_bytes) + len(value_bytes)
1036
1037    self._cached_byte_size = size
1038    self._cached_byte_size_dirty = False
1039    self._listener_for_children.dirty = False
1040    return size
1041
1042  cls.ByteSize = ByteSize
1043
1044
1045def _AddSerializeToStringMethod(message_descriptor, cls):
1046  """Helper for _AddMessageMethods()."""
1047
1048  def SerializeToString(self):
1049    # Check if the message has all of its required fields set.
1050    errors = []
1051    if not self.IsInitialized():
1052      raise message_mod.EncodeError(
1053          'Message %s is missing required fields: %s' % (
1054          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1055    return self.SerializePartialToString()
1056  cls.SerializeToString = SerializeToString
1057
1058
1059def _AddSerializePartialToStringMethod(message_descriptor, cls):
1060  """Helper for _AddMessageMethods()."""
1061
1062  def SerializePartialToString(self):
1063    out = BytesIO()
1064    self._InternalSerialize(out.write)
1065    return out.getvalue()
1066  cls.SerializePartialToString = SerializePartialToString
1067
1068  def InternalSerialize(self, write_bytes):
1069    for field_descriptor, field_value in self.ListFields():
1070      field_descriptor._encoder(write_bytes, field_value)
1071    for tag_bytes, value_bytes in self._unknown_fields:
1072      write_bytes(tag_bytes)
1073      write_bytes(value_bytes)
1074  cls._InternalSerialize = InternalSerialize
1075
1076
1077def _AddMergeFromStringMethod(message_descriptor, cls):
1078  """Helper for _AddMessageMethods()."""
1079  def MergeFromString(self, serialized):
1080    length = len(serialized)
1081    try:
1082      if self._InternalParse(serialized, 0, length) != length:
1083        # The only reason _InternalParse would return early is if it
1084        # encountered an end-group tag.
1085        raise message_mod.DecodeError('Unexpected end-group tag.')
1086    except (IndexError, TypeError):
1087      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1088      raise message_mod.DecodeError('Truncated message.')
1089    except struct.error as e:
1090      raise message_mod.DecodeError(e)
1091    return length   # Return this for legacy reasons.
1092  cls.MergeFromString = MergeFromString
1093
1094  local_ReadTag = decoder.ReadTag
1095  local_SkipField = decoder.SkipField
1096  decoders_by_tag = cls._decoders_by_tag
1097  is_proto3 = message_descriptor.syntax == "proto3"
1098
1099  def InternalParse(self, buffer, pos, end):
1100    self._Modified()
1101    field_dict = self._fields
1102    unknown_field_list = self._unknown_fields
1103    while pos != end:
1104      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1105      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1106      if field_decoder is None:
1107        value_start_pos = new_pos
1108        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
1109        if new_pos == -1:
1110          return pos
1111        if not is_proto3:
1112          if not unknown_field_list:
1113            unknown_field_list = self._unknown_fields = []
1114          unknown_field_list.append(
1115              (tag_bytes, buffer[value_start_pos:new_pos]))
1116        pos = new_pos
1117      else:
1118        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1119        if field_desc:
1120          self._UpdateOneofState(field_desc)
1121    return pos
1122  cls._InternalParse = InternalParse
1123
1124
1125def _AddIsInitializedMethod(message_descriptor, cls):
1126  """Adds the IsInitialized and FindInitializationError methods to the
1127  protocol message class."""
1128
1129  required_fields = [field for field in message_descriptor.fields
1130                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1131
1132  def IsInitialized(self, errors=None):
1133    """Checks if all required fields of a message are set.
1134
1135    Args:
1136      errors:  A list which, if provided, will be populated with the field
1137               paths of all missing required fields.
1138
1139    Returns:
1140      True iff the specified message has all required fields set.
1141    """
1142
1143    # Performance is critical so we avoid HasField() and ListFields().
1144
1145    for field in required_fields:
1146      if (field not in self._fields or
1147          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1148           not self._fields[field]._is_present_in_parent)):
1149        if errors is not None:
1150          errors.extend(self.FindInitializationErrors())
1151        return False
1152
1153    for field, value in list(self._fields.items()):  # dict can change size!
1154      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1155        if field.label == _FieldDescriptor.LABEL_REPEATED:
1156          if (field.message_type.has_options and
1157              field.message_type.GetOptions().map_entry):
1158            continue
1159          for element in value:
1160            if not element.IsInitialized():
1161              if errors is not None:
1162                errors.extend(self.FindInitializationErrors())
1163              return False
1164        elif value._is_present_in_parent and not value.IsInitialized():
1165          if errors is not None:
1166            errors.extend(self.FindInitializationErrors())
1167          return False
1168
1169    return True
1170
1171  cls.IsInitialized = IsInitialized
1172
1173  def FindInitializationErrors(self):
1174    """Finds required fields which are not initialized.
1175
1176    Returns:
1177      A list of strings.  Each string is a path to an uninitialized field from
1178      the top-level message, e.g. "foo.bar[5].baz".
1179    """
1180
1181    errors = []  # simplify things
1182
1183    for field in required_fields:
1184      if not self.HasField(field.name):
1185        errors.append(field.name)
1186
1187    for field, value in self.ListFields():
1188      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1189        if field.is_extension:
1190          name = "(%s)" % field.full_name
1191        else:
1192          name = field.name
1193
1194        if _IsMapField(field):
1195          if _IsMessageMapField(field):
1196            for key in value:
1197              element = value[key]
1198              prefix = "%s[%s]." % (name, key)
1199              sub_errors = element.FindInitializationErrors()
1200              errors += [prefix + error for error in sub_errors]
1201          else:
1202            # ScalarMaps can't have any initialization errors.
1203            pass
1204        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1205          for i in range(len(value)):
1206            element = value[i]
1207            prefix = "%s[%d]." % (name, i)
1208            sub_errors = element.FindInitializationErrors()
1209            errors += [prefix + error for error in sub_errors]
1210        else:
1211          prefix = name + "."
1212          sub_errors = value.FindInitializationErrors()
1213          errors += [prefix + error for error in sub_errors]
1214
1215    return errors
1216
1217  cls.FindInitializationErrors = FindInitializationErrors
1218
1219
1220def _AddMergeFromMethod(cls):
1221  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1222  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1223
1224  def MergeFrom(self, msg):
1225    if not isinstance(msg, cls):
1226      raise TypeError(
1227          "Parameter to MergeFrom() must be instance of same class: "
1228          "expected %s got %s." % (cls.__name__, type(msg).__name__))
1229
1230    assert msg is not self
1231    self._Modified()
1232
1233    fields = self._fields
1234
1235    for field, value in msg._fields.items():
1236      if field.label == LABEL_REPEATED:
1237        field_value = fields.get(field)
1238        if field_value is None:
1239          # Construct a new object to represent this field.
1240          field_value = field._default_constructor(self)
1241          fields[field] = field_value
1242        field_value.MergeFrom(value)
1243      elif field.cpp_type == CPPTYPE_MESSAGE:
1244        if value._is_present_in_parent:
1245          field_value = fields.get(field)
1246          if field_value is None:
1247            # Construct a new object to represent this field.
1248            field_value = field._default_constructor(self)
1249            fields[field] = field_value
1250          field_value.MergeFrom(value)
1251      else:
1252        self._fields[field] = value
1253        if field.containing_oneof:
1254          self._UpdateOneofState(field)
1255
1256    if msg._unknown_fields:
1257      if not self._unknown_fields:
1258        self._unknown_fields = []
1259      self._unknown_fields.extend(msg._unknown_fields)
1260
1261  cls.MergeFrom = MergeFrom
1262
1263
1264def _AddWhichOneofMethod(message_descriptor, cls):
1265  def WhichOneof(self, oneof_name):
1266    """Returns the name of the currently set field inside a oneof, or None."""
1267    try:
1268      field = message_descriptor.oneofs_by_name[oneof_name]
1269    except KeyError:
1270      raise ValueError(
1271          'Protocol message has no oneof "%s" field.' % oneof_name)
1272
1273    nested_field = self._oneofs.get(field, None)
1274    if nested_field is not None and self.HasField(nested_field.name):
1275      return nested_field.name
1276    else:
1277      return None
1278
1279  cls.WhichOneof = WhichOneof
1280
1281
1282def _Clear(self):
1283  # Clear fields.
1284  self._fields = {}
1285  self._unknown_fields = ()
1286  self._oneofs = {}
1287  self._Modified()
1288
1289
1290def _DiscardUnknownFields(self):
1291  self._unknown_fields = []
1292  for field, value in self.ListFields():
1293    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1294      if field.label == _FieldDescriptor.LABEL_REPEATED:
1295        for sub_message in value:
1296          sub_message.DiscardUnknownFields()
1297      else:
1298        value.DiscardUnknownFields()
1299
1300
1301def _SetListener(self, listener):
1302  if listener is None:
1303    self._listener = message_listener_mod.NullMessageListener()
1304  else:
1305    self._listener = listener
1306
1307
1308def _AddMessageMethods(message_descriptor, cls):
1309  """Adds implementations of all Message methods to cls."""
1310  _AddListFieldsMethod(message_descriptor, cls)
1311  _AddHasFieldMethod(message_descriptor, cls)
1312  _AddClearFieldMethod(message_descriptor, cls)
1313  if message_descriptor.is_extendable:
1314    _AddClearExtensionMethod(cls)
1315    _AddHasExtensionMethod(cls)
1316  _AddEqualsMethod(message_descriptor, cls)
1317  _AddStrMethod(message_descriptor, cls)
1318  _AddReprMethod(message_descriptor, cls)
1319  _AddUnicodeMethod(message_descriptor, cls)
1320  _AddByteSizeMethod(message_descriptor, cls)
1321  _AddSerializeToStringMethod(message_descriptor, cls)
1322  _AddSerializePartialToStringMethod(message_descriptor, cls)
1323  _AddMergeFromStringMethod(message_descriptor, cls)
1324  _AddIsInitializedMethod(message_descriptor, cls)
1325  _AddMergeFromMethod(cls)
1326  _AddWhichOneofMethod(message_descriptor, cls)
1327  # Adds methods which do not depend on cls.
1328  cls.Clear = _Clear
1329  cls.DiscardUnknownFields = _DiscardUnknownFields
1330  cls._SetListener = _SetListener
1331
1332
1333def _AddPrivateHelperMethods(message_descriptor, cls):
1334  """Adds implementation of private helper methods to cls."""
1335
1336  def Modified(self):
1337    """Sets the _cached_byte_size_dirty bit to true,
1338    and propagates this to our listener iff this was a state change.
1339    """
1340
1341    # Note:  Some callers check _cached_byte_size_dirty before calling
1342    #   _Modified() as an extra optimization.  So, if this method is ever
1343    #   changed such that it does stuff even when _cached_byte_size_dirty is
1344    #   already true, the callers need to be updated.
1345    if not self._cached_byte_size_dirty:
1346      self._cached_byte_size_dirty = True
1347      self._listener_for_children.dirty = True
1348      self._is_present_in_parent = True
1349      self._listener.Modified()
1350
1351  def _UpdateOneofState(self, field):
1352    """Sets field as the active field in its containing oneof.
1353
1354    Will also delete currently active field in the oneof, if it is different
1355    from the argument. Does not mark the message as modified.
1356    """
1357    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1358    if other_field is not field:
1359      del self._fields[other_field]
1360      self._oneofs[field.containing_oneof] = field
1361
1362  cls._Modified = Modified
1363  cls.SetInParent = Modified
1364  cls._UpdateOneofState = _UpdateOneofState
1365
1366
1367class _Listener(object):
1368
1369  """MessageListener implementation that a parent message registers with its
1370  child message.
1371
1372  In order to support semantics like:
1373
1374    foo.bar.baz.qux = 23
1375    assert foo.HasField('bar')
1376
1377  ...child objects must have back references to their parents.
1378  This helper class is at the heart of this support.
1379  """
1380
1381  def __init__(self, parent_message):
1382    """Args:
1383      parent_message: The message whose _Modified() method we should call when
1384        we receive Modified() messages.
1385    """
1386    # This listener establishes a back reference from a child (contained) object
1387    # to its parent (containing) object.  We make this a weak reference to avoid
1388    # creating cyclic garbage when the client finishes with the 'parent' object
1389    # in the tree.
1390    if isinstance(parent_message, weakref.ProxyType):
1391      self._parent_message_weakref = parent_message
1392    else:
1393      self._parent_message_weakref = weakref.proxy(parent_message)
1394
1395    # As an optimization, we also indicate directly on the listener whether
1396    # or not the parent message is dirty.  This way we can avoid traversing
1397    # up the tree in the common case.
1398    self.dirty = False
1399
1400  def Modified(self):
1401    if self.dirty:
1402      return
1403    try:
1404      # Propagate the signal to our parents iff this is the first field set.
1405      self._parent_message_weakref._Modified()
1406    except ReferenceError:
1407      # We can get here if a client has kept a reference to a child object,
1408      # and is now setting a field on it, but the child's parent has been
1409      # garbage-collected.  This is not an error.
1410      pass
1411
1412
1413class _OneofListener(_Listener):
1414  """Special listener implementation for setting composite oneof fields."""
1415
1416  def __init__(self, parent_message, field):
1417    """Args:
1418      parent_message: The message whose _Modified() method we should call when
1419        we receive Modified() messages.
1420      field: The descriptor of the field being set in the parent message.
1421    """
1422    super(_OneofListener, self).__init__(parent_message)
1423    self._field = field
1424
1425  def Modified(self):
1426    """Also updates the state of the containing oneof in the parent message."""
1427    try:
1428      self._parent_message_weakref._UpdateOneofState(self._field)
1429      super(_OneofListener, self).Modified()
1430    except ReferenceError:
1431      pass
1432
1433
1434# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1435# TODO(robinson): Unify error handling of "unknown extension" crap.
1436# TODO(robinson): Support iteritems()-style iteration over all
1437# extensions with the "has" bits turned on?
1438class _ExtensionDict(object):
1439
1440  """Dict-like container for supporting an indexable "Extensions"
1441  field on proto instances.
1442
1443  Note that in all cases we expect extension handles to be
1444  FieldDescriptors.
1445  """
1446
1447  def __init__(self, extended_message):
1448    """extended_message: Message instance for which we are the Extensions dict.
1449    """
1450
1451    self._extended_message = extended_message
1452
1453  def __getitem__(self, extension_handle):
1454    """Returns the current value of the given extension handle."""
1455
1456    _VerifyExtensionHandle(self._extended_message, extension_handle)
1457
1458    result = self._extended_message._fields.get(extension_handle)
1459    if result is not None:
1460      return result
1461
1462    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1463      result = extension_handle._default_constructor(self._extended_message)
1464    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1465      result = extension_handle.message_type._concrete_class()
1466      try:
1467        result._SetListener(self._extended_message._listener_for_children)
1468      except ReferenceError:
1469        pass
1470    else:
1471      # Singular scalar -- just return the default without inserting into the
1472      # dict.
1473      return extension_handle.default_value
1474
1475    # Atomically check if another thread has preempted us and, if not, swap
1476    # in the new object we just created.  If someone has preempted us, we
1477    # take that object and discard ours.
1478    # WARNING:  We are relying on setdefault() being atomic.  This is true
1479    #   in CPython but we haven't investigated others.  This warning appears
1480    #   in several other locations in this file.
1481    result = self._extended_message._fields.setdefault(
1482        extension_handle, result)
1483
1484    return result
1485
1486  def __eq__(self, other):
1487    if not isinstance(other, self.__class__):
1488      return False
1489
1490    my_fields = self._extended_message.ListFields()
1491    other_fields = other._extended_message.ListFields()
1492
1493    # Get rid of non-extension fields.
1494    my_fields    = [ field for field in my_fields    if field.is_extension ]
1495    other_fields = [ field for field in other_fields if field.is_extension ]
1496
1497    return my_fields == other_fields
1498
1499  def __ne__(self, other):
1500    return not self == other
1501
1502  def __hash__(self):
1503    raise TypeError('unhashable object')
1504
1505  # Note that this is only meaningful for non-repeated, scalar extension
1506  # fields.  Note also that we may have to call _Modified() when we do
1507  # successfully set a field this way, to set any necssary "has" bits in the
1508  # ancestors of the extended message.
1509  def __setitem__(self, extension_handle, value):
1510    """If extension_handle specifies a non-repeated, scalar extension
1511    field, sets the value of that field.
1512    """
1513
1514    _VerifyExtensionHandle(self._extended_message, extension_handle)
1515
1516    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1517        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1518      raise TypeError(
1519          'Cannot assign to extension "%s" because it is a repeated or '
1520          'composite type.' % extension_handle.full_name)
1521
1522    # It's slightly wasteful to lookup the type checker each time,
1523    # but we expect this to be a vanishingly uncommon case anyway.
1524    type_checker = type_checkers.GetTypeChecker(extension_handle)
1525    # pylint: disable=protected-access
1526    self._extended_message._fields[extension_handle] = (
1527        type_checker.CheckValue(value))
1528    self._extended_message._Modified()
1529
1530  def _FindExtensionByName(self, name):
1531    """Tries to find a known extension with the specified name.
1532
1533    Args:
1534      name: Extension full name.
1535
1536    Returns:
1537      Extension field descriptor.
1538    """
1539    return self._extended_message._extensions_by_name.get(name, None)
1540
1541  def _FindExtensionByNumber(self, number):
1542    """Tries to find a known extension with the field number.
1543
1544    Args:
1545      number: Extension field number.
1546
1547    Returns:
1548      Extension field descriptor.
1549    """
1550    return self._extended_message._extensions_by_number.get(number, None)
1551