1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# http://code.google.com/p/protobuf/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53try:
54  from cStringIO import StringIO
55except ImportError:
56  from StringIO import StringIO
57import copy_reg
58import struct
59import weakref
60
61# We use "as" to avoid name collisions with variables.
62from google.protobuf.internal import containers
63from google.protobuf.internal import decoder
64from google.protobuf.internal import encoder
65from google.protobuf.internal import enum_type_wrapper
66from google.protobuf.internal import message_listener as message_listener_mod
67from google.protobuf.internal import type_checkers
68from google.protobuf.internal import wire_format
69from google.protobuf import descriptor as descriptor_mod
70from google.protobuf import message as message_mod
71from google.protobuf import text_format
72
73_FieldDescriptor = descriptor_mod.FieldDescriptor
74
75
76def NewMessage(bases, descriptor, dictionary):
77  _AddClassAttributesForNestedExtensions(descriptor, dictionary)
78  _AddSlots(descriptor, dictionary)
79  return bases
80
81
82def InitMessage(descriptor, cls):
83  cls._decoders_by_tag = {}
84  cls._extensions_by_name = {}
85  cls._extensions_by_number = {}
86  if (descriptor.has_options and
87      descriptor.GetOptions().message_set_wire_format):
88    cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
89        decoder.MessageSetItemDecoder(cls._extensions_by_number))
90
91  # Attach stuff to each FieldDescriptor for quick lookup later on.
92  for field in descriptor.fields:
93    _AttachFieldHelpers(cls, field)
94
95  _AddEnumValues(descriptor, cls)
96  _AddInitMethod(descriptor, cls)
97  _AddPropertiesForFields(descriptor, cls)
98  _AddPropertiesForExtensions(descriptor, cls)
99  _AddStaticMethods(cls)
100  _AddMessageMethods(descriptor, cls)
101  _AddPrivateHelperMethods(cls)
102  copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
103
104
105# Stateless helpers for GeneratedProtocolMessageType below.
106# Outside clients should not access these directly.
107#
108# I opted not to make any of these methods on the metaclass, to make it more
109# clear that I'm not really using any state there and to keep clients from
110# thinking that they have direct access to these construction helpers.
111
112
113def _PropertyName(proto_field_name):
114  """Returns the name of the public property attribute which
115  clients can use to get and (in some cases) set the value
116  of a protocol message field.
117
118  Args:
119    proto_field_name: The protocol message field name, exactly
120      as it appears (or would appear) in a .proto file.
121  """
122  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
123  # nnorwitz makes my day by writing:
124  # """
125  # FYI.  See the keyword module in the stdlib. This could be as simple as:
126  #
127  # if keyword.iskeyword(proto_field_name):
128  #   return proto_field_name + "_"
129  # return proto_field_name
130  # """
131  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
132  #   getattr() and setattr() to reflectively manipulate field values.  If we
133  #   rename the properties, then every such user has to also make sure to apply
134  #   the same transformation.  Note that currently if you name a field "yield",
135  #   you can still access it just fine using getattr/setattr -- it's not even
136  #   that cumbersome to do so.
137  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
138  #   position.
139  return proto_field_name
140
141
142def _VerifyExtensionHandle(message, extension_handle):
143  """Verify that the given extension handle is valid."""
144
145  if not isinstance(extension_handle, _FieldDescriptor):
146    raise KeyError('HasExtension() expects an extension handle, got: %s' %
147                   extension_handle)
148
149  if not extension_handle.is_extension:
150    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
151
152  if not extension_handle.containing_type:
153    raise KeyError('"%s" is missing a containing_type.'
154                   % extension_handle.full_name)
155
156  if extension_handle.containing_type is not message.DESCRIPTOR:
157    raise KeyError('Extension "%s" extends message type "%s", but this '
158                   'message is of type "%s".' %
159                   (extension_handle.full_name,
160                    extension_handle.containing_type.full_name,
161                    message.DESCRIPTOR.full_name))
162
163
164def _AddSlots(message_descriptor, dictionary):
165  """Adds a __slots__ entry to dictionary, containing the names of all valid
166  attributes for this message type.
167
168  Args:
169    message_descriptor: A Descriptor instance describing this message type.
170    dictionary: Class dictionary to which we'll add a '__slots__' entry.
171  """
172  dictionary['__slots__'] = ['_cached_byte_size',
173                             '_cached_byte_size_dirty',
174                             '_fields',
175                             '_unknown_fields',
176                             '_is_present_in_parent',
177                             '_listener',
178                             '_listener_for_children',
179                             '__weakref__']
180
181
182def _IsMessageSetExtension(field):
183  return (field.is_extension and
184          field.containing_type.has_options and
185          field.containing_type.GetOptions().message_set_wire_format and
186          field.type == _FieldDescriptor.TYPE_MESSAGE and
187          field.message_type == field.extension_scope and
188          field.label == _FieldDescriptor.LABEL_OPTIONAL)
189
190
191def _AttachFieldHelpers(cls, field_descriptor):
192  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
193  is_packed = (field_descriptor.has_options and
194               field_descriptor.GetOptions().packed)
195
196  if _IsMessageSetExtension(field_descriptor):
197    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
198    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
199  else:
200    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
201        field_descriptor.number, is_repeated, is_packed)
202    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
203        field_descriptor.number, is_repeated, is_packed)
204
205  field_descriptor._encoder = field_encoder
206  field_descriptor._sizer = sizer
207  field_descriptor._default_constructor = _DefaultValueConstructorForField(
208      field_descriptor)
209
210  def AddDecoder(wiretype, is_packed):
211    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
212    cls._decoders_by_tag[tag_bytes] = (
213        type_checkers.TYPE_TO_DECODER[field_descriptor.type](
214            field_descriptor.number, is_repeated, is_packed,
215            field_descriptor, field_descriptor._default_constructor))
216
217  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
218             False)
219
220  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
221    # To support wire compatibility of adding packed = true, add a decoder for
222    # packed values regardless of the field's options.
223    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
224
225
226def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
227  extension_dict = descriptor.extensions_by_name
228  for extension_name, extension_field in extension_dict.iteritems():
229    assert extension_name not in dictionary
230    dictionary[extension_name] = extension_field
231
232
233def _AddEnumValues(descriptor, cls):
234  """Sets class-level attributes for all enum fields defined in this message.
235
236  Also exporting a class-level object that can name enum values.
237
238  Args:
239    descriptor: Descriptor object for this message type.
240    cls: Class we're constructing for this message type.
241  """
242  for enum_type in descriptor.enum_types:
243    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
244    for enum_value in enum_type.values:
245      setattr(cls, enum_value.name, enum_value.number)
246
247
248def _DefaultValueConstructorForField(field):
249  """Returns a function which returns a default value for a field.
250
251  Args:
252    field: FieldDescriptor object for this field.
253
254  The returned function has one argument:
255    message: Message instance containing this field, or a weakref proxy
256      of same.
257
258  That function in turn returns a default value for this field.  The default
259    value may refer back to |message| via a weak reference.
260  """
261
262  if field.label == _FieldDescriptor.LABEL_REPEATED:
263    if field.has_default_value and field.default_value != []:
264      raise ValueError('Repeated field default value not empty list: %s' % (
265          field.default_value))
266    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
267      # We can't look at _concrete_class yet since it might not have
268      # been set.  (Depends on order in which we initialize the classes).
269      message_type = field.message_type
270      def MakeRepeatedMessageDefault(message):
271        return containers.RepeatedCompositeFieldContainer(
272            message._listener_for_children, field.message_type)
273      return MakeRepeatedMessageDefault
274    else:
275      type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
276      def MakeRepeatedScalarDefault(message):
277        return containers.RepeatedScalarFieldContainer(
278            message._listener_for_children, type_checker)
279      return MakeRepeatedScalarDefault
280
281  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
282    # _concrete_class may not yet be initialized.
283    message_type = field.message_type
284    def MakeSubMessageDefault(message):
285      result = message_type._concrete_class()
286      result._SetListener(message._listener_for_children)
287      return result
288    return MakeSubMessageDefault
289
290  def MakeScalarDefault(message):
291    # TODO(protobuf-team): This may be broken since there may not be
292    # default_value.  Combine with has_default_value somehow.
293    return field.default_value
294  return MakeScalarDefault
295
296
297def _AddInitMethod(message_descriptor, cls):
298  """Adds an __init__ method to cls."""
299  fields = message_descriptor.fields
300  def init(self, **kwargs):
301    self._cached_byte_size = 0
302    self._cached_byte_size_dirty = len(kwargs) > 0
303    self._fields = {}
304    # _unknown_fields is () when empty for efficiency, and will be turned into
305    # a list if fields are added.
306    self._unknown_fields = ()
307    self._is_present_in_parent = False
308    self._listener = message_listener_mod.NullMessageListener()
309    self._listener_for_children = _Listener(self)
310    for field_name, field_value in kwargs.iteritems():
311      field = _GetFieldByName(message_descriptor, field_name)
312      if field is None:
313        raise TypeError("%s() got an unexpected keyword argument '%s'" %
314                        (message_descriptor.name, field_name))
315      if field.label == _FieldDescriptor.LABEL_REPEATED:
316        copy = field._default_constructor(self)
317        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
318          for val in field_value:
319            copy.add().MergeFrom(val)
320        else:  # Scalar
321          copy.extend(field_value)
322        self._fields[field] = copy
323      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
324        copy = field._default_constructor(self)
325        copy.MergeFrom(field_value)
326        self._fields[field] = copy
327      else:
328        setattr(self, field_name, field_value)
329
330  init.__module__ = None
331  init.__doc__ = None
332  cls.__init__ = init
333
334
335def _GetFieldByName(message_descriptor, field_name):
336  """Returns a field descriptor by field name.
337
338  Args:
339    message_descriptor: A Descriptor describing all fields in message.
340    field_name: The name of the field to retrieve.
341  Returns:
342    The field descriptor associated with the field name.
343  """
344  try:
345    return message_descriptor.fields_by_name[field_name]
346  except KeyError:
347    raise ValueError('Protocol message has no "%s" field.' % field_name)
348
349
350def _AddPropertiesForFields(descriptor, cls):
351  """Adds properties for all fields in this protocol message type."""
352  for field in descriptor.fields:
353    _AddPropertiesForField(field, cls)
354
355  if descriptor.is_extendable:
356    # _ExtensionDict is just an adaptor with no state so we allocate a new one
357    # every time it is accessed.
358    cls.Extensions = property(lambda self: _ExtensionDict(self))
359
360
361def _AddPropertiesForField(field, cls):
362  """Adds a public property for a protocol message field.
363  Clients can use this property to get and (in the case
364  of non-repeated scalar fields) directly set the value
365  of a protocol message field.
366
367  Args:
368    field: A FieldDescriptor for this field.
369    cls: The class we're constructing.
370  """
371  # Catch it if we add other types that we should
372  # handle specially here.
373  assert _FieldDescriptor.MAX_CPPTYPE == 10
374
375  constant_name = field.name.upper() + "_FIELD_NUMBER"
376  setattr(cls, constant_name, field.number)
377
378  if field.label == _FieldDescriptor.LABEL_REPEATED:
379    _AddPropertiesForRepeatedField(field, cls)
380  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
381    _AddPropertiesForNonRepeatedCompositeField(field, cls)
382  else:
383    _AddPropertiesForNonRepeatedScalarField(field, cls)
384
385
386def _AddPropertiesForRepeatedField(field, cls):
387  """Adds a public property for a "repeated" protocol message field.  Clients
388  can use this property to get the value of the field, which will be either a
389  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
390  below).
391
392  Note that when clients add values to these containers, we perform
393  type-checking in the case of repeated scalar fields, and we also set any
394  necessary "has" bits as a side-effect.
395
396  Args:
397    field: A FieldDescriptor for this field.
398    cls: The class we're constructing.
399  """
400  proto_field_name = field.name
401  property_name = _PropertyName(proto_field_name)
402
403  def getter(self):
404    field_value = self._fields.get(field)
405    if field_value is None:
406      # Construct a new object to represent this field.
407      field_value = field._default_constructor(self)
408
409      # Atomically check if another thread has preempted us and, if not, swap
410      # in the new object we just created.  If someone has preempted us, we
411      # take that object and discard ours.
412      # WARNING:  We are relying on setdefault() being atomic.  This is true
413      #   in CPython but we haven't investigated others.  This warning appears
414      #   in several other locations in this file.
415      field_value = self._fields.setdefault(field, field_value)
416    return field_value
417  getter.__module__ = None
418  getter.__doc__ = 'Getter for %s.' % proto_field_name
419
420  # We define a setter just so we can throw an exception with a more
421  # helpful error message.
422  def setter(self, new_value):
423    raise AttributeError('Assignment not allowed to repeated field '
424                         '"%s" in protocol message object.' % proto_field_name)
425
426  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
427  setattr(cls, property_name, property(getter, setter, doc=doc))
428
429
430def _AddPropertiesForNonRepeatedScalarField(field, cls):
431  """Adds a public property for a nonrepeated, scalar protocol message field.
432  Clients can use this property to get and directly set the value of the field.
433  Note that when the client sets the value of a field by using this property,
434  all necessary "has" bits are set as a side-effect, and we also perform
435  type-checking.
436
437  Args:
438    field: A FieldDescriptor for this field.
439    cls: The class we're constructing.
440  """
441  proto_field_name = field.name
442  property_name = _PropertyName(proto_field_name)
443  type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
444  default_value = field.default_value
445  valid_values = set()
446
447  def getter(self):
448    # TODO(protobuf-team): This may be broken since there may not be
449    # default_value.  Combine with has_default_value somehow.
450    return self._fields.get(field, default_value)
451  getter.__module__ = None
452  getter.__doc__ = 'Getter for %s.' % proto_field_name
453  def setter(self, new_value):
454    type_checker.CheckValue(new_value)
455    self._fields[field] = new_value
456    # Check _cached_byte_size_dirty inline to improve performance, since scalar
457    # setters are called frequently.
458    if not self._cached_byte_size_dirty:
459      self._Modified()
460
461  setter.__module__ = None
462  setter.__doc__ = 'Setter for %s.' % proto_field_name
463
464  # Add a property to encapsulate the getter/setter.
465  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
466  setattr(cls, property_name, property(getter, setter, doc=doc))
467
468
469def _AddPropertiesForNonRepeatedCompositeField(field, cls):
470  """Adds a public property for a nonrepeated, composite protocol message field.
471  A composite field is a "group" or "message" field.
472
473  Clients can use this property to get the value of the field, but cannot
474  assign to the property directly.
475
476  Args:
477    field: A FieldDescriptor for this field.
478    cls: The class we're constructing.
479  """
480  # TODO(robinson): Remove duplication with similar method
481  # for non-repeated scalars.
482  proto_field_name = field.name
483  property_name = _PropertyName(proto_field_name)
484
485  # TODO(komarek): Can anyone explain to me why we cache the message_type this
486  # way, instead of referring to field.message_type inside of getter(self)?
487  # What if someone sets message_type later on (which makes for simpler
488  # dyanmic proto descriptor and class creation code).
489  message_type = field.message_type
490
491  def getter(self):
492    field_value = self._fields.get(field)
493    if field_value is None:
494      # Construct a new object to represent this field.
495      field_value = message_type._concrete_class()  # use field.message_type?
496      field_value._SetListener(self._listener_for_children)
497
498      # Atomically check if another thread has preempted us and, if not, swap
499      # in the new object we just created.  If someone has preempted us, we
500      # take that object and discard ours.
501      # WARNING:  We are relying on setdefault() being atomic.  This is true
502      #   in CPython but we haven't investigated others.  This warning appears
503      #   in several other locations in this file.
504      field_value = self._fields.setdefault(field, field_value)
505    return field_value
506  getter.__module__ = None
507  getter.__doc__ = 'Getter for %s.' % proto_field_name
508
509  # We define a setter just so we can throw an exception with a more
510  # helpful error message.
511  def setter(self, new_value):
512    raise AttributeError('Assignment not allowed to composite field '
513                         '"%s" in protocol message object.' % proto_field_name)
514
515  # Add a property to encapsulate the getter.
516  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
517  setattr(cls, property_name, property(getter, setter, doc=doc))
518
519
520def _AddPropertiesForExtensions(descriptor, cls):
521  """Adds properties for all fields in this protocol message type."""
522  extension_dict = descriptor.extensions_by_name
523  for extension_name, extension_field in extension_dict.iteritems():
524    constant_name = extension_name.upper() + "_FIELD_NUMBER"
525    setattr(cls, constant_name, extension_field.number)
526
527
528def _AddStaticMethods(cls):
529  # TODO(robinson): This probably needs to be thread-safe(?)
530  def RegisterExtension(extension_handle):
531    extension_handle.containing_type = cls.DESCRIPTOR
532    _AttachFieldHelpers(cls, extension_handle)
533
534    # Try to insert our extension, failing if an extension with the same number
535    # already exists.
536    actual_handle = cls._extensions_by_number.setdefault(
537        extension_handle.number, extension_handle)
538    if actual_handle is not extension_handle:
539      raise AssertionError(
540          'Extensions "%s" and "%s" both try to extend message type "%s" with '
541          'field number %d.' %
542          (extension_handle.full_name, actual_handle.full_name,
543           cls.DESCRIPTOR.full_name, extension_handle.number))
544
545    cls._extensions_by_name[extension_handle.full_name] = extension_handle
546
547    handle = extension_handle  # avoid line wrapping
548    if _IsMessageSetExtension(handle):
549      # MessageSet extension.  Also register under type name.
550      cls._extensions_by_name[
551          extension_handle.message_type.full_name] = extension_handle
552
553  cls.RegisterExtension = staticmethod(RegisterExtension)
554
555  def FromString(s):
556    message = cls()
557    message.MergeFromString(s)
558    return message
559  cls.FromString = staticmethod(FromString)
560
561
562def _IsPresent(item):
563  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
564  value should be included in the list returned by ListFields()."""
565
566  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
567    return bool(item[1])
568  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
569    return item[1]._is_present_in_parent
570  else:
571    return True
572
573
574def _AddListFieldsMethod(message_descriptor, cls):
575  """Helper for _AddMessageMethods()."""
576
577  def ListFields(self):
578    all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
579    all_fields.sort(key = lambda item: item[0].number)
580    return all_fields
581
582  cls.ListFields = ListFields
583
584
585def _AddHasFieldMethod(message_descriptor, cls):
586  """Helper for _AddMessageMethods()."""
587
588  singular_fields = {}
589  for field in message_descriptor.fields:
590    if field.label != _FieldDescriptor.LABEL_REPEATED:
591      singular_fields[field.name] = field
592
593  def HasField(self, field_name):
594    try:
595      field = singular_fields[field_name]
596    except KeyError:
597      raise ValueError(
598          'Protocol message has no singular "%s" field.' % field_name)
599
600    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
601      value = self._fields.get(field)
602      return value is not None and value._is_present_in_parent
603    else:
604      return field in self._fields
605  cls.HasField = HasField
606
607
608def _AddClearFieldMethod(message_descriptor, cls):
609  """Helper for _AddMessageMethods()."""
610  def ClearField(self, field_name):
611    try:
612      field = message_descriptor.fields_by_name[field_name]
613    except KeyError:
614      raise ValueError('Protocol message has no "%s" field.' % field_name)
615
616    if field in self._fields:
617      # Note:  If the field is a sub-message, its listener will still point
618      #   at us.  That's fine, because the worst than can happen is that it
619      #   will call _Modified() and invalidate our byte size.  Big deal.
620      del self._fields[field]
621
622    # Always call _Modified() -- even if nothing was changed, this is
623    # a mutating method, and thus calling it should cause the field to become
624    # present in the parent message.
625    self._Modified()
626
627  cls.ClearField = ClearField
628
629
630def _AddClearExtensionMethod(cls):
631  """Helper for _AddMessageMethods()."""
632  def ClearExtension(self, extension_handle):
633    _VerifyExtensionHandle(self, extension_handle)
634
635    # Similar to ClearField(), above.
636    if extension_handle in self._fields:
637      del self._fields[extension_handle]
638    self._Modified()
639  cls.ClearExtension = ClearExtension
640
641
642def _AddClearMethod(message_descriptor, cls):
643  """Helper for _AddMessageMethods()."""
644  def Clear(self):
645    # Clear fields.
646    self._fields = {}
647    self._unknown_fields = ()
648    self._Modified()
649  cls.Clear = Clear
650
651
652def _AddHasExtensionMethod(cls):
653  """Helper for _AddMessageMethods()."""
654  def HasExtension(self, extension_handle):
655    _VerifyExtensionHandle(self, extension_handle)
656    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
657      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
658
659    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
660      value = self._fields.get(extension_handle)
661      return value is not None and value._is_present_in_parent
662    else:
663      return extension_handle in self._fields
664  cls.HasExtension = HasExtension
665
666
667def _AddEqualsMethod(message_descriptor, cls):
668  """Helper for _AddMessageMethods()."""
669  def __eq__(self, other):
670    if (not isinstance(other, message_mod.Message) or
671        other.DESCRIPTOR != self.DESCRIPTOR):
672      return False
673
674    if self is other:
675      return True
676
677    if not self.ListFields() == other.ListFields():
678      return False
679
680    # Sort unknown fields because their order shouldn't affect equality test.
681    unknown_fields = list(self._unknown_fields)
682    unknown_fields.sort()
683    other_unknown_fields = list(other._unknown_fields)
684    other_unknown_fields.sort()
685
686    return unknown_fields == other_unknown_fields
687
688  cls.__eq__ = __eq__
689
690
691def _AddStrMethod(message_descriptor, cls):
692  """Helper for _AddMessageMethods()."""
693  def __str__(self):
694    return text_format.MessageToString(self)
695  cls.__str__ = __str__
696
697
698def _AddUnicodeMethod(unused_message_descriptor, cls):
699  """Helper for _AddMessageMethods()."""
700
701  def __unicode__(self):
702    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
703  cls.__unicode__ = __unicode__
704
705
706def _AddSetListenerMethod(cls):
707  """Helper for _AddMessageMethods()."""
708  def SetListener(self, listener):
709    if listener is None:
710      self._listener = message_listener_mod.NullMessageListener()
711    else:
712      self._listener = listener
713  cls._SetListener = SetListener
714
715
716def _BytesForNonRepeatedElement(value, field_number, field_type):
717  """Returns the number of bytes needed to serialize a non-repeated element.
718  The returned byte count includes space for tag information and any
719  other additional space associated with serializing value.
720
721  Args:
722    value: Value we're serializing.
723    field_number: Field number of this value.  (Since the field number
724      is stored as part of a varint-encoded tag, this has an impact
725      on the total bytes required to serialize the value).
726    field_type: The type of the field.  One of the TYPE_* constants
727      within FieldDescriptor.
728  """
729  try:
730    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
731    return fn(field_number, value)
732  except KeyError:
733    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
734
735
736def _AddByteSizeMethod(message_descriptor, cls):
737  """Helper for _AddMessageMethods()."""
738
739  def ByteSize(self):
740    if not self._cached_byte_size_dirty:
741      return self._cached_byte_size
742
743    size = 0
744    for field_descriptor, field_value in self.ListFields():
745      size += field_descriptor._sizer(field_value)
746
747    for tag_bytes, value_bytes in self._unknown_fields:
748      size += len(tag_bytes) + len(value_bytes)
749
750    self._cached_byte_size = size
751    self._cached_byte_size_dirty = False
752    self._listener_for_children.dirty = False
753    return size
754
755  cls.ByteSize = ByteSize
756
757
758def _AddSerializeToStringMethod(message_descriptor, cls):
759  """Helper for _AddMessageMethods()."""
760
761  def SerializeToString(self):
762    # Check if the message has all of its required fields set.
763    errors = []
764    if not self.IsInitialized():
765      raise message_mod.EncodeError(
766          'Message %s is missing required fields: %s' % (
767          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
768    return self.SerializePartialToString()
769  cls.SerializeToString = SerializeToString
770
771
772def _AddSerializePartialToStringMethod(message_descriptor, cls):
773  """Helper for _AddMessageMethods()."""
774
775  def SerializePartialToString(self):
776    out = StringIO()
777    self._InternalSerialize(out.write)
778    return out.getvalue()
779  cls.SerializePartialToString = SerializePartialToString
780
781  def InternalSerialize(self, write_bytes):
782    for field_descriptor, field_value in self.ListFields():
783      field_descriptor._encoder(write_bytes, field_value)
784    for tag_bytes, value_bytes in self._unknown_fields:
785      write_bytes(tag_bytes)
786      write_bytes(value_bytes)
787  cls._InternalSerialize = InternalSerialize
788
789
790def _AddMergeFromStringMethod(message_descriptor, cls):
791  """Helper for _AddMessageMethods()."""
792  def MergeFromString(self, serialized):
793    length = len(serialized)
794    try:
795      if self._InternalParse(serialized, 0, length) != length:
796        # The only reason _InternalParse would return early is if it
797        # encountered an end-group tag.
798        raise message_mod.DecodeError('Unexpected end-group tag.')
799    except IndexError:
800      raise message_mod.DecodeError('Truncated message.')
801    except struct.error, e:
802      raise message_mod.DecodeError(e)
803    return length   # Return this for legacy reasons.
804  cls.MergeFromString = MergeFromString
805
806  local_ReadTag = decoder.ReadTag
807  local_SkipField = decoder.SkipField
808  decoders_by_tag = cls._decoders_by_tag
809
810  def InternalParse(self, buffer, pos, end):
811    self._Modified()
812    field_dict = self._fields
813    unknown_field_list = self._unknown_fields
814    while pos != end:
815      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
816      field_decoder = decoders_by_tag.get(tag_bytes)
817      if field_decoder is None:
818        value_start_pos = new_pos
819        new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
820        if new_pos == -1:
821          return pos
822        if not unknown_field_list:
823          unknown_field_list = self._unknown_fields = []
824        unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
825        pos = new_pos
826      else:
827        pos = field_decoder(buffer, new_pos, end, self, field_dict)
828    return pos
829  cls._InternalParse = InternalParse
830
831
832def _AddIsInitializedMethod(message_descriptor, cls):
833  """Adds the IsInitialized and FindInitializationError methods to the
834  protocol message class."""
835
836  required_fields = [field for field in message_descriptor.fields
837                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
838
839  def IsInitialized(self, errors=None):
840    """Checks if all required fields of a message are set.
841
842    Args:
843      errors:  A list which, if provided, will be populated with the field
844               paths of all missing required fields.
845
846    Returns:
847      True iff the specified message has all required fields set.
848    """
849
850    # Performance is critical so we avoid HasField() and ListFields().
851
852    for field in required_fields:
853      if (field not in self._fields or
854          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
855           not self._fields[field]._is_present_in_parent)):
856        if errors is not None:
857          errors.extend(self.FindInitializationErrors())
858        return False
859
860    for field, value in self._fields.iteritems():
861      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
862        if field.label == _FieldDescriptor.LABEL_REPEATED:
863          for element in value:
864            if not element.IsInitialized():
865              if errors is not None:
866                errors.extend(self.FindInitializationErrors())
867              return False
868        elif value._is_present_in_parent and not value.IsInitialized():
869          if errors is not None:
870            errors.extend(self.FindInitializationErrors())
871          return False
872
873    return True
874
875  cls.IsInitialized = IsInitialized
876
877  def FindInitializationErrors(self):
878    """Finds required fields which are not initialized.
879
880    Returns:
881      A list of strings.  Each string is a path to an uninitialized field from
882      the top-level message, e.g. "foo.bar[5].baz".
883    """
884
885    errors = []  # simplify things
886
887    for field in required_fields:
888      if not self.HasField(field.name):
889        errors.append(field.name)
890
891    for field, value in self.ListFields():
892      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
893        if field.is_extension:
894          name = "(%s)" % field.full_name
895        else:
896          name = field.name
897
898        if field.label == _FieldDescriptor.LABEL_REPEATED:
899          for i in xrange(len(value)):
900            element = value[i]
901            prefix = "%s[%d]." % (name, i)
902            sub_errors = element.FindInitializationErrors()
903            errors += [ prefix + error for error in sub_errors ]
904        else:
905          prefix = name + "."
906          sub_errors = value.FindInitializationErrors()
907          errors += [ prefix + error for error in sub_errors ]
908
909    return errors
910
911  cls.FindInitializationErrors = FindInitializationErrors
912
913
914def _AddMergeFromMethod(cls):
915  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
916  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
917
918  def MergeFrom(self, msg):
919    if not isinstance(msg, cls):
920      raise TypeError(
921          "Parameter to MergeFrom() must be instance of same class: "
922          "expected %s got %s." % (cls.__name__, type(msg).__name__))
923
924    assert msg is not self
925    self._Modified()
926
927    fields = self._fields
928
929    for field, value in msg._fields.iteritems():
930      if field.label == LABEL_REPEATED:
931        field_value = fields.get(field)
932        if field_value is None:
933          # Construct a new object to represent this field.
934          field_value = field._default_constructor(self)
935          fields[field] = field_value
936        field_value.MergeFrom(value)
937      elif field.cpp_type == CPPTYPE_MESSAGE:
938        if value._is_present_in_parent:
939          field_value = fields.get(field)
940          if field_value is None:
941            # Construct a new object to represent this field.
942            field_value = field._default_constructor(self)
943            fields[field] = field_value
944          field_value.MergeFrom(value)
945      else:
946        self._fields[field] = value
947
948    if msg._unknown_fields:
949      if not self._unknown_fields:
950        self._unknown_fields = []
951      self._unknown_fields.extend(msg._unknown_fields)
952
953  cls.MergeFrom = MergeFrom
954
955
956def _AddMessageMethods(message_descriptor, cls):
957  """Adds implementations of all Message methods to cls."""
958  _AddListFieldsMethod(message_descriptor, cls)
959  _AddHasFieldMethod(message_descriptor, cls)
960  _AddClearFieldMethod(message_descriptor, cls)
961  if message_descriptor.is_extendable:
962    _AddClearExtensionMethod(cls)
963    _AddHasExtensionMethod(cls)
964  _AddClearMethod(message_descriptor, cls)
965  _AddEqualsMethod(message_descriptor, cls)
966  _AddStrMethod(message_descriptor, cls)
967  _AddUnicodeMethod(message_descriptor, cls)
968  _AddSetListenerMethod(cls)
969  _AddByteSizeMethod(message_descriptor, cls)
970  _AddSerializeToStringMethod(message_descriptor, cls)
971  _AddSerializePartialToStringMethod(message_descriptor, cls)
972  _AddMergeFromStringMethod(message_descriptor, cls)
973  _AddIsInitializedMethod(message_descriptor, cls)
974  _AddMergeFromMethod(cls)
975
976
977def _AddPrivateHelperMethods(cls):
978  """Adds implementation of private helper methods to cls."""
979
980  def Modified(self):
981    """Sets the _cached_byte_size_dirty bit to true,
982    and propagates this to our listener iff this was a state change.
983    """
984
985    # Note:  Some callers check _cached_byte_size_dirty before calling
986    #   _Modified() as an extra optimization.  So, if this method is ever
987    #   changed such that it does stuff even when _cached_byte_size_dirty is
988    #   already true, the callers need to be updated.
989    if not self._cached_byte_size_dirty:
990      self._cached_byte_size_dirty = True
991      self._listener_for_children.dirty = True
992      self._is_present_in_parent = True
993      self._listener.Modified()
994
995  cls._Modified = Modified
996  cls.SetInParent = Modified
997
998
999class _Listener(object):
1000
1001  """MessageListener implementation that a parent message registers with its
1002  child message.
1003
1004  In order to support semantics like:
1005
1006    foo.bar.baz.qux = 23
1007    assert foo.HasField('bar')
1008
1009  ...child objects must have back references to their parents.
1010  This helper class is at the heart of this support.
1011  """
1012
1013  def __init__(self, parent_message):
1014    """Args:
1015      parent_message: The message whose _Modified() method we should call when
1016        we receive Modified() messages.
1017    """
1018    # This listener establishes a back reference from a child (contained) object
1019    # to its parent (containing) object.  We make this a weak reference to avoid
1020    # creating cyclic garbage when the client finishes with the 'parent' object
1021    # in the tree.
1022    if isinstance(parent_message, weakref.ProxyType):
1023      self._parent_message_weakref = parent_message
1024    else:
1025      self._parent_message_weakref = weakref.proxy(parent_message)
1026
1027    # As an optimization, we also indicate directly on the listener whether
1028    # or not the parent message is dirty.  This way we can avoid traversing
1029    # up the tree in the common case.
1030    self.dirty = False
1031
1032  def Modified(self):
1033    if self.dirty:
1034      return
1035    try:
1036      # Propagate the signal to our parents iff this is the first field set.
1037      self._parent_message_weakref._Modified()
1038    except ReferenceError:
1039      # We can get here if a client has kept a reference to a child object,
1040      # and is now setting a field on it, but the child's parent has been
1041      # garbage-collected.  This is not an error.
1042      pass
1043
1044
1045# TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
1046# TODO(robinson): Unify error handling of "unknown extension" crap.
1047# TODO(robinson): Support iteritems()-style iteration over all
1048# extensions with the "has" bits turned on?
1049class _ExtensionDict(object):
1050
1051  """Dict-like container for supporting an indexable "Extensions"
1052  field on proto instances.
1053
1054  Note that in all cases we expect extension handles to be
1055  FieldDescriptors.
1056  """
1057
1058  def __init__(self, extended_message):
1059    """extended_message: Message instance for which we are the Extensions dict.
1060    """
1061
1062    self._extended_message = extended_message
1063
1064  def __getitem__(self, extension_handle):
1065    """Returns the current value of the given extension handle."""
1066
1067    _VerifyExtensionHandle(self._extended_message, extension_handle)
1068
1069    result = self._extended_message._fields.get(extension_handle)
1070    if result is not None:
1071      return result
1072
1073    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1074      result = extension_handle._default_constructor(self._extended_message)
1075    elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1076      result = extension_handle.message_type._concrete_class()
1077      try:
1078        result._SetListener(self._extended_message._listener_for_children)
1079      except ReferenceError:
1080        pass
1081    else:
1082      # Singular scalar -- just return the default without inserting into the
1083      # dict.
1084      return extension_handle.default_value
1085
1086    # Atomically check if another thread has preempted us and, if not, swap
1087    # in the new object we just created.  If someone has preempted us, we
1088    # take that object and discard ours.
1089    # WARNING:  We are relying on setdefault() being atomic.  This is true
1090    #   in CPython but we haven't investigated others.  This warning appears
1091    #   in several other locations in this file.
1092    result = self._extended_message._fields.setdefault(
1093        extension_handle, result)
1094
1095    return result
1096
1097  def __eq__(self, other):
1098    if not isinstance(other, self.__class__):
1099      return False
1100
1101    my_fields = self._extended_message.ListFields()
1102    other_fields = other._extended_message.ListFields()
1103
1104    # Get rid of non-extension fields.
1105    my_fields    = [ field for field in my_fields    if field.is_extension ]
1106    other_fields = [ field for field in other_fields if field.is_extension ]
1107
1108    return my_fields == other_fields
1109
1110  def __ne__(self, other):
1111    return not self == other
1112
1113  def __hash__(self):
1114    raise TypeError('unhashable object')
1115
1116  # Note that this is only meaningful for non-repeated, scalar extension
1117  # fields.  Note also that we may have to call _Modified() when we do
1118  # successfully set a field this way, to set any necssary "has" bits in the
1119  # ancestors of the extended message.
1120  def __setitem__(self, extension_handle, value):
1121    """If extension_handle specifies a non-repeated, scalar extension
1122    field, sets the value of that field.
1123    """
1124
1125    _VerifyExtensionHandle(self._extended_message, extension_handle)
1126
1127    if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1128        extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1129      raise TypeError(
1130          'Cannot assign to extension "%s" because it is a repeated or '
1131          'composite type.' % extension_handle.full_name)
1132
1133    # It's slightly wasteful to lookup the type checker each time,
1134    # but we expect this to be a vanishingly uncommon case anyway.
1135    type_checker = type_checkers.GetTypeChecker(
1136        extension_handle.cpp_type, extension_handle.type)
1137    type_checker.CheckValue(value)
1138    self._extended_message._fields[extension_handle] = value
1139    self._extended_message._Modified()
1140
1141  def _FindExtensionByName(self, name):
1142    """Tries to find a known extension with the specified name.
1143
1144    Args:
1145      name: Extension full name.
1146
1147    Returns:
1148      Extension field descriptor.
1149    """
1150    return self._extended_message._extensions_by_name.get(name, None)
1151