1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53from io import BytesIO 54import sys 55import struct 56import weakref 57 58import six 59try: 60 import six.moves.copyreg as copyreg 61except ImportError: 62 # On some platforms, for example gMac, we run native Python because there is 63 # nothing like hermetic Python. This means lesser control on the system and 64 # the six.moves package may be missing (is missing on 20150321 on gMac). Be 65 # extra conservative and try to load the old replacement if it fails. 66 import copy_reg as copyreg 67 68# We use "as" to avoid name collisions with variables. 69from google.protobuf.internal import containers 70from google.protobuf.internal import decoder 71from google.protobuf.internal import encoder 72from google.protobuf.internal import enum_type_wrapper 73from google.protobuf.internal import message_listener as message_listener_mod 74from google.protobuf.internal import type_checkers 75from google.protobuf.internal import well_known_types 76from google.protobuf.internal import wire_format 77from google.protobuf import descriptor as descriptor_mod 78from google.protobuf import message as message_mod 79from google.protobuf import text_format 80 81_FieldDescriptor = descriptor_mod.FieldDescriptor 82_AnyFullTypeName = 'google.protobuf.Any' 83 84 85class GeneratedProtocolMessageType(type): 86 87 """Metaclass for protocol message classes created at runtime from Descriptors. 88 89 We add implementations for all methods described in the Message class. We 90 also create properties to allow getting/setting all fields in the protocol 91 message. Finally, we create slots to prevent users from accidentally 92 "setting" nonexistent fields in the protocol message, which then wouldn't get 93 serialized / deserialized properly. 94 95 The protocol compiler currently uses this metaclass to create protocol 96 message classes at runtime. Clients can also manually create their own 97 classes at runtime, as in this example: 98 99 mydescriptor = Descriptor(.....) 100 factory = symbol_database.Default() 101 factory.pool.AddDescriptor(mydescriptor) 102 MyProtoClass = factory.GetPrototype(mydescriptor) 103 myproto_instance = MyProtoClass() 104 myproto.foo_field = 23 105 ... 106 """ 107 108 # Must be consistent with the protocol-compiler code in 109 # proto2/compiler/internal/generator.*. 110 _DESCRIPTOR_KEY = 'DESCRIPTOR' 111 112 def __new__(cls, name, bases, dictionary): 113 """Custom allocation for runtime-generated class types. 114 115 We override __new__ because this is apparently the only place 116 where we can meaningfully set __slots__ on the class we're creating(?). 117 (The interplay between metaclasses and slots is not very well-documented). 118 119 Args: 120 name: Name of the class (ignored, but required by the 121 metaclass protocol). 122 bases: Base classes of the class we're constructing. 123 (Should be message.Message). We ignore this field, but 124 it's required by the metaclass protocol 125 dictionary: The class dictionary of the class we're 126 constructing. dictionary[_DESCRIPTOR_KEY] must contain 127 a Descriptor object describing this protocol message 128 type. 129 130 Returns: 131 Newly-allocated class. 132 """ 133 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 134 if descriptor.full_name in well_known_types.WKTBASES: 135 bases += (well_known_types.WKTBASES[descriptor.full_name],) 136 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 137 _AddSlots(descriptor, dictionary) 138 139 superclass = super(GeneratedProtocolMessageType, cls) 140 new_class = superclass.__new__(cls, name, bases, dictionary) 141 return new_class 142 143 def __init__(cls, name, bases, dictionary): 144 """Here we perform the majority of our work on the class. 145 We add enum getters, an __init__ method, implementations 146 of all Message methods, and properties for all fields 147 in the protocol type. 148 149 Args: 150 name: Name of the class (ignored, but required by the 151 metaclass protocol). 152 bases: Base classes of the class we're constructing. 153 (Should be message.Message). We ignore this field, but 154 it's required by the metaclass protocol 155 dictionary: The class dictionary of the class we're 156 constructing. dictionary[_DESCRIPTOR_KEY] must contain 157 a Descriptor object describing this protocol message 158 type. 159 """ 160 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 161 cls._decoders_by_tag = {} 162 cls._extensions_by_name = {} 163 cls._extensions_by_number = {} 164 if (descriptor.has_options and 165 descriptor.GetOptions().message_set_wire_format): 166 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 167 decoder.MessageSetItemDecoder(cls._extensions_by_number), None) 168 169 # Attach stuff to each FieldDescriptor for quick lookup later on. 170 for field in descriptor.fields: 171 _AttachFieldHelpers(cls, field) 172 173 descriptor._concrete_class = cls # pylint: disable=protected-access 174 _AddEnumValues(descriptor, cls) 175 _AddInitMethod(descriptor, cls) 176 _AddPropertiesForFields(descriptor, cls) 177 _AddPropertiesForExtensions(descriptor, cls) 178 _AddStaticMethods(cls) 179 _AddMessageMethods(descriptor, cls) 180 _AddPrivateHelperMethods(descriptor, cls) 181 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) 182 183 superclass = super(GeneratedProtocolMessageType, cls) 184 superclass.__init__(name, bases, dictionary) 185 186 187# Stateless helpers for GeneratedProtocolMessageType below. 188# Outside clients should not access these directly. 189# 190# I opted not to make any of these methods on the metaclass, to make it more 191# clear that I'm not really using any state there and to keep clients from 192# thinking that they have direct access to these construction helpers. 193 194 195def _PropertyName(proto_field_name): 196 """Returns the name of the public property attribute which 197 clients can use to get and (in some cases) set the value 198 of a protocol message field. 199 200 Args: 201 proto_field_name: The protocol message field name, exactly 202 as it appears (or would appear) in a .proto file. 203 """ 204 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 205 # nnorwitz makes my day by writing: 206 # """ 207 # FYI. See the keyword module in the stdlib. This could be as simple as: 208 # 209 # if keyword.iskeyword(proto_field_name): 210 # return proto_field_name + "_" 211 # return proto_field_name 212 # """ 213 # Kenton says: The above is a BAD IDEA. People rely on being able to use 214 # getattr() and setattr() to reflectively manipulate field values. If we 215 # rename the properties, then every such user has to also make sure to apply 216 # the same transformation. Note that currently if you name a field "yield", 217 # you can still access it just fine using getattr/setattr -- it's not even 218 # that cumbersome to do so. 219 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 220 # position. 221 return proto_field_name 222 223 224def _VerifyExtensionHandle(message, extension_handle): 225 """Verify that the given extension handle is valid.""" 226 227 if not isinstance(extension_handle, _FieldDescriptor): 228 raise KeyError('HasExtension() expects an extension handle, got: %s' % 229 extension_handle) 230 231 if not extension_handle.is_extension: 232 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) 233 234 if not extension_handle.containing_type: 235 raise KeyError('"%s" is missing a containing_type.' 236 % extension_handle.full_name) 237 238 if extension_handle.containing_type is not message.DESCRIPTOR: 239 raise KeyError('Extension "%s" extends message type "%s", but this ' 240 'message is of type "%s".' % 241 (extension_handle.full_name, 242 extension_handle.containing_type.full_name, 243 message.DESCRIPTOR.full_name)) 244 245 246def _AddSlots(message_descriptor, dictionary): 247 """Adds a __slots__ entry to dictionary, containing the names of all valid 248 attributes for this message type. 249 250 Args: 251 message_descriptor: A Descriptor instance describing this message type. 252 dictionary: Class dictionary to which we'll add a '__slots__' entry. 253 """ 254 dictionary['__slots__'] = ['_cached_byte_size', 255 '_cached_byte_size_dirty', 256 '_fields', 257 '_unknown_fields', 258 '_is_present_in_parent', 259 '_listener', 260 '_listener_for_children', 261 '__weakref__', 262 '_oneofs'] 263 264 265def _IsMessageSetExtension(field): 266 return (field.is_extension and 267 field.containing_type.has_options and 268 field.containing_type.GetOptions().message_set_wire_format and 269 field.type == _FieldDescriptor.TYPE_MESSAGE and 270 field.label == _FieldDescriptor.LABEL_OPTIONAL) 271 272 273def _IsMapField(field): 274 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 275 field.message_type.has_options and 276 field.message_type.GetOptions().map_entry) 277 278 279def _IsMessageMapField(field): 280 value_type = field.message_type.fields_by_name["value"] 281 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 282 283 284def _AttachFieldHelpers(cls, field_descriptor): 285 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 286 is_packable = (is_repeated and 287 wire_format.IsTypePackable(field_descriptor.type)) 288 if not is_packable: 289 is_packed = False 290 elif field_descriptor.containing_type.syntax == "proto2": 291 is_packed = (field_descriptor.has_options and 292 field_descriptor.GetOptions().packed) 293 else: 294 has_packed_false = (field_descriptor.has_options and 295 field_descriptor.GetOptions().HasField("packed") and 296 field_descriptor.GetOptions().packed == False) 297 is_packed = not has_packed_false 298 is_map_entry = _IsMapField(field_descriptor) 299 300 if is_map_entry: 301 field_encoder = encoder.MapEncoder(field_descriptor) 302 sizer = encoder.MapSizer(field_descriptor) 303 elif _IsMessageSetExtension(field_descriptor): 304 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 305 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 306 else: 307 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 308 field_descriptor.number, is_repeated, is_packed) 309 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 310 field_descriptor.number, is_repeated, is_packed) 311 312 field_descriptor._encoder = field_encoder 313 field_descriptor._sizer = sizer 314 field_descriptor._default_constructor = _DefaultValueConstructorForField( 315 field_descriptor) 316 317 def AddDecoder(wiretype, is_packed): 318 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 319 decode_type = field_descriptor.type 320 if (decode_type == _FieldDescriptor.TYPE_ENUM and 321 type_checkers.SupportsOpenEnums(field_descriptor)): 322 decode_type = _FieldDescriptor.TYPE_INT32 323 324 oneof_descriptor = None 325 if field_descriptor.containing_oneof is not None: 326 oneof_descriptor = field_descriptor 327 328 if is_map_entry: 329 is_message_map = _IsMessageMapField(field_descriptor) 330 331 field_decoder = decoder.MapDecoder( 332 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 333 is_message_map) 334 else: 335 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 336 field_descriptor.number, is_repeated, is_packed, 337 field_descriptor, field_descriptor._default_constructor) 338 339 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 340 341 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 342 False) 343 344 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 345 # To support wire compatibility of adding packed = true, add a decoder for 346 # packed values regardless of the field's options. 347 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 348 349 350def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 351 extension_dict = descriptor.extensions_by_name 352 for extension_name, extension_field in extension_dict.items(): 353 assert extension_name not in dictionary 354 dictionary[extension_name] = extension_field 355 356 357def _AddEnumValues(descriptor, cls): 358 """Sets class-level attributes for all enum fields defined in this message. 359 360 Also exporting a class-level object that can name enum values. 361 362 Args: 363 descriptor: Descriptor object for this message type. 364 cls: Class we're constructing for this message type. 365 """ 366 for enum_type in descriptor.enum_types: 367 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 368 for enum_value in enum_type.values: 369 setattr(cls, enum_value.name, enum_value.number) 370 371 372def _GetInitializeDefaultForMap(field): 373 if field.label != _FieldDescriptor.LABEL_REPEATED: 374 raise ValueError('map_entry set on non-repeated field %s' % ( 375 field.name)) 376 fields_by_name = field.message_type.fields_by_name 377 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 378 379 value_field = fields_by_name['value'] 380 if _IsMessageMapField(field): 381 def MakeMessageMapDefault(message): 382 return containers.MessageMap( 383 message._listener_for_children, value_field.message_type, key_checker) 384 return MakeMessageMapDefault 385 else: 386 value_checker = type_checkers.GetTypeChecker(value_field) 387 def MakePrimitiveMapDefault(message): 388 return containers.ScalarMap( 389 message._listener_for_children, key_checker, value_checker) 390 return MakePrimitiveMapDefault 391 392def _DefaultValueConstructorForField(field): 393 """Returns a function which returns a default value for a field. 394 395 Args: 396 field: FieldDescriptor object for this field. 397 398 The returned function has one argument: 399 message: Message instance containing this field, or a weakref proxy 400 of same. 401 402 That function in turn returns a default value for this field. The default 403 value may refer back to |message| via a weak reference. 404 """ 405 406 if _IsMapField(field): 407 return _GetInitializeDefaultForMap(field) 408 409 if field.label == _FieldDescriptor.LABEL_REPEATED: 410 if field.has_default_value and field.default_value != []: 411 raise ValueError('Repeated field default value not empty list: %s' % ( 412 field.default_value)) 413 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 414 # We can't look at _concrete_class yet since it might not have 415 # been set. (Depends on order in which we initialize the classes). 416 message_type = field.message_type 417 def MakeRepeatedMessageDefault(message): 418 return containers.RepeatedCompositeFieldContainer( 419 message._listener_for_children, field.message_type) 420 return MakeRepeatedMessageDefault 421 else: 422 type_checker = type_checkers.GetTypeChecker(field) 423 def MakeRepeatedScalarDefault(message): 424 return containers.RepeatedScalarFieldContainer( 425 message._listener_for_children, type_checker) 426 return MakeRepeatedScalarDefault 427 428 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 429 # _concrete_class may not yet be initialized. 430 message_type = field.message_type 431 def MakeSubMessageDefault(message): 432 result = message_type._concrete_class() 433 result._SetListener( 434 _OneofListener(message, field) 435 if field.containing_oneof is not None 436 else message._listener_for_children) 437 return result 438 return MakeSubMessageDefault 439 440 def MakeScalarDefault(message): 441 # TODO(protobuf-team): This may be broken since there may not be 442 # default_value. Combine with has_default_value somehow. 443 return field.default_value 444 return MakeScalarDefault 445 446 447def _ReraiseTypeErrorWithFieldName(message_name, field_name): 448 """Re-raise the currently-handled TypeError with the field name added.""" 449 exc = sys.exc_info()[1] 450 if len(exc.args) == 1 and type(exc) is TypeError: 451 # simple TypeError; add field name to exception message 452 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 453 454 # re-raise possibly-amended exception with original traceback: 455 six.reraise(type(exc), exc, sys.exc_info()[2]) 456 457 458def _AddInitMethod(message_descriptor, cls): 459 """Adds an __init__ method to cls.""" 460 461 def _GetIntegerEnumValue(enum_type, value): 462 """Convert a string or integer enum value to an integer. 463 464 If the value is a string, it is converted to the enum value in 465 enum_type with the same name. If the value is not a string, it's 466 returned as-is. (No conversion or bounds-checking is done.) 467 """ 468 if isinstance(value, six.string_types): 469 try: 470 return enum_type.values_by_name[value].number 471 except KeyError: 472 raise ValueError('Enum type %s: unknown label "%s"' % ( 473 enum_type.full_name, value)) 474 return value 475 476 def init(self, **kwargs): 477 self._cached_byte_size = 0 478 self._cached_byte_size_dirty = len(kwargs) > 0 479 self._fields = {} 480 # Contains a mapping from oneof field descriptors to the descriptor 481 # of the currently set field in that oneof field. 482 self._oneofs = {} 483 484 # _unknown_fields is () when empty for efficiency, and will be turned into 485 # a list if fields are added. 486 self._unknown_fields = () 487 self._is_present_in_parent = False 488 self._listener = message_listener_mod.NullMessageListener() 489 self._listener_for_children = _Listener(self) 490 for field_name, field_value in kwargs.items(): 491 field = _GetFieldByName(message_descriptor, field_name) 492 if field is None: 493 raise TypeError("%s() got an unexpected keyword argument '%s'" % 494 (message_descriptor.name, field_name)) 495 if field_value is None: 496 # field=None is the same as no field at all. 497 continue 498 if field.label == _FieldDescriptor.LABEL_REPEATED: 499 copy = field._default_constructor(self) 500 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 501 if _IsMapField(field): 502 if _IsMessageMapField(field): 503 for key in field_value: 504 copy[key].MergeFrom(field_value[key]) 505 else: 506 copy.update(field_value) 507 else: 508 for val in field_value: 509 if isinstance(val, dict): 510 copy.add(**val) 511 else: 512 copy.add().MergeFrom(val) 513 else: # Scalar 514 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 515 field_value = [_GetIntegerEnumValue(field.enum_type, val) 516 for val in field_value] 517 copy.extend(field_value) 518 self._fields[field] = copy 519 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 520 copy = field._default_constructor(self) 521 new_val = field_value 522 if isinstance(field_value, dict): 523 new_val = field.message_type._concrete_class(**field_value) 524 try: 525 copy.MergeFrom(new_val) 526 except TypeError: 527 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 528 self._fields[field] = copy 529 else: 530 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 531 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 532 try: 533 setattr(self, field_name, field_value) 534 except TypeError: 535 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 536 537 init.__module__ = None 538 init.__doc__ = None 539 cls.__init__ = init 540 541 542def _GetFieldByName(message_descriptor, field_name): 543 """Returns a field descriptor by field name. 544 545 Args: 546 message_descriptor: A Descriptor describing all fields in message. 547 field_name: The name of the field to retrieve. 548 Returns: 549 The field descriptor associated with the field name. 550 """ 551 try: 552 return message_descriptor.fields_by_name[field_name] 553 except KeyError: 554 raise ValueError('Protocol message %s has no "%s" field.' % 555 (message_descriptor.name, field_name)) 556 557 558def _AddPropertiesForFields(descriptor, cls): 559 """Adds properties for all fields in this protocol message type.""" 560 for field in descriptor.fields: 561 _AddPropertiesForField(field, cls) 562 563 if descriptor.is_extendable: 564 # _ExtensionDict is just an adaptor with no state so we allocate a new one 565 # every time it is accessed. 566 cls.Extensions = property(lambda self: _ExtensionDict(self)) 567 568 569def _AddPropertiesForField(field, cls): 570 """Adds a public property for a protocol message field. 571 Clients can use this property to get and (in the case 572 of non-repeated scalar fields) directly set the value 573 of a protocol message field. 574 575 Args: 576 field: A FieldDescriptor for this field. 577 cls: The class we're constructing. 578 """ 579 # Catch it if we add other types that we should 580 # handle specially here. 581 assert _FieldDescriptor.MAX_CPPTYPE == 10 582 583 constant_name = field.name.upper() + "_FIELD_NUMBER" 584 setattr(cls, constant_name, field.number) 585 586 if field.label == _FieldDescriptor.LABEL_REPEATED: 587 _AddPropertiesForRepeatedField(field, cls) 588 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 589 _AddPropertiesForNonRepeatedCompositeField(field, cls) 590 else: 591 _AddPropertiesForNonRepeatedScalarField(field, cls) 592 593 594def _AddPropertiesForRepeatedField(field, cls): 595 """Adds a public property for a "repeated" protocol message field. Clients 596 can use this property to get the value of the field, which will be either a 597 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 598 below). 599 600 Note that when clients add values to these containers, we perform 601 type-checking in the case of repeated scalar fields, and we also set any 602 necessary "has" bits as a side-effect. 603 604 Args: 605 field: A FieldDescriptor for this field. 606 cls: The class we're constructing. 607 """ 608 proto_field_name = field.name 609 property_name = _PropertyName(proto_field_name) 610 611 def getter(self): 612 field_value = self._fields.get(field) 613 if field_value is None: 614 # Construct a new object to represent this field. 615 field_value = field._default_constructor(self) 616 617 # Atomically check if another thread has preempted us and, if not, swap 618 # in the new object we just created. If someone has preempted us, we 619 # take that object and discard ours. 620 # WARNING: We are relying on setdefault() being atomic. This is true 621 # in CPython but we haven't investigated others. This warning appears 622 # in several other locations in this file. 623 field_value = self._fields.setdefault(field, field_value) 624 return field_value 625 getter.__module__ = None 626 getter.__doc__ = 'Getter for %s.' % proto_field_name 627 628 # We define a setter just so we can throw an exception with a more 629 # helpful error message. 630 def setter(self, new_value): 631 raise AttributeError('Assignment not allowed to repeated field ' 632 '"%s" in protocol message object.' % proto_field_name) 633 634 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 635 setattr(cls, property_name, property(getter, setter, doc=doc)) 636 637 638def _AddPropertiesForNonRepeatedScalarField(field, cls): 639 """Adds a public property for a nonrepeated, scalar protocol message field. 640 Clients can use this property to get and directly set the value of the field. 641 Note that when the client sets the value of a field by using this property, 642 all necessary "has" bits are set as a side-effect, and we also perform 643 type-checking. 644 645 Args: 646 field: A FieldDescriptor for this field. 647 cls: The class we're constructing. 648 """ 649 proto_field_name = field.name 650 property_name = _PropertyName(proto_field_name) 651 type_checker = type_checkers.GetTypeChecker(field) 652 default_value = field.default_value 653 valid_values = set() 654 is_proto3 = field.containing_type.syntax == "proto3" 655 656 def getter(self): 657 # TODO(protobuf-team): This may be broken since there may not be 658 # default_value. Combine with has_default_value somehow. 659 return self._fields.get(field, default_value) 660 getter.__module__ = None 661 getter.__doc__ = 'Getter for %s.' % proto_field_name 662 663 clear_when_set_to_default = is_proto3 and not field.containing_oneof 664 665 def field_setter(self, new_value): 666 # pylint: disable=protected-access 667 # Testing the value for truthiness captures all of the proto3 defaults 668 # (0, 0.0, enum 0, and False). 669 new_value = type_checker.CheckValue(new_value) 670 if clear_when_set_to_default and not new_value: 671 self._fields.pop(field, None) 672 else: 673 self._fields[field] = new_value 674 # Check _cached_byte_size_dirty inline to improve performance, since scalar 675 # setters are called frequently. 676 if not self._cached_byte_size_dirty: 677 self._Modified() 678 679 if field.containing_oneof: 680 def setter(self, new_value): 681 field_setter(self, new_value) 682 self._UpdateOneofState(field) 683 else: 684 setter = field_setter 685 686 setter.__module__ = None 687 setter.__doc__ = 'Setter for %s.' % proto_field_name 688 689 # Add a property to encapsulate the getter/setter. 690 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 691 setattr(cls, property_name, property(getter, setter, doc=doc)) 692 693 694def _AddPropertiesForNonRepeatedCompositeField(field, cls): 695 """Adds a public property for a nonrepeated, composite protocol message field. 696 A composite field is a "group" or "message" field. 697 698 Clients can use this property to get the value of the field, but cannot 699 assign to the property directly. 700 701 Args: 702 field: A FieldDescriptor for this field. 703 cls: The class we're constructing. 704 """ 705 # TODO(robinson): Remove duplication with similar method 706 # for non-repeated scalars. 707 proto_field_name = field.name 708 property_name = _PropertyName(proto_field_name) 709 710 def getter(self): 711 field_value = self._fields.get(field) 712 if field_value is None: 713 # Construct a new object to represent this field. 714 field_value = field._default_constructor(self) 715 716 # Atomically check if another thread has preempted us and, if not, swap 717 # in the new object we just created. If someone has preempted us, we 718 # take that object and discard ours. 719 # WARNING: We are relying on setdefault() being atomic. This is true 720 # in CPython but we haven't investigated others. This warning appears 721 # in several other locations in this file. 722 field_value = self._fields.setdefault(field, field_value) 723 return field_value 724 getter.__module__ = None 725 getter.__doc__ = 'Getter for %s.' % proto_field_name 726 727 # We define a setter just so we can throw an exception with a more 728 # helpful error message. 729 def setter(self, new_value): 730 raise AttributeError('Assignment not allowed to composite field ' 731 '"%s" in protocol message object.' % proto_field_name) 732 733 # Add a property to encapsulate the getter. 734 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 735 setattr(cls, property_name, property(getter, setter, doc=doc)) 736 737 738def _AddPropertiesForExtensions(descriptor, cls): 739 """Adds properties for all fields in this protocol message type.""" 740 extension_dict = descriptor.extensions_by_name 741 for extension_name, extension_field in extension_dict.items(): 742 constant_name = extension_name.upper() + "_FIELD_NUMBER" 743 setattr(cls, constant_name, extension_field.number) 744 745 746def _AddStaticMethods(cls): 747 # TODO(robinson): This probably needs to be thread-safe(?) 748 def RegisterExtension(extension_handle): 749 extension_handle.containing_type = cls.DESCRIPTOR 750 _AttachFieldHelpers(cls, extension_handle) 751 752 # Try to insert our extension, failing if an extension with the same number 753 # already exists. 754 actual_handle = cls._extensions_by_number.setdefault( 755 extension_handle.number, extension_handle) 756 if actual_handle is not extension_handle: 757 raise AssertionError( 758 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 759 'field number %d.' % 760 (extension_handle.full_name, actual_handle.full_name, 761 cls.DESCRIPTOR.full_name, extension_handle.number)) 762 763 cls._extensions_by_name[extension_handle.full_name] = extension_handle 764 765 handle = extension_handle # avoid line wrapping 766 if _IsMessageSetExtension(handle): 767 # MessageSet extension. Also register under type name. 768 cls._extensions_by_name[ 769 extension_handle.message_type.full_name] = extension_handle 770 771 cls.RegisterExtension = staticmethod(RegisterExtension) 772 773 def FromString(s): 774 message = cls() 775 message.MergeFromString(s) 776 return message 777 cls.FromString = staticmethod(FromString) 778 779 780def _IsPresent(item): 781 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 782 value should be included in the list returned by ListFields().""" 783 784 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 785 return bool(item[1]) 786 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 787 return item[1]._is_present_in_parent 788 else: 789 return True 790 791 792def _AddListFieldsMethod(message_descriptor, cls): 793 """Helper for _AddMessageMethods().""" 794 795 def ListFields(self): 796 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 797 all_fields.sort(key = lambda item: item[0].number) 798 return all_fields 799 800 cls.ListFields = ListFields 801 802_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' 803_Proto2HasError = 'Protocol message has no non-repeated field "%s"' 804 805def _AddHasFieldMethod(message_descriptor, cls): 806 """Helper for _AddMessageMethods().""" 807 808 is_proto3 = (message_descriptor.syntax == "proto3") 809 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError 810 811 hassable_fields = {} 812 for field in message_descriptor.fields: 813 if field.label == _FieldDescriptor.LABEL_REPEATED: 814 continue 815 # For proto3, only submessages and fields inside a oneof have presence. 816 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 817 not field.containing_oneof): 818 continue 819 hassable_fields[field.name] = field 820 821 if not is_proto3: 822 # Fields inside oneofs are never repeated (enforced by the compiler). 823 for oneof in message_descriptor.oneofs: 824 hassable_fields[oneof.name] = oneof 825 826 def HasField(self, field_name): 827 try: 828 field = hassable_fields[field_name] 829 except KeyError: 830 raise ValueError(error_msg % field_name) 831 832 if isinstance(field, descriptor_mod.OneofDescriptor): 833 try: 834 return HasField(self, self._oneofs[field].name) 835 except KeyError: 836 return False 837 else: 838 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 839 value = self._fields.get(field) 840 return value is not None and value._is_present_in_parent 841 else: 842 return field in self._fields 843 844 cls.HasField = HasField 845 846 847def _AddClearFieldMethod(message_descriptor, cls): 848 """Helper for _AddMessageMethods().""" 849 def ClearField(self, field_name): 850 try: 851 field = message_descriptor.fields_by_name[field_name] 852 except KeyError: 853 try: 854 field = message_descriptor.oneofs_by_name[field_name] 855 if field in self._oneofs: 856 field = self._oneofs[field] 857 else: 858 return 859 except KeyError: 860 raise ValueError('Protocol message %s() has no "%s" field.' % 861 (message_descriptor.name, field_name)) 862 863 if field in self._fields: 864 # To match the C++ implementation, we need to invalidate iterators 865 # for map fields when ClearField() happens. 866 if hasattr(self._fields[field], 'InvalidateIterators'): 867 self._fields[field].InvalidateIterators() 868 869 # Note: If the field is a sub-message, its listener will still point 870 # at us. That's fine, because the worst than can happen is that it 871 # will call _Modified() and invalidate our byte size. Big deal. 872 del self._fields[field] 873 874 if self._oneofs.get(field.containing_oneof, None) is field: 875 del self._oneofs[field.containing_oneof] 876 877 # Always call _Modified() -- even if nothing was changed, this is 878 # a mutating method, and thus calling it should cause the field to become 879 # present in the parent message. 880 self._Modified() 881 882 cls.ClearField = ClearField 883 884 885def _AddClearExtensionMethod(cls): 886 """Helper for _AddMessageMethods().""" 887 def ClearExtension(self, extension_handle): 888 _VerifyExtensionHandle(self, extension_handle) 889 890 # Similar to ClearField(), above. 891 if extension_handle in self._fields: 892 del self._fields[extension_handle] 893 self._Modified() 894 cls.ClearExtension = ClearExtension 895 896 897def _AddHasExtensionMethod(cls): 898 """Helper for _AddMessageMethods().""" 899 def HasExtension(self, extension_handle): 900 _VerifyExtensionHandle(self, extension_handle) 901 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 902 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 903 904 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 905 value = self._fields.get(extension_handle) 906 return value is not None and value._is_present_in_parent 907 else: 908 return extension_handle in self._fields 909 cls.HasExtension = HasExtension 910 911def _InternalUnpackAny(msg): 912 """Unpacks Any message and returns the unpacked message. 913 914 This internal method is differnt from public Any Unpack method which takes 915 the target message as argument. _InternalUnpackAny method does not have 916 target message type and need to find the message type in descriptor pool. 917 918 Args: 919 msg: An Any message to be unpacked. 920 921 Returns: 922 The unpacked message. 923 """ 924 # TODO(amauryfa): Don't use the factory of generated messages. 925 # To make Any work with custom factories, use the message factory of the 926 # parent message. 927 # pylint: disable=g-import-not-at-top 928 from google.protobuf import symbol_database 929 factory = symbol_database.Default() 930 931 type_url = msg.type_url 932 933 if not type_url: 934 return None 935 936 # TODO(haberman): For now we just strip the hostname. Better logic will be 937 # required. 938 type_name = type_url.split('/')[-1] 939 descriptor = factory.pool.FindMessageTypeByName(type_name) 940 941 if descriptor is None: 942 return None 943 944 message_class = factory.GetPrototype(descriptor) 945 message = message_class() 946 947 message.ParseFromString(msg.value) 948 return message 949 950 951def _AddEqualsMethod(message_descriptor, cls): 952 """Helper for _AddMessageMethods().""" 953 def __eq__(self, other): 954 if (not isinstance(other, message_mod.Message) or 955 other.DESCRIPTOR != self.DESCRIPTOR): 956 return False 957 958 if self is other: 959 return True 960 961 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 962 any_a = _InternalUnpackAny(self) 963 any_b = _InternalUnpackAny(other) 964 if any_a and any_b: 965 return any_a == any_b 966 967 if not self.ListFields() == other.ListFields(): 968 return False 969 970 # Sort unknown fields because their order shouldn't affect equality test. 971 unknown_fields = list(self._unknown_fields) 972 unknown_fields.sort() 973 other_unknown_fields = list(other._unknown_fields) 974 other_unknown_fields.sort() 975 976 return unknown_fields == other_unknown_fields 977 978 cls.__eq__ = __eq__ 979 980 981def _AddStrMethod(message_descriptor, cls): 982 """Helper for _AddMessageMethods().""" 983 def __str__(self): 984 return text_format.MessageToString(self) 985 cls.__str__ = __str__ 986 987 988def _AddReprMethod(message_descriptor, cls): 989 """Helper for _AddMessageMethods().""" 990 def __repr__(self): 991 return text_format.MessageToString(self) 992 cls.__repr__ = __repr__ 993 994 995def _AddUnicodeMethod(unused_message_descriptor, cls): 996 """Helper for _AddMessageMethods().""" 997 998 def __unicode__(self): 999 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1000 cls.__unicode__ = __unicode__ 1001 1002 1003def _BytesForNonRepeatedElement(value, field_number, field_type): 1004 """Returns the number of bytes needed to serialize a non-repeated element. 1005 The returned byte count includes space for tag information and any 1006 other additional space associated with serializing value. 1007 1008 Args: 1009 value: Value we're serializing. 1010 field_number: Field number of this value. (Since the field number 1011 is stored as part of a varint-encoded tag, this has an impact 1012 on the total bytes required to serialize the value). 1013 field_type: The type of the field. One of the TYPE_* constants 1014 within FieldDescriptor. 1015 """ 1016 try: 1017 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1018 return fn(field_number, value) 1019 except KeyError: 1020 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1021 1022 1023def _AddByteSizeMethod(message_descriptor, cls): 1024 """Helper for _AddMessageMethods().""" 1025 1026 def ByteSize(self): 1027 if not self._cached_byte_size_dirty: 1028 return self._cached_byte_size 1029 1030 size = 0 1031 for field_descriptor, field_value in self.ListFields(): 1032 size += field_descriptor._sizer(field_value) 1033 1034 for tag_bytes, value_bytes in self._unknown_fields: 1035 size += len(tag_bytes) + len(value_bytes) 1036 1037 self._cached_byte_size = size 1038 self._cached_byte_size_dirty = False 1039 self._listener_for_children.dirty = False 1040 return size 1041 1042 cls.ByteSize = ByteSize 1043 1044 1045def _AddSerializeToStringMethod(message_descriptor, cls): 1046 """Helper for _AddMessageMethods().""" 1047 1048 def SerializeToString(self): 1049 # Check if the message has all of its required fields set. 1050 errors = [] 1051 if not self.IsInitialized(): 1052 raise message_mod.EncodeError( 1053 'Message %s is missing required fields: %s' % ( 1054 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1055 return self.SerializePartialToString() 1056 cls.SerializeToString = SerializeToString 1057 1058 1059def _AddSerializePartialToStringMethod(message_descriptor, cls): 1060 """Helper for _AddMessageMethods().""" 1061 1062 def SerializePartialToString(self): 1063 out = BytesIO() 1064 self._InternalSerialize(out.write) 1065 return out.getvalue() 1066 cls.SerializePartialToString = SerializePartialToString 1067 1068 def InternalSerialize(self, write_bytes): 1069 for field_descriptor, field_value in self.ListFields(): 1070 field_descriptor._encoder(write_bytes, field_value) 1071 for tag_bytes, value_bytes in self._unknown_fields: 1072 write_bytes(tag_bytes) 1073 write_bytes(value_bytes) 1074 cls._InternalSerialize = InternalSerialize 1075 1076 1077def _AddMergeFromStringMethod(message_descriptor, cls): 1078 """Helper for _AddMessageMethods().""" 1079 def MergeFromString(self, serialized): 1080 length = len(serialized) 1081 try: 1082 if self._InternalParse(serialized, 0, length) != length: 1083 # The only reason _InternalParse would return early is if it 1084 # encountered an end-group tag. 1085 raise message_mod.DecodeError('Unexpected end-group tag.') 1086 except (IndexError, TypeError): 1087 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1088 raise message_mod.DecodeError('Truncated message.') 1089 except struct.error as e: 1090 raise message_mod.DecodeError(e) 1091 return length # Return this for legacy reasons. 1092 cls.MergeFromString = MergeFromString 1093 1094 local_ReadTag = decoder.ReadTag 1095 local_SkipField = decoder.SkipField 1096 decoders_by_tag = cls._decoders_by_tag 1097 is_proto3 = message_descriptor.syntax == "proto3" 1098 1099 def InternalParse(self, buffer, pos, end): 1100 self._Modified() 1101 field_dict = self._fields 1102 unknown_field_list = self._unknown_fields 1103 while pos != end: 1104 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1105 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1106 if field_decoder is None: 1107 value_start_pos = new_pos 1108 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 1109 if new_pos == -1: 1110 return pos 1111 if not is_proto3: 1112 if not unknown_field_list: 1113 unknown_field_list = self._unknown_fields = [] 1114 unknown_field_list.append( 1115 (tag_bytes, buffer[value_start_pos:new_pos])) 1116 pos = new_pos 1117 else: 1118 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1119 if field_desc: 1120 self._UpdateOneofState(field_desc) 1121 return pos 1122 cls._InternalParse = InternalParse 1123 1124 1125def _AddIsInitializedMethod(message_descriptor, cls): 1126 """Adds the IsInitialized and FindInitializationError methods to the 1127 protocol message class.""" 1128 1129 required_fields = [field for field in message_descriptor.fields 1130 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1131 1132 def IsInitialized(self, errors=None): 1133 """Checks if all required fields of a message are set. 1134 1135 Args: 1136 errors: A list which, if provided, will be populated with the field 1137 paths of all missing required fields. 1138 1139 Returns: 1140 True iff the specified message has all required fields set. 1141 """ 1142 1143 # Performance is critical so we avoid HasField() and ListFields(). 1144 1145 for field in required_fields: 1146 if (field not in self._fields or 1147 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1148 not self._fields[field]._is_present_in_parent)): 1149 if errors is not None: 1150 errors.extend(self.FindInitializationErrors()) 1151 return False 1152 1153 for field, value in list(self._fields.items()): # dict can change size! 1154 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1155 if field.label == _FieldDescriptor.LABEL_REPEATED: 1156 if (field.message_type.has_options and 1157 field.message_type.GetOptions().map_entry): 1158 continue 1159 for element in value: 1160 if not element.IsInitialized(): 1161 if errors is not None: 1162 errors.extend(self.FindInitializationErrors()) 1163 return False 1164 elif value._is_present_in_parent and not value.IsInitialized(): 1165 if errors is not None: 1166 errors.extend(self.FindInitializationErrors()) 1167 return False 1168 1169 return True 1170 1171 cls.IsInitialized = IsInitialized 1172 1173 def FindInitializationErrors(self): 1174 """Finds required fields which are not initialized. 1175 1176 Returns: 1177 A list of strings. Each string is a path to an uninitialized field from 1178 the top-level message, e.g. "foo.bar[5].baz". 1179 """ 1180 1181 errors = [] # simplify things 1182 1183 for field in required_fields: 1184 if not self.HasField(field.name): 1185 errors.append(field.name) 1186 1187 for field, value in self.ListFields(): 1188 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1189 if field.is_extension: 1190 name = "(%s)" % field.full_name 1191 else: 1192 name = field.name 1193 1194 if _IsMapField(field): 1195 if _IsMessageMapField(field): 1196 for key in value: 1197 element = value[key] 1198 prefix = "%s[%s]." % (name, key) 1199 sub_errors = element.FindInitializationErrors() 1200 errors += [prefix + error for error in sub_errors] 1201 else: 1202 # ScalarMaps can't have any initialization errors. 1203 pass 1204 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1205 for i in range(len(value)): 1206 element = value[i] 1207 prefix = "%s[%d]." % (name, i) 1208 sub_errors = element.FindInitializationErrors() 1209 errors += [prefix + error for error in sub_errors] 1210 else: 1211 prefix = name + "." 1212 sub_errors = value.FindInitializationErrors() 1213 errors += [prefix + error for error in sub_errors] 1214 1215 return errors 1216 1217 cls.FindInitializationErrors = FindInitializationErrors 1218 1219 1220def _AddMergeFromMethod(cls): 1221 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1222 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1223 1224 def MergeFrom(self, msg): 1225 if not isinstance(msg, cls): 1226 raise TypeError( 1227 "Parameter to MergeFrom() must be instance of same class: " 1228 "expected %s got %s." % (cls.__name__, type(msg).__name__)) 1229 1230 assert msg is not self 1231 self._Modified() 1232 1233 fields = self._fields 1234 1235 for field, value in msg._fields.items(): 1236 if field.label == LABEL_REPEATED: 1237 field_value = fields.get(field) 1238 if field_value is None: 1239 # Construct a new object to represent this field. 1240 field_value = field._default_constructor(self) 1241 fields[field] = field_value 1242 field_value.MergeFrom(value) 1243 elif field.cpp_type == CPPTYPE_MESSAGE: 1244 if value._is_present_in_parent: 1245 field_value = fields.get(field) 1246 if field_value is None: 1247 # Construct a new object to represent this field. 1248 field_value = field._default_constructor(self) 1249 fields[field] = field_value 1250 field_value.MergeFrom(value) 1251 else: 1252 self._fields[field] = value 1253 if field.containing_oneof: 1254 self._UpdateOneofState(field) 1255 1256 if msg._unknown_fields: 1257 if not self._unknown_fields: 1258 self._unknown_fields = [] 1259 self._unknown_fields.extend(msg._unknown_fields) 1260 1261 cls.MergeFrom = MergeFrom 1262 1263 1264def _AddWhichOneofMethod(message_descriptor, cls): 1265 def WhichOneof(self, oneof_name): 1266 """Returns the name of the currently set field inside a oneof, or None.""" 1267 try: 1268 field = message_descriptor.oneofs_by_name[oneof_name] 1269 except KeyError: 1270 raise ValueError( 1271 'Protocol message has no oneof "%s" field.' % oneof_name) 1272 1273 nested_field = self._oneofs.get(field, None) 1274 if nested_field is not None and self.HasField(nested_field.name): 1275 return nested_field.name 1276 else: 1277 return None 1278 1279 cls.WhichOneof = WhichOneof 1280 1281 1282def _Clear(self): 1283 # Clear fields. 1284 self._fields = {} 1285 self._unknown_fields = () 1286 self._oneofs = {} 1287 self._Modified() 1288 1289 1290def _DiscardUnknownFields(self): 1291 self._unknown_fields = [] 1292 for field, value in self.ListFields(): 1293 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1294 if field.label == _FieldDescriptor.LABEL_REPEATED: 1295 for sub_message in value: 1296 sub_message.DiscardUnknownFields() 1297 else: 1298 value.DiscardUnknownFields() 1299 1300 1301def _SetListener(self, listener): 1302 if listener is None: 1303 self._listener = message_listener_mod.NullMessageListener() 1304 else: 1305 self._listener = listener 1306 1307 1308def _AddMessageMethods(message_descriptor, cls): 1309 """Adds implementations of all Message methods to cls.""" 1310 _AddListFieldsMethod(message_descriptor, cls) 1311 _AddHasFieldMethod(message_descriptor, cls) 1312 _AddClearFieldMethod(message_descriptor, cls) 1313 if message_descriptor.is_extendable: 1314 _AddClearExtensionMethod(cls) 1315 _AddHasExtensionMethod(cls) 1316 _AddEqualsMethod(message_descriptor, cls) 1317 _AddStrMethod(message_descriptor, cls) 1318 _AddReprMethod(message_descriptor, cls) 1319 _AddUnicodeMethod(message_descriptor, cls) 1320 _AddByteSizeMethod(message_descriptor, cls) 1321 _AddSerializeToStringMethod(message_descriptor, cls) 1322 _AddSerializePartialToStringMethod(message_descriptor, cls) 1323 _AddMergeFromStringMethod(message_descriptor, cls) 1324 _AddIsInitializedMethod(message_descriptor, cls) 1325 _AddMergeFromMethod(cls) 1326 _AddWhichOneofMethod(message_descriptor, cls) 1327 # Adds methods which do not depend on cls. 1328 cls.Clear = _Clear 1329 cls.DiscardUnknownFields = _DiscardUnknownFields 1330 cls._SetListener = _SetListener 1331 1332 1333def _AddPrivateHelperMethods(message_descriptor, cls): 1334 """Adds implementation of private helper methods to cls.""" 1335 1336 def Modified(self): 1337 """Sets the _cached_byte_size_dirty bit to true, 1338 and propagates this to our listener iff this was a state change. 1339 """ 1340 1341 # Note: Some callers check _cached_byte_size_dirty before calling 1342 # _Modified() as an extra optimization. So, if this method is ever 1343 # changed such that it does stuff even when _cached_byte_size_dirty is 1344 # already true, the callers need to be updated. 1345 if not self._cached_byte_size_dirty: 1346 self._cached_byte_size_dirty = True 1347 self._listener_for_children.dirty = True 1348 self._is_present_in_parent = True 1349 self._listener.Modified() 1350 1351 def _UpdateOneofState(self, field): 1352 """Sets field as the active field in its containing oneof. 1353 1354 Will also delete currently active field in the oneof, if it is different 1355 from the argument. Does not mark the message as modified. 1356 """ 1357 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1358 if other_field is not field: 1359 del self._fields[other_field] 1360 self._oneofs[field.containing_oneof] = field 1361 1362 cls._Modified = Modified 1363 cls.SetInParent = Modified 1364 cls._UpdateOneofState = _UpdateOneofState 1365 1366 1367class _Listener(object): 1368 1369 """MessageListener implementation that a parent message registers with its 1370 child message. 1371 1372 In order to support semantics like: 1373 1374 foo.bar.baz.qux = 23 1375 assert foo.HasField('bar') 1376 1377 ...child objects must have back references to their parents. 1378 This helper class is at the heart of this support. 1379 """ 1380 1381 def __init__(self, parent_message): 1382 """Args: 1383 parent_message: The message whose _Modified() method we should call when 1384 we receive Modified() messages. 1385 """ 1386 # This listener establishes a back reference from a child (contained) object 1387 # to its parent (containing) object. We make this a weak reference to avoid 1388 # creating cyclic garbage when the client finishes with the 'parent' object 1389 # in the tree. 1390 if isinstance(parent_message, weakref.ProxyType): 1391 self._parent_message_weakref = parent_message 1392 else: 1393 self._parent_message_weakref = weakref.proxy(parent_message) 1394 1395 # As an optimization, we also indicate directly on the listener whether 1396 # or not the parent message is dirty. This way we can avoid traversing 1397 # up the tree in the common case. 1398 self.dirty = False 1399 1400 def Modified(self): 1401 if self.dirty: 1402 return 1403 try: 1404 # Propagate the signal to our parents iff this is the first field set. 1405 self._parent_message_weakref._Modified() 1406 except ReferenceError: 1407 # We can get here if a client has kept a reference to a child object, 1408 # and is now setting a field on it, but the child's parent has been 1409 # garbage-collected. This is not an error. 1410 pass 1411 1412 1413class _OneofListener(_Listener): 1414 """Special listener implementation for setting composite oneof fields.""" 1415 1416 def __init__(self, parent_message, field): 1417 """Args: 1418 parent_message: The message whose _Modified() method we should call when 1419 we receive Modified() messages. 1420 field: The descriptor of the field being set in the parent message. 1421 """ 1422 super(_OneofListener, self).__init__(parent_message) 1423 self._field = field 1424 1425 def Modified(self): 1426 """Also updates the state of the containing oneof in the parent message.""" 1427 try: 1428 self._parent_message_weakref._UpdateOneofState(self._field) 1429 super(_OneofListener, self).Modified() 1430 except ReferenceError: 1431 pass 1432 1433 1434# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1435# TODO(robinson): Unify error handling of "unknown extension" crap. 1436# TODO(robinson): Support iteritems()-style iteration over all 1437# extensions with the "has" bits turned on? 1438class _ExtensionDict(object): 1439 1440 """Dict-like container for supporting an indexable "Extensions" 1441 field on proto instances. 1442 1443 Note that in all cases we expect extension handles to be 1444 FieldDescriptors. 1445 """ 1446 1447 def __init__(self, extended_message): 1448 """extended_message: Message instance for which we are the Extensions dict. 1449 """ 1450 1451 self._extended_message = extended_message 1452 1453 def __getitem__(self, extension_handle): 1454 """Returns the current value of the given extension handle.""" 1455 1456 _VerifyExtensionHandle(self._extended_message, extension_handle) 1457 1458 result = self._extended_message._fields.get(extension_handle) 1459 if result is not None: 1460 return result 1461 1462 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 1463 result = extension_handle._default_constructor(self._extended_message) 1464 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1465 result = extension_handle.message_type._concrete_class() 1466 try: 1467 result._SetListener(self._extended_message._listener_for_children) 1468 except ReferenceError: 1469 pass 1470 else: 1471 # Singular scalar -- just return the default without inserting into the 1472 # dict. 1473 return extension_handle.default_value 1474 1475 # Atomically check if another thread has preempted us and, if not, swap 1476 # in the new object we just created. If someone has preempted us, we 1477 # take that object and discard ours. 1478 # WARNING: We are relying on setdefault() being atomic. This is true 1479 # in CPython but we haven't investigated others. This warning appears 1480 # in several other locations in this file. 1481 result = self._extended_message._fields.setdefault( 1482 extension_handle, result) 1483 1484 return result 1485 1486 def __eq__(self, other): 1487 if not isinstance(other, self.__class__): 1488 return False 1489 1490 my_fields = self._extended_message.ListFields() 1491 other_fields = other._extended_message.ListFields() 1492 1493 # Get rid of non-extension fields. 1494 my_fields = [ field for field in my_fields if field.is_extension ] 1495 other_fields = [ field for field in other_fields if field.is_extension ] 1496 1497 return my_fields == other_fields 1498 1499 def __ne__(self, other): 1500 return not self == other 1501 1502 def __hash__(self): 1503 raise TypeError('unhashable object') 1504 1505 # Note that this is only meaningful for non-repeated, scalar extension 1506 # fields. Note also that we may have to call _Modified() when we do 1507 # successfully set a field this way, to set any necssary "has" bits in the 1508 # ancestors of the extended message. 1509 def __setitem__(self, extension_handle, value): 1510 """If extension_handle specifies a non-repeated, scalar extension 1511 field, sets the value of that field. 1512 """ 1513 1514 _VerifyExtensionHandle(self._extended_message, extension_handle) 1515 1516 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1517 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1518 raise TypeError( 1519 'Cannot assign to extension "%s" because it is a repeated or ' 1520 'composite type.' % extension_handle.full_name) 1521 1522 # It's slightly wasteful to lookup the type checker each time, 1523 # but we expect this to be a vanishingly uncommon case anyway. 1524 type_checker = type_checkers.GetTypeChecker(extension_handle) 1525 # pylint: disable=protected-access 1526 self._extended_message._fields[extension_handle] = ( 1527 type_checker.CheckValue(value)) 1528 self._extended_message._Modified() 1529 1530 def _FindExtensionByName(self, name): 1531 """Tries to find a known extension with the specified name. 1532 1533 Args: 1534 name: Extension full name. 1535 1536 Returns: 1537 Extension field descriptor. 1538 """ 1539 return self._extended_message._extensions_by_name.get(name, None) 1540 1541 def _FindExtensionByNumber(self, number): 1542 """Tries to find a known extension with the field number. 1543 1544 Args: 1545 number: Extension field number. 1546 1547 Returns: 1548 Extension field descriptor. 1549 """ 1550 return self._extended_message._extensions_by_number.get(number, None) 1551