1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# http://code.google.com/p/protobuf/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53try: 54 from cStringIO import StringIO 55except ImportError: 56 from StringIO import StringIO 57import copy_reg 58import struct 59import weakref 60 61# We use "as" to avoid name collisions with variables. 62from google.protobuf.internal import containers 63from google.protobuf.internal import decoder 64from google.protobuf.internal import encoder 65from google.protobuf.internal import enum_type_wrapper 66from google.protobuf.internal import message_listener as message_listener_mod 67from google.protobuf.internal import type_checkers 68from google.protobuf.internal import wire_format 69from google.protobuf import descriptor as descriptor_mod 70from google.protobuf import message as message_mod 71from google.protobuf import text_format 72 73_FieldDescriptor = descriptor_mod.FieldDescriptor 74 75 76def NewMessage(bases, descriptor, dictionary): 77 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 78 _AddSlots(descriptor, dictionary) 79 return bases 80 81 82def InitMessage(descriptor, cls): 83 cls._decoders_by_tag = {} 84 cls._extensions_by_name = {} 85 cls._extensions_by_number = {} 86 if (descriptor.has_options and 87 descriptor.GetOptions().message_set_wire_format): 88 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 89 decoder.MessageSetItemDecoder(cls._extensions_by_number)) 90 91 # Attach stuff to each FieldDescriptor for quick lookup later on. 92 for field in descriptor.fields: 93 _AttachFieldHelpers(cls, field) 94 95 _AddEnumValues(descriptor, cls) 96 _AddInitMethod(descriptor, cls) 97 _AddPropertiesForFields(descriptor, cls) 98 _AddPropertiesForExtensions(descriptor, cls) 99 _AddStaticMethods(cls) 100 _AddMessageMethods(descriptor, cls) 101 _AddPrivateHelperMethods(cls) 102 copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) 103 104 105# Stateless helpers for GeneratedProtocolMessageType below. 106# Outside clients should not access these directly. 107# 108# I opted not to make any of these methods on the metaclass, to make it more 109# clear that I'm not really using any state there and to keep clients from 110# thinking that they have direct access to these construction helpers. 111 112 113def _PropertyName(proto_field_name): 114 """Returns the name of the public property attribute which 115 clients can use to get and (in some cases) set the value 116 of a protocol message field. 117 118 Args: 119 proto_field_name: The protocol message field name, exactly 120 as it appears (or would appear) in a .proto file. 121 """ 122 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 123 # nnorwitz makes my day by writing: 124 # """ 125 # FYI. See the keyword module in the stdlib. This could be as simple as: 126 # 127 # if keyword.iskeyword(proto_field_name): 128 # return proto_field_name + "_" 129 # return proto_field_name 130 # """ 131 # Kenton says: The above is a BAD IDEA. People rely on being able to use 132 # getattr() and setattr() to reflectively manipulate field values. If we 133 # rename the properties, then every such user has to also make sure to apply 134 # the same transformation. Note that currently if you name a field "yield", 135 # you can still access it just fine using getattr/setattr -- it's not even 136 # that cumbersome to do so. 137 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 138 # position. 139 return proto_field_name 140 141 142def _VerifyExtensionHandle(message, extension_handle): 143 """Verify that the given extension handle is valid.""" 144 145 if not isinstance(extension_handle, _FieldDescriptor): 146 raise KeyError('HasExtension() expects an extension handle, got: %s' % 147 extension_handle) 148 149 if not extension_handle.is_extension: 150 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) 151 152 if not extension_handle.containing_type: 153 raise KeyError('"%s" is missing a containing_type.' 154 % extension_handle.full_name) 155 156 if extension_handle.containing_type is not message.DESCRIPTOR: 157 raise KeyError('Extension "%s" extends message type "%s", but this ' 158 'message is of type "%s".' % 159 (extension_handle.full_name, 160 extension_handle.containing_type.full_name, 161 message.DESCRIPTOR.full_name)) 162 163 164def _AddSlots(message_descriptor, dictionary): 165 """Adds a __slots__ entry to dictionary, containing the names of all valid 166 attributes for this message type. 167 168 Args: 169 message_descriptor: A Descriptor instance describing this message type. 170 dictionary: Class dictionary to which we'll add a '__slots__' entry. 171 """ 172 dictionary['__slots__'] = ['_cached_byte_size', 173 '_cached_byte_size_dirty', 174 '_fields', 175 '_unknown_fields', 176 '_is_present_in_parent', 177 '_listener', 178 '_listener_for_children', 179 '__weakref__'] 180 181 182def _IsMessageSetExtension(field): 183 return (field.is_extension and 184 field.containing_type.has_options and 185 field.containing_type.GetOptions().message_set_wire_format and 186 field.type == _FieldDescriptor.TYPE_MESSAGE and 187 field.message_type == field.extension_scope and 188 field.label == _FieldDescriptor.LABEL_OPTIONAL) 189 190 191def _AttachFieldHelpers(cls, field_descriptor): 192 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 193 is_packed = (field_descriptor.has_options and 194 field_descriptor.GetOptions().packed) 195 196 if _IsMessageSetExtension(field_descriptor): 197 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 198 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 199 else: 200 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 201 field_descriptor.number, is_repeated, is_packed) 202 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 203 field_descriptor.number, is_repeated, is_packed) 204 205 field_descriptor._encoder = field_encoder 206 field_descriptor._sizer = sizer 207 field_descriptor._default_constructor = _DefaultValueConstructorForField( 208 field_descriptor) 209 210 def AddDecoder(wiretype, is_packed): 211 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 212 cls._decoders_by_tag[tag_bytes] = ( 213 type_checkers.TYPE_TO_DECODER[field_descriptor.type]( 214 field_descriptor.number, is_repeated, is_packed, 215 field_descriptor, field_descriptor._default_constructor)) 216 217 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 218 False) 219 220 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 221 # To support wire compatibility of adding packed = true, add a decoder for 222 # packed values regardless of the field's options. 223 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 224 225 226def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 227 extension_dict = descriptor.extensions_by_name 228 for extension_name, extension_field in extension_dict.iteritems(): 229 assert extension_name not in dictionary 230 dictionary[extension_name] = extension_field 231 232 233def _AddEnumValues(descriptor, cls): 234 """Sets class-level attributes for all enum fields defined in this message. 235 236 Also exporting a class-level object that can name enum values. 237 238 Args: 239 descriptor: Descriptor object for this message type. 240 cls: Class we're constructing for this message type. 241 """ 242 for enum_type in descriptor.enum_types: 243 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 244 for enum_value in enum_type.values: 245 setattr(cls, enum_value.name, enum_value.number) 246 247 248def _DefaultValueConstructorForField(field): 249 """Returns a function which returns a default value for a field. 250 251 Args: 252 field: FieldDescriptor object for this field. 253 254 The returned function has one argument: 255 message: Message instance containing this field, or a weakref proxy 256 of same. 257 258 That function in turn returns a default value for this field. The default 259 value may refer back to |message| via a weak reference. 260 """ 261 262 if field.label == _FieldDescriptor.LABEL_REPEATED: 263 if field.has_default_value and field.default_value != []: 264 raise ValueError('Repeated field default value not empty list: %s' % ( 265 field.default_value)) 266 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 267 # We can't look at _concrete_class yet since it might not have 268 # been set. (Depends on order in which we initialize the classes). 269 message_type = field.message_type 270 def MakeRepeatedMessageDefault(message): 271 return containers.RepeatedCompositeFieldContainer( 272 message._listener_for_children, field.message_type) 273 return MakeRepeatedMessageDefault 274 else: 275 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 276 def MakeRepeatedScalarDefault(message): 277 return containers.RepeatedScalarFieldContainer( 278 message._listener_for_children, type_checker) 279 return MakeRepeatedScalarDefault 280 281 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 282 # _concrete_class may not yet be initialized. 283 message_type = field.message_type 284 def MakeSubMessageDefault(message): 285 result = message_type._concrete_class() 286 result._SetListener(message._listener_for_children) 287 return result 288 return MakeSubMessageDefault 289 290 def MakeScalarDefault(message): 291 # TODO(protobuf-team): This may be broken since there may not be 292 # default_value. Combine with has_default_value somehow. 293 return field.default_value 294 return MakeScalarDefault 295 296 297def _AddInitMethod(message_descriptor, cls): 298 """Adds an __init__ method to cls.""" 299 fields = message_descriptor.fields 300 def init(self, **kwargs): 301 self._cached_byte_size = 0 302 self._cached_byte_size_dirty = len(kwargs) > 0 303 self._fields = {} 304 # _unknown_fields is () when empty for efficiency, and will be turned into 305 # a list if fields are added. 306 self._unknown_fields = () 307 self._is_present_in_parent = False 308 self._listener = message_listener_mod.NullMessageListener() 309 self._listener_for_children = _Listener(self) 310 for field_name, field_value in kwargs.iteritems(): 311 field = _GetFieldByName(message_descriptor, field_name) 312 if field is None: 313 raise TypeError("%s() got an unexpected keyword argument '%s'" % 314 (message_descriptor.name, field_name)) 315 if field.label == _FieldDescriptor.LABEL_REPEATED: 316 copy = field._default_constructor(self) 317 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 318 for val in field_value: 319 copy.add().MergeFrom(val) 320 else: # Scalar 321 copy.extend(field_value) 322 self._fields[field] = copy 323 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 324 copy = field._default_constructor(self) 325 copy.MergeFrom(field_value) 326 self._fields[field] = copy 327 else: 328 setattr(self, field_name, field_value) 329 330 init.__module__ = None 331 init.__doc__ = None 332 cls.__init__ = init 333 334 335def _GetFieldByName(message_descriptor, field_name): 336 """Returns a field descriptor by field name. 337 338 Args: 339 message_descriptor: A Descriptor describing all fields in message. 340 field_name: The name of the field to retrieve. 341 Returns: 342 The field descriptor associated with the field name. 343 """ 344 try: 345 return message_descriptor.fields_by_name[field_name] 346 except KeyError: 347 raise ValueError('Protocol message has no "%s" field.' % field_name) 348 349 350def _AddPropertiesForFields(descriptor, cls): 351 """Adds properties for all fields in this protocol message type.""" 352 for field in descriptor.fields: 353 _AddPropertiesForField(field, cls) 354 355 if descriptor.is_extendable: 356 # _ExtensionDict is just an adaptor with no state so we allocate a new one 357 # every time it is accessed. 358 cls.Extensions = property(lambda self: _ExtensionDict(self)) 359 360 361def _AddPropertiesForField(field, cls): 362 """Adds a public property for a protocol message field. 363 Clients can use this property to get and (in the case 364 of non-repeated scalar fields) directly set the value 365 of a protocol message field. 366 367 Args: 368 field: A FieldDescriptor for this field. 369 cls: The class we're constructing. 370 """ 371 # Catch it if we add other types that we should 372 # handle specially here. 373 assert _FieldDescriptor.MAX_CPPTYPE == 10 374 375 constant_name = field.name.upper() + "_FIELD_NUMBER" 376 setattr(cls, constant_name, field.number) 377 378 if field.label == _FieldDescriptor.LABEL_REPEATED: 379 _AddPropertiesForRepeatedField(field, cls) 380 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 381 _AddPropertiesForNonRepeatedCompositeField(field, cls) 382 else: 383 _AddPropertiesForNonRepeatedScalarField(field, cls) 384 385 386def _AddPropertiesForRepeatedField(field, cls): 387 """Adds a public property for a "repeated" protocol message field. Clients 388 can use this property to get the value of the field, which will be either a 389 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 390 below). 391 392 Note that when clients add values to these containers, we perform 393 type-checking in the case of repeated scalar fields, and we also set any 394 necessary "has" bits as a side-effect. 395 396 Args: 397 field: A FieldDescriptor for this field. 398 cls: The class we're constructing. 399 """ 400 proto_field_name = field.name 401 property_name = _PropertyName(proto_field_name) 402 403 def getter(self): 404 field_value = self._fields.get(field) 405 if field_value is None: 406 # Construct a new object to represent this field. 407 field_value = field._default_constructor(self) 408 409 # Atomically check if another thread has preempted us and, if not, swap 410 # in the new object we just created. If someone has preempted us, we 411 # take that object and discard ours. 412 # WARNING: We are relying on setdefault() being atomic. This is true 413 # in CPython but we haven't investigated others. This warning appears 414 # in several other locations in this file. 415 field_value = self._fields.setdefault(field, field_value) 416 return field_value 417 getter.__module__ = None 418 getter.__doc__ = 'Getter for %s.' % proto_field_name 419 420 # We define a setter just so we can throw an exception with a more 421 # helpful error message. 422 def setter(self, new_value): 423 raise AttributeError('Assignment not allowed to repeated field ' 424 '"%s" in protocol message object.' % proto_field_name) 425 426 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 427 setattr(cls, property_name, property(getter, setter, doc=doc)) 428 429 430def _AddPropertiesForNonRepeatedScalarField(field, cls): 431 """Adds a public property for a nonrepeated, scalar protocol message field. 432 Clients can use this property to get and directly set the value of the field. 433 Note that when the client sets the value of a field by using this property, 434 all necessary "has" bits are set as a side-effect, and we also perform 435 type-checking. 436 437 Args: 438 field: A FieldDescriptor for this field. 439 cls: The class we're constructing. 440 """ 441 proto_field_name = field.name 442 property_name = _PropertyName(proto_field_name) 443 type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) 444 default_value = field.default_value 445 valid_values = set() 446 447 def getter(self): 448 # TODO(protobuf-team): This may be broken since there may not be 449 # default_value. Combine with has_default_value somehow. 450 return self._fields.get(field, default_value) 451 getter.__module__ = None 452 getter.__doc__ = 'Getter for %s.' % proto_field_name 453 def setter(self, new_value): 454 type_checker.CheckValue(new_value) 455 self._fields[field] = new_value 456 # Check _cached_byte_size_dirty inline to improve performance, since scalar 457 # setters are called frequently. 458 if not self._cached_byte_size_dirty: 459 self._Modified() 460 461 setter.__module__ = None 462 setter.__doc__ = 'Setter for %s.' % proto_field_name 463 464 # Add a property to encapsulate the getter/setter. 465 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 466 setattr(cls, property_name, property(getter, setter, doc=doc)) 467 468 469def _AddPropertiesForNonRepeatedCompositeField(field, cls): 470 """Adds a public property for a nonrepeated, composite protocol message field. 471 A composite field is a "group" or "message" field. 472 473 Clients can use this property to get the value of the field, but cannot 474 assign to the property directly. 475 476 Args: 477 field: A FieldDescriptor for this field. 478 cls: The class we're constructing. 479 """ 480 # TODO(robinson): Remove duplication with similar method 481 # for non-repeated scalars. 482 proto_field_name = field.name 483 property_name = _PropertyName(proto_field_name) 484 485 # TODO(komarek): Can anyone explain to me why we cache the message_type this 486 # way, instead of referring to field.message_type inside of getter(self)? 487 # What if someone sets message_type later on (which makes for simpler 488 # dyanmic proto descriptor and class creation code). 489 message_type = field.message_type 490 491 def getter(self): 492 field_value = self._fields.get(field) 493 if field_value is None: 494 # Construct a new object to represent this field. 495 field_value = message_type._concrete_class() # use field.message_type? 496 field_value._SetListener(self._listener_for_children) 497 498 # Atomically check if another thread has preempted us and, if not, swap 499 # in the new object we just created. If someone has preempted us, we 500 # take that object and discard ours. 501 # WARNING: We are relying on setdefault() being atomic. This is true 502 # in CPython but we haven't investigated others. This warning appears 503 # in several other locations in this file. 504 field_value = self._fields.setdefault(field, field_value) 505 return field_value 506 getter.__module__ = None 507 getter.__doc__ = 'Getter for %s.' % proto_field_name 508 509 # We define a setter just so we can throw an exception with a more 510 # helpful error message. 511 def setter(self, new_value): 512 raise AttributeError('Assignment not allowed to composite field ' 513 '"%s" in protocol message object.' % proto_field_name) 514 515 # Add a property to encapsulate the getter. 516 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 517 setattr(cls, property_name, property(getter, setter, doc=doc)) 518 519 520def _AddPropertiesForExtensions(descriptor, cls): 521 """Adds properties for all fields in this protocol message type.""" 522 extension_dict = descriptor.extensions_by_name 523 for extension_name, extension_field in extension_dict.iteritems(): 524 constant_name = extension_name.upper() + "_FIELD_NUMBER" 525 setattr(cls, constant_name, extension_field.number) 526 527 528def _AddStaticMethods(cls): 529 # TODO(robinson): This probably needs to be thread-safe(?) 530 def RegisterExtension(extension_handle): 531 extension_handle.containing_type = cls.DESCRIPTOR 532 _AttachFieldHelpers(cls, extension_handle) 533 534 # Try to insert our extension, failing if an extension with the same number 535 # already exists. 536 actual_handle = cls._extensions_by_number.setdefault( 537 extension_handle.number, extension_handle) 538 if actual_handle is not extension_handle: 539 raise AssertionError( 540 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 541 'field number %d.' % 542 (extension_handle.full_name, actual_handle.full_name, 543 cls.DESCRIPTOR.full_name, extension_handle.number)) 544 545 cls._extensions_by_name[extension_handle.full_name] = extension_handle 546 547 handle = extension_handle # avoid line wrapping 548 if _IsMessageSetExtension(handle): 549 # MessageSet extension. Also register under type name. 550 cls._extensions_by_name[ 551 extension_handle.message_type.full_name] = extension_handle 552 553 cls.RegisterExtension = staticmethod(RegisterExtension) 554 555 def FromString(s): 556 message = cls() 557 message.MergeFromString(s) 558 return message 559 cls.FromString = staticmethod(FromString) 560 561 562def _IsPresent(item): 563 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 564 value should be included in the list returned by ListFields().""" 565 566 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 567 return bool(item[1]) 568 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 569 return item[1]._is_present_in_parent 570 else: 571 return True 572 573 574def _AddListFieldsMethod(message_descriptor, cls): 575 """Helper for _AddMessageMethods().""" 576 577 def ListFields(self): 578 all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] 579 all_fields.sort(key = lambda item: item[0].number) 580 return all_fields 581 582 cls.ListFields = ListFields 583 584 585def _AddHasFieldMethod(message_descriptor, cls): 586 """Helper for _AddMessageMethods().""" 587 588 singular_fields = {} 589 for field in message_descriptor.fields: 590 if field.label != _FieldDescriptor.LABEL_REPEATED: 591 singular_fields[field.name] = field 592 593 def HasField(self, field_name): 594 try: 595 field = singular_fields[field_name] 596 except KeyError: 597 raise ValueError( 598 'Protocol message has no singular "%s" field.' % field_name) 599 600 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 601 value = self._fields.get(field) 602 return value is not None and value._is_present_in_parent 603 else: 604 return field in self._fields 605 cls.HasField = HasField 606 607 608def _AddClearFieldMethod(message_descriptor, cls): 609 """Helper for _AddMessageMethods().""" 610 def ClearField(self, field_name): 611 try: 612 field = message_descriptor.fields_by_name[field_name] 613 except KeyError: 614 raise ValueError('Protocol message has no "%s" field.' % field_name) 615 616 if field in self._fields: 617 # Note: If the field is a sub-message, its listener will still point 618 # at us. That's fine, because the worst than can happen is that it 619 # will call _Modified() and invalidate our byte size. Big deal. 620 del self._fields[field] 621 622 # Always call _Modified() -- even if nothing was changed, this is 623 # a mutating method, and thus calling it should cause the field to become 624 # present in the parent message. 625 self._Modified() 626 627 cls.ClearField = ClearField 628 629 630def _AddClearExtensionMethod(cls): 631 """Helper for _AddMessageMethods().""" 632 def ClearExtension(self, extension_handle): 633 _VerifyExtensionHandle(self, extension_handle) 634 635 # Similar to ClearField(), above. 636 if extension_handle in self._fields: 637 del self._fields[extension_handle] 638 self._Modified() 639 cls.ClearExtension = ClearExtension 640 641 642def _AddClearMethod(message_descriptor, cls): 643 """Helper for _AddMessageMethods().""" 644 def Clear(self): 645 # Clear fields. 646 self._fields = {} 647 self._unknown_fields = () 648 self._Modified() 649 cls.Clear = Clear 650 651 652def _AddHasExtensionMethod(cls): 653 """Helper for _AddMessageMethods().""" 654 def HasExtension(self, extension_handle): 655 _VerifyExtensionHandle(self, extension_handle) 656 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 657 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 658 659 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 660 value = self._fields.get(extension_handle) 661 return value is not None and value._is_present_in_parent 662 else: 663 return extension_handle in self._fields 664 cls.HasExtension = HasExtension 665 666 667def _AddEqualsMethod(message_descriptor, cls): 668 """Helper for _AddMessageMethods().""" 669 def __eq__(self, other): 670 if (not isinstance(other, message_mod.Message) or 671 other.DESCRIPTOR != self.DESCRIPTOR): 672 return False 673 674 if self is other: 675 return True 676 677 if not self.ListFields() == other.ListFields(): 678 return False 679 680 # Sort unknown fields because their order shouldn't affect equality test. 681 unknown_fields = list(self._unknown_fields) 682 unknown_fields.sort() 683 other_unknown_fields = list(other._unknown_fields) 684 other_unknown_fields.sort() 685 686 return unknown_fields == other_unknown_fields 687 688 cls.__eq__ = __eq__ 689 690 691def _AddStrMethod(message_descriptor, cls): 692 """Helper for _AddMessageMethods().""" 693 def __str__(self): 694 return text_format.MessageToString(self) 695 cls.__str__ = __str__ 696 697 698def _AddUnicodeMethod(unused_message_descriptor, cls): 699 """Helper for _AddMessageMethods().""" 700 701 def __unicode__(self): 702 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 703 cls.__unicode__ = __unicode__ 704 705 706def _AddSetListenerMethod(cls): 707 """Helper for _AddMessageMethods().""" 708 def SetListener(self, listener): 709 if listener is None: 710 self._listener = message_listener_mod.NullMessageListener() 711 else: 712 self._listener = listener 713 cls._SetListener = SetListener 714 715 716def _BytesForNonRepeatedElement(value, field_number, field_type): 717 """Returns the number of bytes needed to serialize a non-repeated element. 718 The returned byte count includes space for tag information and any 719 other additional space associated with serializing value. 720 721 Args: 722 value: Value we're serializing. 723 field_number: Field number of this value. (Since the field number 724 is stored as part of a varint-encoded tag, this has an impact 725 on the total bytes required to serialize the value). 726 field_type: The type of the field. One of the TYPE_* constants 727 within FieldDescriptor. 728 """ 729 try: 730 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 731 return fn(field_number, value) 732 except KeyError: 733 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 734 735 736def _AddByteSizeMethod(message_descriptor, cls): 737 """Helper for _AddMessageMethods().""" 738 739 def ByteSize(self): 740 if not self._cached_byte_size_dirty: 741 return self._cached_byte_size 742 743 size = 0 744 for field_descriptor, field_value in self.ListFields(): 745 size += field_descriptor._sizer(field_value) 746 747 for tag_bytes, value_bytes in self._unknown_fields: 748 size += len(tag_bytes) + len(value_bytes) 749 750 self._cached_byte_size = size 751 self._cached_byte_size_dirty = False 752 self._listener_for_children.dirty = False 753 return size 754 755 cls.ByteSize = ByteSize 756 757 758def _AddSerializeToStringMethod(message_descriptor, cls): 759 """Helper for _AddMessageMethods().""" 760 761 def SerializeToString(self): 762 # Check if the message has all of its required fields set. 763 errors = [] 764 if not self.IsInitialized(): 765 raise message_mod.EncodeError( 766 'Message %s is missing required fields: %s' % ( 767 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 768 return self.SerializePartialToString() 769 cls.SerializeToString = SerializeToString 770 771 772def _AddSerializePartialToStringMethod(message_descriptor, cls): 773 """Helper for _AddMessageMethods().""" 774 775 def SerializePartialToString(self): 776 out = StringIO() 777 self._InternalSerialize(out.write) 778 return out.getvalue() 779 cls.SerializePartialToString = SerializePartialToString 780 781 def InternalSerialize(self, write_bytes): 782 for field_descriptor, field_value in self.ListFields(): 783 field_descriptor._encoder(write_bytes, field_value) 784 for tag_bytes, value_bytes in self._unknown_fields: 785 write_bytes(tag_bytes) 786 write_bytes(value_bytes) 787 cls._InternalSerialize = InternalSerialize 788 789 790def _AddMergeFromStringMethod(message_descriptor, cls): 791 """Helper for _AddMessageMethods().""" 792 def MergeFromString(self, serialized): 793 length = len(serialized) 794 try: 795 if self._InternalParse(serialized, 0, length) != length: 796 # The only reason _InternalParse would return early is if it 797 # encountered an end-group tag. 798 raise message_mod.DecodeError('Unexpected end-group tag.') 799 except IndexError: 800 raise message_mod.DecodeError('Truncated message.') 801 except struct.error, e: 802 raise message_mod.DecodeError(e) 803 return length # Return this for legacy reasons. 804 cls.MergeFromString = MergeFromString 805 806 local_ReadTag = decoder.ReadTag 807 local_SkipField = decoder.SkipField 808 decoders_by_tag = cls._decoders_by_tag 809 810 def InternalParse(self, buffer, pos, end): 811 self._Modified() 812 field_dict = self._fields 813 unknown_field_list = self._unknown_fields 814 while pos != end: 815 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 816 field_decoder = decoders_by_tag.get(tag_bytes) 817 if field_decoder is None: 818 value_start_pos = new_pos 819 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 820 if new_pos == -1: 821 return pos 822 if not unknown_field_list: 823 unknown_field_list = self._unknown_fields = [] 824 unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) 825 pos = new_pos 826 else: 827 pos = field_decoder(buffer, new_pos, end, self, field_dict) 828 return pos 829 cls._InternalParse = InternalParse 830 831 832def _AddIsInitializedMethod(message_descriptor, cls): 833 """Adds the IsInitialized and FindInitializationError methods to the 834 protocol message class.""" 835 836 required_fields = [field for field in message_descriptor.fields 837 if field.label == _FieldDescriptor.LABEL_REQUIRED] 838 839 def IsInitialized(self, errors=None): 840 """Checks if all required fields of a message are set. 841 842 Args: 843 errors: A list which, if provided, will be populated with the field 844 paths of all missing required fields. 845 846 Returns: 847 True iff the specified message has all required fields set. 848 """ 849 850 # Performance is critical so we avoid HasField() and ListFields(). 851 852 for field in required_fields: 853 if (field not in self._fields or 854 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 855 not self._fields[field]._is_present_in_parent)): 856 if errors is not None: 857 errors.extend(self.FindInitializationErrors()) 858 return False 859 860 for field, value in self._fields.iteritems(): 861 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 862 if field.label == _FieldDescriptor.LABEL_REPEATED: 863 for element in value: 864 if not element.IsInitialized(): 865 if errors is not None: 866 errors.extend(self.FindInitializationErrors()) 867 return False 868 elif value._is_present_in_parent and not value.IsInitialized(): 869 if errors is not None: 870 errors.extend(self.FindInitializationErrors()) 871 return False 872 873 return True 874 875 cls.IsInitialized = IsInitialized 876 877 def FindInitializationErrors(self): 878 """Finds required fields which are not initialized. 879 880 Returns: 881 A list of strings. Each string is a path to an uninitialized field from 882 the top-level message, e.g. "foo.bar[5].baz". 883 """ 884 885 errors = [] # simplify things 886 887 for field in required_fields: 888 if not self.HasField(field.name): 889 errors.append(field.name) 890 891 for field, value in self.ListFields(): 892 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 893 if field.is_extension: 894 name = "(%s)" % field.full_name 895 else: 896 name = field.name 897 898 if field.label == _FieldDescriptor.LABEL_REPEATED: 899 for i in xrange(len(value)): 900 element = value[i] 901 prefix = "%s[%d]." % (name, i) 902 sub_errors = element.FindInitializationErrors() 903 errors += [ prefix + error for error in sub_errors ] 904 else: 905 prefix = name + "." 906 sub_errors = value.FindInitializationErrors() 907 errors += [ prefix + error for error in sub_errors ] 908 909 return errors 910 911 cls.FindInitializationErrors = FindInitializationErrors 912 913 914def _AddMergeFromMethod(cls): 915 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 916 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 917 918 def MergeFrom(self, msg): 919 if not isinstance(msg, cls): 920 raise TypeError( 921 "Parameter to MergeFrom() must be instance of same class: " 922 "expected %s got %s." % (cls.__name__, type(msg).__name__)) 923 924 assert msg is not self 925 self._Modified() 926 927 fields = self._fields 928 929 for field, value in msg._fields.iteritems(): 930 if field.label == LABEL_REPEATED: 931 field_value = fields.get(field) 932 if field_value is None: 933 # Construct a new object to represent this field. 934 field_value = field._default_constructor(self) 935 fields[field] = field_value 936 field_value.MergeFrom(value) 937 elif field.cpp_type == CPPTYPE_MESSAGE: 938 if value._is_present_in_parent: 939 field_value = fields.get(field) 940 if field_value is None: 941 # Construct a new object to represent this field. 942 field_value = field._default_constructor(self) 943 fields[field] = field_value 944 field_value.MergeFrom(value) 945 else: 946 self._fields[field] = value 947 948 if msg._unknown_fields: 949 if not self._unknown_fields: 950 self._unknown_fields = [] 951 self._unknown_fields.extend(msg._unknown_fields) 952 953 cls.MergeFrom = MergeFrom 954 955 956def _AddMessageMethods(message_descriptor, cls): 957 """Adds implementations of all Message methods to cls.""" 958 _AddListFieldsMethod(message_descriptor, cls) 959 _AddHasFieldMethod(message_descriptor, cls) 960 _AddClearFieldMethod(message_descriptor, cls) 961 if message_descriptor.is_extendable: 962 _AddClearExtensionMethod(cls) 963 _AddHasExtensionMethod(cls) 964 _AddClearMethod(message_descriptor, cls) 965 _AddEqualsMethod(message_descriptor, cls) 966 _AddStrMethod(message_descriptor, cls) 967 _AddUnicodeMethod(message_descriptor, cls) 968 _AddSetListenerMethod(cls) 969 _AddByteSizeMethod(message_descriptor, cls) 970 _AddSerializeToStringMethod(message_descriptor, cls) 971 _AddSerializePartialToStringMethod(message_descriptor, cls) 972 _AddMergeFromStringMethod(message_descriptor, cls) 973 _AddIsInitializedMethod(message_descriptor, cls) 974 _AddMergeFromMethod(cls) 975 976 977def _AddPrivateHelperMethods(cls): 978 """Adds implementation of private helper methods to cls.""" 979 980 def Modified(self): 981 """Sets the _cached_byte_size_dirty bit to true, 982 and propagates this to our listener iff this was a state change. 983 """ 984 985 # Note: Some callers check _cached_byte_size_dirty before calling 986 # _Modified() as an extra optimization. So, if this method is ever 987 # changed such that it does stuff even when _cached_byte_size_dirty is 988 # already true, the callers need to be updated. 989 if not self._cached_byte_size_dirty: 990 self._cached_byte_size_dirty = True 991 self._listener_for_children.dirty = True 992 self._is_present_in_parent = True 993 self._listener.Modified() 994 995 cls._Modified = Modified 996 cls.SetInParent = Modified 997 998 999class _Listener(object): 1000 1001 """MessageListener implementation that a parent message registers with its 1002 child message. 1003 1004 In order to support semantics like: 1005 1006 foo.bar.baz.qux = 23 1007 assert foo.HasField('bar') 1008 1009 ...child objects must have back references to their parents. 1010 This helper class is at the heart of this support. 1011 """ 1012 1013 def __init__(self, parent_message): 1014 """Args: 1015 parent_message: The message whose _Modified() method we should call when 1016 we receive Modified() messages. 1017 """ 1018 # This listener establishes a back reference from a child (contained) object 1019 # to its parent (containing) object. We make this a weak reference to avoid 1020 # creating cyclic garbage when the client finishes with the 'parent' object 1021 # in the tree. 1022 if isinstance(parent_message, weakref.ProxyType): 1023 self._parent_message_weakref = parent_message 1024 else: 1025 self._parent_message_weakref = weakref.proxy(parent_message) 1026 1027 # As an optimization, we also indicate directly on the listener whether 1028 # or not the parent message is dirty. This way we can avoid traversing 1029 # up the tree in the common case. 1030 self.dirty = False 1031 1032 def Modified(self): 1033 if self.dirty: 1034 return 1035 try: 1036 # Propagate the signal to our parents iff this is the first field set. 1037 self._parent_message_weakref._Modified() 1038 except ReferenceError: 1039 # We can get here if a client has kept a reference to a child object, 1040 # and is now setting a field on it, but the child's parent has been 1041 # garbage-collected. This is not an error. 1042 pass 1043 1044 1045# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1046# TODO(robinson): Unify error handling of "unknown extension" crap. 1047# TODO(robinson): Support iteritems()-style iteration over all 1048# extensions with the "has" bits turned on? 1049class _ExtensionDict(object): 1050 1051 """Dict-like container for supporting an indexable "Extensions" 1052 field on proto instances. 1053 1054 Note that in all cases we expect extension handles to be 1055 FieldDescriptors. 1056 """ 1057 1058 def __init__(self, extended_message): 1059 """extended_message: Message instance for which we are the Extensions dict. 1060 """ 1061 1062 self._extended_message = extended_message 1063 1064 def __getitem__(self, extension_handle): 1065 """Returns the current value of the given extension handle.""" 1066 1067 _VerifyExtensionHandle(self._extended_message, extension_handle) 1068 1069 result = self._extended_message._fields.get(extension_handle) 1070 if result is not None: 1071 return result 1072 1073 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 1074 result = extension_handle._default_constructor(self._extended_message) 1075 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1076 result = extension_handle.message_type._concrete_class() 1077 try: 1078 result._SetListener(self._extended_message._listener_for_children) 1079 except ReferenceError: 1080 pass 1081 else: 1082 # Singular scalar -- just return the default without inserting into the 1083 # dict. 1084 return extension_handle.default_value 1085 1086 # Atomically check if another thread has preempted us and, if not, swap 1087 # in the new object we just created. If someone has preempted us, we 1088 # take that object and discard ours. 1089 # WARNING: We are relying on setdefault() being atomic. This is true 1090 # in CPython but we haven't investigated others. This warning appears 1091 # in several other locations in this file. 1092 result = self._extended_message._fields.setdefault( 1093 extension_handle, result) 1094 1095 return result 1096 1097 def __eq__(self, other): 1098 if not isinstance(other, self.__class__): 1099 return False 1100 1101 my_fields = self._extended_message.ListFields() 1102 other_fields = other._extended_message.ListFields() 1103 1104 # Get rid of non-extension fields. 1105 my_fields = [ field for field in my_fields if field.is_extension ] 1106 other_fields = [ field for field in other_fields if field.is_extension ] 1107 1108 return my_fields == other_fields 1109 1110 def __ne__(self, other): 1111 return not self == other 1112 1113 def __hash__(self): 1114 raise TypeError('unhashable object') 1115 1116 # Note that this is only meaningful for non-repeated, scalar extension 1117 # fields. Note also that we may have to call _Modified() when we do 1118 # successfully set a field this way, to set any necssary "has" bits in the 1119 # ancestors of the extended message. 1120 def __setitem__(self, extension_handle, value): 1121 """If extension_handle specifies a non-repeated, scalar extension 1122 field, sets the value of that field. 1123 """ 1124 1125 _VerifyExtensionHandle(self._extended_message, extension_handle) 1126 1127 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1128 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1129 raise TypeError( 1130 'Cannot assign to extension "%s" because it is a repeated or ' 1131 'composite type.' % extension_handle.full_name) 1132 1133 # It's slightly wasteful to lookup the type checker each time, 1134 # but we expect this to be a vanishingly uncommon case anyway. 1135 type_checker = type_checkers.GetTypeChecker( 1136 extension_handle.cpp_type, extension_handle.type) 1137 type_checker.CheckValue(value) 1138 self._extended_message._fields[extension_handle] = value 1139 self._extended_message._Modified() 1140 1141 def _FindExtensionByName(self, name): 1142 """Tries to find a known extension with the specified name. 1143 1144 Args: 1145 name: Extension full name. 1146 1147 Returns: 1148 Extension field descriptor. 1149 """ 1150 return self._extended_message._extensions_by_name.get(name, None) 1151