1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38import copy 39import gc 40import operator 41import six 42import struct 43 44try: 45 import unittest2 as unittest #PY26 46except ImportError: 47 import unittest 48 49from google.protobuf import unittest_import_pb2 50from google.protobuf import unittest_mset_pb2 51from google.protobuf import unittest_pb2 52from google.protobuf import descriptor_pb2 53from google.protobuf import descriptor 54from google.protobuf import message 55from google.protobuf import reflection 56from google.protobuf import text_format 57from google.protobuf.internal import api_implementation 58from google.protobuf.internal import more_extensions_pb2 59from google.protobuf.internal import more_messages_pb2 60from google.protobuf.internal import message_set_extensions_pb2 61from google.protobuf.internal import wire_format 62from google.protobuf.internal import test_util 63from google.protobuf.internal import decoder 64 65 66class _MiniDecoder(object): 67 """Decodes a stream of values from a string. 68 69 Once upon a time we actually had a class called decoder.Decoder. Then we 70 got rid of it during a redesign that made decoding much, much faster overall. 71 But a couple tests in this file used it to check that the serialized form of 72 a message was correct. So, this class implements just the methods that were 73 used by said tests, so that we don't have to rewrite the tests. 74 """ 75 76 def __init__(self, bytes): 77 self._bytes = bytes 78 self._pos = 0 79 80 def ReadVarint(self): 81 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 82 return result 83 84 ReadInt32 = ReadVarint 85 ReadInt64 = ReadVarint 86 ReadUInt32 = ReadVarint 87 ReadUInt64 = ReadVarint 88 89 def ReadSInt64(self): 90 return wire_format.ZigZagDecode(self.ReadVarint()) 91 92 ReadSInt32 = ReadSInt64 93 94 def ReadFieldNumberAndWireType(self): 95 return wire_format.UnpackTag(self.ReadVarint()) 96 97 def ReadFloat(self): 98 result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] 99 self._pos += 4 100 return result 101 102 def ReadDouble(self): 103 result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] 104 self._pos += 8 105 return result 106 107 def EndOfStream(self): 108 return self._pos == len(self._bytes) 109 110 111class ReflectionTest(unittest.TestCase): 112 113 def assertListsEqual(self, values, others): 114 self.assertEqual(len(values), len(others)) 115 for i in range(len(values)): 116 self.assertEqual(values[i], others[i]) 117 118 def testScalarConstructor(self): 119 # Constructor with only scalar types should succeed. 120 proto = unittest_pb2.TestAllTypes( 121 optional_int32=24, 122 optional_double=54.321, 123 optional_string='optional_string', 124 optional_float=None) 125 126 self.assertEqual(24, proto.optional_int32) 127 self.assertEqual(54.321, proto.optional_double) 128 self.assertEqual('optional_string', proto.optional_string) 129 self.assertFalse(proto.HasField("optional_float")) 130 131 def testRepeatedScalarConstructor(self): 132 # Constructor with only repeated scalar types should succeed. 133 proto = unittest_pb2.TestAllTypes( 134 repeated_int32=[1, 2, 3, 4], 135 repeated_double=[1.23, 54.321], 136 repeated_bool=[True, False, False], 137 repeated_string=["optional_string"], 138 repeated_float=None) 139 140 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 141 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 142 self.assertEqual([True, False, False], list(proto.repeated_bool)) 143 self.assertEqual(["optional_string"], list(proto.repeated_string)) 144 self.assertEqual([], list(proto.repeated_float)) 145 146 def testRepeatedCompositeConstructor(self): 147 # Constructor with only repeated composite types should succeed. 148 proto = unittest_pb2.TestAllTypes( 149 repeated_nested_message=[ 150 unittest_pb2.TestAllTypes.NestedMessage( 151 bb=unittest_pb2.TestAllTypes.FOO), 152 unittest_pb2.TestAllTypes.NestedMessage( 153 bb=unittest_pb2.TestAllTypes.BAR)], 154 repeated_foreign_message=[ 155 unittest_pb2.ForeignMessage(c=-43), 156 unittest_pb2.ForeignMessage(c=45324), 157 unittest_pb2.ForeignMessage(c=12)], 158 repeatedgroup=[ 159 unittest_pb2.TestAllTypes.RepeatedGroup(), 160 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 161 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 162 163 self.assertEqual( 164 [unittest_pb2.TestAllTypes.NestedMessage( 165 bb=unittest_pb2.TestAllTypes.FOO), 166 unittest_pb2.TestAllTypes.NestedMessage( 167 bb=unittest_pb2.TestAllTypes.BAR)], 168 list(proto.repeated_nested_message)) 169 self.assertEqual( 170 [unittest_pb2.ForeignMessage(c=-43), 171 unittest_pb2.ForeignMessage(c=45324), 172 unittest_pb2.ForeignMessage(c=12)], 173 list(proto.repeated_foreign_message)) 174 self.assertEqual( 175 [unittest_pb2.TestAllTypes.RepeatedGroup(), 176 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 177 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 178 list(proto.repeatedgroup)) 179 180 def testMixedConstructor(self): 181 # Constructor with only mixed types should succeed. 182 proto = unittest_pb2.TestAllTypes( 183 optional_int32=24, 184 optional_string='optional_string', 185 repeated_double=[1.23, 54.321], 186 repeated_bool=[True, False, False], 187 repeated_nested_message=[ 188 unittest_pb2.TestAllTypes.NestedMessage( 189 bb=unittest_pb2.TestAllTypes.FOO), 190 unittest_pb2.TestAllTypes.NestedMessage( 191 bb=unittest_pb2.TestAllTypes.BAR)], 192 repeated_foreign_message=[ 193 unittest_pb2.ForeignMessage(c=-43), 194 unittest_pb2.ForeignMessage(c=45324), 195 unittest_pb2.ForeignMessage(c=12)], 196 optional_nested_message=None) 197 198 self.assertEqual(24, proto.optional_int32) 199 self.assertEqual('optional_string', proto.optional_string) 200 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 201 self.assertEqual([True, False, False], list(proto.repeated_bool)) 202 self.assertEqual( 203 [unittest_pb2.TestAllTypes.NestedMessage( 204 bb=unittest_pb2.TestAllTypes.FOO), 205 unittest_pb2.TestAllTypes.NestedMessage( 206 bb=unittest_pb2.TestAllTypes.BAR)], 207 list(proto.repeated_nested_message)) 208 self.assertEqual( 209 [unittest_pb2.ForeignMessage(c=-43), 210 unittest_pb2.ForeignMessage(c=45324), 211 unittest_pb2.ForeignMessage(c=12)], 212 list(proto.repeated_foreign_message)) 213 self.assertFalse(proto.HasField("optional_nested_message")) 214 215 def testConstructorTypeError(self): 216 self.assertRaises( 217 TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") 218 self.assertRaises( 219 TypeError, unittest_pb2.TestAllTypes, optional_string=1234) 220 self.assertRaises( 221 TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) 222 self.assertRaises( 223 TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) 224 self.assertRaises( 225 TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) 226 self.assertRaises( 227 TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) 228 self.assertRaises( 229 TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) 230 self.assertRaises( 231 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) 232 self.assertRaises( 233 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) 234 235 def testConstructorInvalidatesCachedByteSize(self): 236 message = unittest_pb2.TestAllTypes(optional_int32 = 12) 237 self.assertEqual(2, message.ByteSize()) 238 239 message = unittest_pb2.TestAllTypes( 240 optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) 241 self.assertEqual(3, message.ByteSize()) 242 243 message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) 244 self.assertEqual(3, message.ByteSize()) 245 246 message = unittest_pb2.TestAllTypes( 247 repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) 248 self.assertEqual(3, message.ByteSize()) 249 250 def testSimpleHasBits(self): 251 # Test a scalar. 252 proto = unittest_pb2.TestAllTypes() 253 self.assertTrue(not proto.HasField('optional_int32')) 254 self.assertEqual(0, proto.optional_int32) 255 # HasField() shouldn't be true if all we've done is 256 # read the default value. 257 self.assertTrue(not proto.HasField('optional_int32')) 258 proto.optional_int32 = 1 259 # Setting a value however *should* set the "has" bit. 260 self.assertTrue(proto.HasField('optional_int32')) 261 proto.ClearField('optional_int32') 262 # And clearing that value should unset the "has" bit. 263 self.assertTrue(not proto.HasField('optional_int32')) 264 265 def testHasBitsWithSinglyNestedScalar(self): 266 # Helper used to test foreign messages and groups. 267 # 268 # composite_field_name should be the name of a non-repeated 269 # composite (i.e., foreign or group) field in TestAllTypes, 270 # and scalar_field_name should be the name of an integer-valued 271 # scalar field within that composite. 272 # 273 # I never thought I'd miss C++ macros and templates so much. :( 274 # This helper is semantically just: 275 # 276 # assert proto.composite_field.scalar_field == 0 277 # assert not proto.composite_field.HasField('scalar_field') 278 # assert not proto.HasField('composite_field') 279 # 280 # proto.composite_field.scalar_field = 10 281 # old_composite_field = proto.composite_field 282 # 283 # assert proto.composite_field.scalar_field == 10 284 # assert proto.composite_field.HasField('scalar_field') 285 # assert proto.HasField('composite_field') 286 # 287 # proto.ClearField('composite_field') 288 # 289 # assert not proto.composite_field.HasField('scalar_field') 290 # assert not proto.HasField('composite_field') 291 # assert proto.composite_field.scalar_field == 0 292 # 293 # # Now ensure that ClearField('composite_field') disconnected 294 # # the old field object from the object tree... 295 # assert old_composite_field is not proto.composite_field 296 # old_composite_field.scalar_field = 20 297 # assert not proto.composite_field.HasField('scalar_field') 298 # assert not proto.HasField('composite_field') 299 def TestCompositeHasBits(composite_field_name, scalar_field_name): 300 proto = unittest_pb2.TestAllTypes() 301 # First, check that we can get the scalar value, and see that it's the 302 # default (0), but that proto.HasField('omposite') and 303 # proto.composite.HasField('scalar') will still return False. 304 composite_field = getattr(proto, composite_field_name) 305 original_scalar_value = getattr(composite_field, scalar_field_name) 306 self.assertEqual(0, original_scalar_value) 307 # Assert that the composite object does not "have" the scalar. 308 self.assertTrue(not composite_field.HasField(scalar_field_name)) 309 # Assert that proto does not "have" the composite field. 310 self.assertTrue(not proto.HasField(composite_field_name)) 311 312 # Now set the scalar within the composite field. Ensure that the setting 313 # is reflected, and that proto.HasField('composite') and 314 # proto.composite.HasField('scalar') now both return True. 315 new_val = 20 316 setattr(composite_field, scalar_field_name, new_val) 317 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 318 # Hold on to a reference to the current composite_field object. 319 old_composite_field = composite_field 320 # Assert that the has methods now return true. 321 self.assertTrue(composite_field.HasField(scalar_field_name)) 322 self.assertTrue(proto.HasField(composite_field_name)) 323 324 # Now call the clear method... 325 proto.ClearField(composite_field_name) 326 327 # ...and ensure that the "has" bits are all back to False... 328 composite_field = getattr(proto, composite_field_name) 329 self.assertTrue(not composite_field.HasField(scalar_field_name)) 330 self.assertTrue(not proto.HasField(composite_field_name)) 331 # ...and ensure that the scalar field has returned to its default. 332 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 333 334 self.assertTrue(old_composite_field is not composite_field) 335 setattr(old_composite_field, scalar_field_name, new_val) 336 self.assertTrue(not composite_field.HasField(scalar_field_name)) 337 self.assertTrue(not proto.HasField(composite_field_name)) 338 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 339 340 # Test simple, single-level nesting when we set a scalar. 341 TestCompositeHasBits('optionalgroup', 'a') 342 TestCompositeHasBits('optional_nested_message', 'bb') 343 TestCompositeHasBits('optional_foreign_message', 'c') 344 TestCompositeHasBits('optional_import_message', 'd') 345 346 def testReferencesToNestedMessage(self): 347 proto = unittest_pb2.TestAllTypes() 348 nested = proto.optional_nested_message 349 del proto 350 # A previous version had a bug where this would raise an exception when 351 # hitting a now-dead weak reference. 352 nested.bb = 23 353 354 def testDisconnectingNestedMessageBeforeSettingField(self): 355 proto = unittest_pb2.TestAllTypes() 356 nested = proto.optional_nested_message 357 proto.ClearField('optional_nested_message') # Should disconnect from parent 358 self.assertTrue(nested is not proto.optional_nested_message) 359 nested.bb = 23 360 self.assertTrue(not proto.HasField('optional_nested_message')) 361 self.assertEqual(0, proto.optional_nested_message.bb) 362 363 def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): 364 proto = unittest_pb2.TestAllTypes() 365 nested = proto.optional_nested_message 366 proto.ClearField('optional_nested_message') 367 del proto 368 del nested 369 # Force a garbage collect so that the underlying CMessages are freed along 370 # with the Messages they point to. This is to make sure we're not deleting 371 # default message instances. 372 gc.collect() 373 proto = unittest_pb2.TestAllTypes() 374 nested = proto.optional_nested_message 375 376 def testDisconnectingNestedMessageAfterSettingField(self): 377 proto = unittest_pb2.TestAllTypes() 378 nested = proto.optional_nested_message 379 nested.bb = 5 380 self.assertTrue(proto.HasField('optional_nested_message')) 381 proto.ClearField('optional_nested_message') # Should disconnect from parent 382 self.assertEqual(5, nested.bb) 383 self.assertEqual(0, proto.optional_nested_message.bb) 384 self.assertTrue(nested is not proto.optional_nested_message) 385 nested.bb = 23 386 self.assertTrue(not proto.HasField('optional_nested_message')) 387 self.assertEqual(0, proto.optional_nested_message.bb) 388 389 def testDisconnectingNestedMessageBeforeGettingField(self): 390 proto = unittest_pb2.TestAllTypes() 391 self.assertTrue(not proto.HasField('optional_nested_message')) 392 proto.ClearField('optional_nested_message') 393 self.assertTrue(not proto.HasField('optional_nested_message')) 394 395 def testDisconnectingNestedMessageAfterMerge(self): 396 # This test exercises the code path that does not use ReleaseMessage(). 397 # The underlying fear is that if we use ReleaseMessage() incorrectly, 398 # we will have memory leaks. It's hard to check that that doesn't happen, 399 # but at least we can exercise that code path to make sure it works. 400 proto1 = unittest_pb2.TestAllTypes() 401 proto2 = unittest_pb2.TestAllTypes() 402 proto2.optional_nested_message.bb = 5 403 proto1.MergeFrom(proto2) 404 self.assertTrue(proto1.HasField('optional_nested_message')) 405 proto1.ClearField('optional_nested_message') 406 self.assertTrue(not proto1.HasField('optional_nested_message')) 407 408 def testDisconnectingLazyNestedMessage(self): 409 # This test exercises releasing a nested message that is lazy. This test 410 # only exercises real code in the C++ implementation as Python does not 411 # support lazy parsing, but the current C++ implementation results in 412 # memory corruption and a crash. 413 if api_implementation.Type() != 'python': 414 return 415 proto = unittest_pb2.TestAllTypes() 416 proto.optional_lazy_message.bb = 5 417 proto.ClearField('optional_lazy_message') 418 del proto 419 gc.collect() 420 421 def testHasBitsWhenModifyingRepeatedFields(self): 422 # Test nesting when we add an element to a repeated field in a submessage. 423 proto = unittest_pb2.TestNestedMessageHasBits() 424 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 425 self.assertEqual( 426 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 427 self.assertTrue(proto.HasField('optional_nested_message')) 428 429 # Do the same test, but with a repeated composite field within the 430 # submessage. 431 proto.ClearField('optional_nested_message') 432 self.assertTrue(not proto.HasField('optional_nested_message')) 433 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 434 self.assertTrue(proto.HasField('optional_nested_message')) 435 436 def testHasBitsForManyLevelsOfNesting(self): 437 # Test nesting many levels deep. 438 recursive_proto = unittest_pb2.TestMutualRecursionA() 439 self.assertTrue(not recursive_proto.HasField('bb')) 440 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 441 self.assertTrue(not recursive_proto.HasField('bb')) 442 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 443 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 444 self.assertTrue(recursive_proto.HasField('bb')) 445 self.assertTrue(recursive_proto.bb.HasField('a')) 446 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 447 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 448 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 449 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) 450 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 451 452 def testSingularListFields(self): 453 proto = unittest_pb2.TestAllTypes() 454 proto.optional_fixed32 = 1 455 proto.optional_int32 = 5 456 proto.optional_string = 'foo' 457 # Access sub-message but don't set it yet. 458 nested_message = proto.optional_nested_message 459 self.assertEqual( 460 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 461 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 462 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 463 proto.ListFields()) 464 465 proto.optional_nested_message.bb = 123 466 self.assertEqual( 467 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 468 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 469 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 470 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 471 nested_message) ], 472 proto.ListFields()) 473 474 def testRepeatedListFields(self): 475 proto = unittest_pb2.TestAllTypes() 476 proto.repeated_fixed32.append(1) 477 proto.repeated_int32.append(5) 478 proto.repeated_int32.append(11) 479 proto.repeated_string.extend(['foo', 'bar']) 480 proto.repeated_string.extend([]) 481 proto.repeated_string.append('baz') 482 proto.repeated_string.extend(str(x) for x in range(2)) 483 proto.optional_int32 = 21 484 proto.repeated_bool # Access but don't set anything; should not be listed. 485 self.assertEqual( 486 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 487 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 488 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 489 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 490 ['foo', 'bar', 'baz', '0', '1']) ], 491 proto.ListFields()) 492 493 def testSingularListExtensions(self): 494 proto = unittest_pb2.TestAllExtensions() 495 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 496 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 497 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 498 self.assertEqual( 499 [ (unittest_pb2.optional_int32_extension , 5), 500 (unittest_pb2.optional_fixed32_extension, 1), 501 (unittest_pb2.optional_string_extension , 'foo') ], 502 proto.ListFields()) 503 504 def testRepeatedListExtensions(self): 505 proto = unittest_pb2.TestAllExtensions() 506 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 507 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 508 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 509 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 510 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 511 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 512 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 513 self.assertEqual( 514 [ (unittest_pb2.optional_int32_extension , 21), 515 (unittest_pb2.repeated_int32_extension , [5, 11]), 516 (unittest_pb2.repeated_fixed32_extension, [1]), 517 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 518 proto.ListFields()) 519 520 def testListFieldsAndExtensions(self): 521 proto = unittest_pb2.TestFieldOrderings() 522 test_util.SetAllFieldsAndExtensions(proto) 523 unittest_pb2.my_extension_int 524 self.assertEqual( 525 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 526 (unittest_pb2.my_extension_int , 23), 527 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 528 (unittest_pb2.my_extension_string , 'bar'), 529 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 530 proto.ListFields()) 531 532 def testDefaultValues(self): 533 proto = unittest_pb2.TestAllTypes() 534 self.assertEqual(0, proto.optional_int32) 535 self.assertEqual(0, proto.optional_int64) 536 self.assertEqual(0, proto.optional_uint32) 537 self.assertEqual(0, proto.optional_uint64) 538 self.assertEqual(0, proto.optional_sint32) 539 self.assertEqual(0, proto.optional_sint64) 540 self.assertEqual(0, proto.optional_fixed32) 541 self.assertEqual(0, proto.optional_fixed64) 542 self.assertEqual(0, proto.optional_sfixed32) 543 self.assertEqual(0, proto.optional_sfixed64) 544 self.assertEqual(0.0, proto.optional_float) 545 self.assertEqual(0.0, proto.optional_double) 546 self.assertEqual(False, proto.optional_bool) 547 self.assertEqual('', proto.optional_string) 548 self.assertEqual(b'', proto.optional_bytes) 549 550 self.assertEqual(41, proto.default_int32) 551 self.assertEqual(42, proto.default_int64) 552 self.assertEqual(43, proto.default_uint32) 553 self.assertEqual(44, proto.default_uint64) 554 self.assertEqual(-45, proto.default_sint32) 555 self.assertEqual(46, proto.default_sint64) 556 self.assertEqual(47, proto.default_fixed32) 557 self.assertEqual(48, proto.default_fixed64) 558 self.assertEqual(49, proto.default_sfixed32) 559 self.assertEqual(-50, proto.default_sfixed64) 560 self.assertEqual(51.5, proto.default_float) 561 self.assertEqual(52e3, proto.default_double) 562 self.assertEqual(True, proto.default_bool) 563 self.assertEqual('hello', proto.default_string) 564 self.assertEqual(b'world', proto.default_bytes) 565 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 566 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 567 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 568 proto.default_import_enum) 569 570 proto = unittest_pb2.TestExtremeDefaultValues() 571 self.assertEqual(u'\u1234', proto.utf8_string) 572 573 def testHasFieldWithUnknownFieldName(self): 574 proto = unittest_pb2.TestAllTypes() 575 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 576 577 def testClearFieldWithUnknownFieldName(self): 578 proto = unittest_pb2.TestAllTypes() 579 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 580 581 def testClearRemovesChildren(self): 582 # Make sure there aren't any implementation bugs that are only partially 583 # clearing the message (which can happen in the more complex C++ 584 # implementation which has parallel message lists). 585 proto = unittest_pb2.TestRequiredForeign() 586 for i in range(10): 587 proto.repeated_message.add() 588 proto2 = unittest_pb2.TestRequiredForeign() 589 proto.CopyFrom(proto2) 590 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 591 592 def testDisallowedAssignments(self): 593 # It's illegal to assign values directly to repeated fields 594 # or to nonrepeated composite fields. Ensure that this fails. 595 proto = unittest_pb2.TestAllTypes() 596 # Repeated fields. 597 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 598 # Lists shouldn't work, either. 599 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 600 # Composite fields. 601 self.assertRaises(AttributeError, setattr, proto, 602 'optional_nested_message', 23) 603 # Assignment to a repeated nested message field without specifying 604 # the index in the array of nested messages. 605 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 606 'bb', 34) 607 # Assignment to an attribute of a repeated field. 608 self.assertRaises(AttributeError, setattr, proto.repeated_float, 609 'some_attribute', 34) 610 # proto.nonexistent_field = 23 should fail as well. 611 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 612 613 def testSingleScalarTypeSafety(self): 614 proto = unittest_pb2.TestAllTypes() 615 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 616 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 617 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 618 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 619 620 def testIntegerTypes(self): 621 def TestGetAndDeserialize(field_name, value, expected_type): 622 proto = unittest_pb2.TestAllTypes() 623 setattr(proto, field_name, value) 624 self.assertIsInstance(getattr(proto, field_name), expected_type) 625 proto2 = unittest_pb2.TestAllTypes() 626 proto2.ParseFromString(proto.SerializeToString()) 627 self.assertIsInstance(getattr(proto2, field_name), expected_type) 628 629 TestGetAndDeserialize('optional_int32', 1, int) 630 TestGetAndDeserialize('optional_int32', 1 << 30, int) 631 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 632 try: 633 integer_64 = long 634 except NameError: # Python3 635 integer_64 = int 636 if struct.calcsize('L') == 4: 637 # Python only has signed ints, so 32-bit python can't fit an uint32 638 # in an int. 639 TestGetAndDeserialize('optional_uint32', 1 << 31, long) 640 else: 641 # 64-bit python can fit uint32 inside an int 642 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 643 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 644 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 645 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 646 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 647 648 def testSingleScalarBoundsChecking(self): 649 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 650 pb = unittest_pb2.TestAllTypes() 651 setattr(pb, field_name, expected_min) 652 self.assertEqual(expected_min, getattr(pb, field_name)) 653 setattr(pb, field_name, expected_max) 654 self.assertEqual(expected_max, getattr(pb, field_name)) 655 self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) 656 self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) 657 658 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 659 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 660 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 661 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 662 663 pb = unittest_pb2.TestAllTypes() 664 pb.optional_nested_enum = 1 665 self.assertEqual(1, pb.optional_nested_enum) 666 667 def testRepeatedScalarTypeSafety(self): 668 proto = unittest_pb2.TestAllTypes() 669 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 670 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 671 self.assertRaises(TypeError, proto.repeated_string, 10) 672 self.assertRaises(TypeError, proto.repeated_bytes, 10) 673 674 proto.repeated_int32.append(10) 675 proto.repeated_int32[0] = 23 676 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 677 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 678 679 # Repeated enums tests. 680 #proto.repeated_nested_enum.append(0) 681 682 def testSingleScalarGettersAndSetters(self): 683 proto = unittest_pb2.TestAllTypes() 684 self.assertEqual(0, proto.optional_int32) 685 proto.optional_int32 = 1 686 self.assertEqual(1, proto.optional_int32) 687 688 proto.optional_uint64 = 0xffffffffffff 689 self.assertEqual(0xffffffffffff, proto.optional_uint64) 690 proto.optional_uint64 = 0xffffffffffffffff 691 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 692 # TODO(robinson): Test all other scalar field types. 693 694 def testSingleScalarClearField(self): 695 proto = unittest_pb2.TestAllTypes() 696 # Should be allowed to clear something that's not there (a no-op). 697 proto.ClearField('optional_int32') 698 proto.optional_int32 = 1 699 self.assertTrue(proto.HasField('optional_int32')) 700 proto.ClearField('optional_int32') 701 self.assertEqual(0, proto.optional_int32) 702 self.assertTrue(not proto.HasField('optional_int32')) 703 # TODO(robinson): Test all other scalar field types. 704 705 def testEnums(self): 706 proto = unittest_pb2.TestAllTypes() 707 self.assertEqual(1, proto.FOO) 708 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) 709 self.assertEqual(2, proto.BAR) 710 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) 711 self.assertEqual(3, proto.BAZ) 712 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) 713 714 def testEnum_Name(self): 715 self.assertEqual('FOREIGN_FOO', 716 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) 717 self.assertEqual('FOREIGN_BAR', 718 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) 719 self.assertEqual('FOREIGN_BAZ', 720 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) 721 self.assertRaises(ValueError, 722 unittest_pb2.ForeignEnum.Name, 11312) 723 724 proto = unittest_pb2.TestAllTypes() 725 self.assertEqual('FOO', 726 proto.NestedEnum.Name(proto.FOO)) 727 self.assertEqual('FOO', 728 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) 729 self.assertEqual('BAR', 730 proto.NestedEnum.Name(proto.BAR)) 731 self.assertEqual('BAR', 732 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) 733 self.assertEqual('BAZ', 734 proto.NestedEnum.Name(proto.BAZ)) 735 self.assertEqual('BAZ', 736 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) 737 self.assertRaises(ValueError, 738 proto.NestedEnum.Name, 11312) 739 self.assertRaises(ValueError, 740 unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) 741 742 def testEnum_Value(self): 743 self.assertEqual(unittest_pb2.FOREIGN_FOO, 744 unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) 745 self.assertEqual(unittest_pb2.FOREIGN_BAR, 746 unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) 747 self.assertEqual(unittest_pb2.FOREIGN_BAZ, 748 unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) 749 self.assertRaises(ValueError, 750 unittest_pb2.ForeignEnum.Value, 'FO') 751 752 proto = unittest_pb2.TestAllTypes() 753 self.assertEqual(proto.FOO, 754 proto.NestedEnum.Value('FOO')) 755 self.assertEqual(proto.FOO, 756 unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) 757 self.assertEqual(proto.BAR, 758 proto.NestedEnum.Value('BAR')) 759 self.assertEqual(proto.BAR, 760 unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) 761 self.assertEqual(proto.BAZ, 762 proto.NestedEnum.Value('BAZ')) 763 self.assertEqual(proto.BAZ, 764 unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) 765 self.assertRaises(ValueError, 766 proto.NestedEnum.Value, 'Foo') 767 self.assertRaises(ValueError, 768 unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') 769 770 def testEnum_KeysAndValues(self): 771 self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], 772 list(unittest_pb2.ForeignEnum.keys())) 773 self.assertEqual([4, 5, 6], 774 list(unittest_pb2.ForeignEnum.values())) 775 self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), 776 ('FOREIGN_BAZ', 6)], 777 list(unittest_pb2.ForeignEnum.items())) 778 779 proto = unittest_pb2.TestAllTypes() 780 self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys())) 781 self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values())) 782 self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], 783 list(proto.NestedEnum.items())) 784 785 def testRepeatedScalars(self): 786 proto = unittest_pb2.TestAllTypes() 787 788 self.assertTrue(not proto.repeated_int32) 789 self.assertEqual(0, len(proto.repeated_int32)) 790 proto.repeated_int32.append(5) 791 proto.repeated_int32.append(10) 792 proto.repeated_int32.append(15) 793 self.assertTrue(proto.repeated_int32) 794 self.assertEqual(3, len(proto.repeated_int32)) 795 796 self.assertEqual([5, 10, 15], proto.repeated_int32) 797 798 # Test single retrieval. 799 self.assertEqual(5, proto.repeated_int32[0]) 800 self.assertEqual(15, proto.repeated_int32[-1]) 801 # Test out-of-bounds indices. 802 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 803 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 804 # Test incorrect types passed to __getitem__. 805 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 806 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 807 808 # Test single assignment. 809 proto.repeated_int32[1] = 20 810 self.assertEqual([5, 20, 15], proto.repeated_int32) 811 812 # Test insertion. 813 proto.repeated_int32.insert(1, 25) 814 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 815 816 # Test slice retrieval. 817 proto.repeated_int32.append(30) 818 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 819 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 820 821 # Test slice assignment with an iterator 822 proto.repeated_int32[1:4] = (i for i in range(3)) 823 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 824 825 # Test slice assignment. 826 proto.repeated_int32[1:4] = [35, 40, 45] 827 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 828 829 # Test that we can use the field as an iterator. 830 result = [] 831 for i in proto.repeated_int32: 832 result.append(i) 833 self.assertEqual([5, 35, 40, 45, 30], result) 834 835 # Test single deletion. 836 del proto.repeated_int32[2] 837 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 838 839 # Test slice deletion. 840 del proto.repeated_int32[2:] 841 self.assertEqual([5, 35], proto.repeated_int32) 842 843 # Test extending. 844 proto.repeated_int32.extend([3, 13]) 845 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 846 847 # Test clearing. 848 proto.ClearField('repeated_int32') 849 self.assertTrue(not proto.repeated_int32) 850 self.assertEqual(0, len(proto.repeated_int32)) 851 852 proto.repeated_int32.append(1) 853 self.assertEqual(1, proto.repeated_int32[-1]) 854 # Test assignment to a negative index. 855 proto.repeated_int32[-1] = 2 856 self.assertEqual(2, proto.repeated_int32[-1]) 857 858 # Test deletion at negative indices. 859 proto.repeated_int32[:] = [0, 1, 2, 3] 860 del proto.repeated_int32[-1] 861 self.assertEqual([0, 1, 2], proto.repeated_int32) 862 863 del proto.repeated_int32[-2] 864 self.assertEqual([0, 2], proto.repeated_int32) 865 866 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 867 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 868 869 del proto.repeated_int32[-2:-1] 870 self.assertEqual([2], proto.repeated_int32) 871 872 del proto.repeated_int32[100:10000] 873 self.assertEqual([2], proto.repeated_int32) 874 875 def testRepeatedScalarsRemove(self): 876 proto = unittest_pb2.TestAllTypes() 877 878 self.assertTrue(not proto.repeated_int32) 879 self.assertEqual(0, len(proto.repeated_int32)) 880 proto.repeated_int32.append(5) 881 proto.repeated_int32.append(10) 882 proto.repeated_int32.append(5) 883 proto.repeated_int32.append(5) 884 885 self.assertEqual(4, len(proto.repeated_int32)) 886 proto.repeated_int32.remove(5) 887 self.assertEqual(3, len(proto.repeated_int32)) 888 self.assertEqual(10, proto.repeated_int32[0]) 889 self.assertEqual(5, proto.repeated_int32[1]) 890 self.assertEqual(5, proto.repeated_int32[2]) 891 892 proto.repeated_int32.remove(5) 893 self.assertEqual(2, len(proto.repeated_int32)) 894 self.assertEqual(10, proto.repeated_int32[0]) 895 self.assertEqual(5, proto.repeated_int32[1]) 896 897 proto.repeated_int32.remove(10) 898 self.assertEqual(1, len(proto.repeated_int32)) 899 self.assertEqual(5, proto.repeated_int32[0]) 900 901 # Remove a non-existent element. 902 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 903 904 def testRepeatedComposites(self): 905 proto = unittest_pb2.TestAllTypes() 906 self.assertTrue(not proto.repeated_nested_message) 907 self.assertEqual(0, len(proto.repeated_nested_message)) 908 m0 = proto.repeated_nested_message.add() 909 m1 = proto.repeated_nested_message.add() 910 self.assertTrue(proto.repeated_nested_message) 911 self.assertEqual(2, len(proto.repeated_nested_message)) 912 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 913 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 914 915 # Test out-of-bounds indices. 916 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 917 1234) 918 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 919 -1234) 920 921 # Test incorrect types passed to __getitem__. 922 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 923 'foo') 924 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 925 None) 926 927 # Test slice retrieval. 928 m2 = proto.repeated_nested_message.add() 929 m3 = proto.repeated_nested_message.add() 930 m4 = proto.repeated_nested_message.add() 931 self.assertListsEqual( 932 [m1, m2, m3], proto.repeated_nested_message[1:4]) 933 self.assertListsEqual( 934 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 935 self.assertListsEqual( 936 [m0, m1], proto.repeated_nested_message[:2]) 937 self.assertListsEqual( 938 [m2, m3, m4], proto.repeated_nested_message[2:]) 939 self.assertEqual( 940 m0, proto.repeated_nested_message[0]) 941 self.assertListsEqual( 942 [m0], proto.repeated_nested_message[:1]) 943 944 # Test that we can use the field as an iterator. 945 result = [] 946 for i in proto.repeated_nested_message: 947 result.append(i) 948 self.assertListsEqual([m0, m1, m2, m3, m4], result) 949 950 # Test single deletion. 951 del proto.repeated_nested_message[2] 952 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 953 954 # Test slice deletion. 955 del proto.repeated_nested_message[2:] 956 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 957 958 # Test extending. 959 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 960 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 961 proto.repeated_nested_message.extend([n1,n2]) 962 self.assertEqual(4, len(proto.repeated_nested_message)) 963 self.assertEqual(n1, proto.repeated_nested_message[2]) 964 self.assertEqual(n2, proto.repeated_nested_message[3]) 965 966 # Test clearing. 967 proto.ClearField('repeated_nested_message') 968 self.assertTrue(not proto.repeated_nested_message) 969 self.assertEqual(0, len(proto.repeated_nested_message)) 970 971 # Test constructing an element while adding it. 972 proto.repeated_nested_message.add(bb=23) 973 self.assertEqual(1, len(proto.repeated_nested_message)) 974 self.assertEqual(23, proto.repeated_nested_message[0].bb) 975 self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) 976 977 def testRepeatedCompositeRemove(self): 978 proto = unittest_pb2.TestAllTypes() 979 980 self.assertEqual(0, len(proto.repeated_nested_message)) 981 m0 = proto.repeated_nested_message.add() 982 # Need to set some differentiating variable so m0 != m1 != m2: 983 m0.bb = len(proto.repeated_nested_message) 984 m1 = proto.repeated_nested_message.add() 985 m1.bb = len(proto.repeated_nested_message) 986 self.assertTrue(m0 != m1) 987 m2 = proto.repeated_nested_message.add() 988 m2.bb = len(proto.repeated_nested_message) 989 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 990 991 self.assertEqual(3, len(proto.repeated_nested_message)) 992 proto.repeated_nested_message.remove(m0) 993 self.assertEqual(2, len(proto.repeated_nested_message)) 994 self.assertEqual(m1, proto.repeated_nested_message[0]) 995 self.assertEqual(m2, proto.repeated_nested_message[1]) 996 997 # Removing m0 again or removing None should raise error 998 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 999 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 1000 self.assertEqual(2, len(proto.repeated_nested_message)) 1001 1002 proto.repeated_nested_message.remove(m2) 1003 self.assertEqual(1, len(proto.repeated_nested_message)) 1004 self.assertEqual(m1, proto.repeated_nested_message[0]) 1005 1006 def testHandWrittenReflection(self): 1007 # Hand written extensions are only supported by the pure-Python 1008 # implementation of the API. 1009 if api_implementation.Type() != 'python': 1010 return 1011 1012 FieldDescriptor = descriptor.FieldDescriptor 1013 foo_field_descriptor = FieldDescriptor( 1014 name='foo_field', full_name='MyProto.foo_field', 1015 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1016 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1017 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1018 containing_type=None, message_type=None, enum_type=None, 1019 is_extension=False, extension_scope=None, 1020 options=descriptor_pb2.FieldOptions()) 1021 mydescriptor = descriptor.Descriptor( 1022 name='MyProto', full_name='MyProto', filename='ignored', 1023 containing_type=None, nested_types=[], enum_types=[], 1024 fields=[foo_field_descriptor], extensions=[], 1025 options=descriptor_pb2.MessageOptions()) 1026 class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1027 DESCRIPTOR = mydescriptor 1028 myproto_instance = MyProtoClass() 1029 self.assertEqual(0, myproto_instance.foo_field) 1030 self.assertTrue(not myproto_instance.HasField('foo_field')) 1031 myproto_instance.foo_field = 23 1032 self.assertEqual(23, myproto_instance.foo_field) 1033 self.assertTrue(myproto_instance.HasField('foo_field')) 1034 1035 def testDescriptorProtoSupport(self): 1036 # Hand written descriptors/reflection are only supported by the pure-Python 1037 # implementation of the API. 1038 if api_implementation.Type() != 'python': 1039 return 1040 1041 def AddDescriptorField(proto, field_name, field_type): 1042 AddDescriptorField.field_index += 1 1043 new_field = proto.field.add() 1044 new_field.name = field_name 1045 new_field.type = field_type 1046 new_field.number = AddDescriptorField.field_index 1047 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1048 1049 AddDescriptorField.field_index = 0 1050 1051 desc_proto = descriptor_pb2.DescriptorProto() 1052 desc_proto.name = 'Car' 1053 fdp = descriptor_pb2.FieldDescriptorProto 1054 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1055 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1056 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1057 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1058 # Add a repeated field 1059 AddDescriptorField.field_index += 1 1060 new_field = desc_proto.field.add() 1061 new_field.name = 'owners' 1062 new_field.type = fdp.TYPE_STRING 1063 new_field.number = AddDescriptorField.field_index 1064 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1065 1066 desc = descriptor.MakeDescriptor(desc_proto) 1067 self.assertTrue('name' in desc.fields_by_name) 1068 self.assertTrue('year' in desc.fields_by_name) 1069 self.assertTrue('automatic' in desc.fields_by_name) 1070 self.assertTrue('price' in desc.fields_by_name) 1071 self.assertTrue('owners' in desc.fields_by_name) 1072 1073 class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1074 DESCRIPTOR = desc 1075 1076 prius = CarMessage() 1077 prius.name = 'prius' 1078 prius.year = 2010 1079 prius.automatic = True 1080 prius.price = 25134.75 1081 prius.owners.extend(['bob', 'susan']) 1082 1083 serialized_prius = prius.SerializeToString() 1084 new_prius = reflection.ParseMessage(desc, serialized_prius) 1085 self.assertTrue(new_prius is not prius) 1086 self.assertEqual(prius, new_prius) 1087 1088 # these are unnecessary assuming message equality works as advertised but 1089 # explicitly check to be safe since we're mucking about in metaclass foo 1090 self.assertEqual(prius.name, new_prius.name) 1091 self.assertEqual(prius.year, new_prius.year) 1092 self.assertEqual(prius.automatic, new_prius.automatic) 1093 self.assertEqual(prius.price, new_prius.price) 1094 self.assertEqual(prius.owners, new_prius.owners) 1095 1096 def testTopLevelExtensionsForOptionalScalar(self): 1097 extendee_proto = unittest_pb2.TestAllExtensions() 1098 extension = unittest_pb2.optional_int32_extension 1099 self.assertTrue(not extendee_proto.HasExtension(extension)) 1100 self.assertEqual(0, extendee_proto.Extensions[extension]) 1101 # As with normal scalar fields, just doing a read doesn't actually set the 1102 # "has" bit. 1103 self.assertTrue(not extendee_proto.HasExtension(extension)) 1104 # Actually set the thing. 1105 extendee_proto.Extensions[extension] = 23 1106 self.assertEqual(23, extendee_proto.Extensions[extension]) 1107 self.assertTrue(extendee_proto.HasExtension(extension)) 1108 # Ensure that clearing works as well. 1109 extendee_proto.ClearExtension(extension) 1110 self.assertEqual(0, extendee_proto.Extensions[extension]) 1111 self.assertTrue(not extendee_proto.HasExtension(extension)) 1112 1113 def testTopLevelExtensionsForRepeatedScalar(self): 1114 extendee_proto = unittest_pb2.TestAllExtensions() 1115 extension = unittest_pb2.repeated_string_extension 1116 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1117 extendee_proto.Extensions[extension].append('foo') 1118 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1119 string_list = extendee_proto.Extensions[extension] 1120 extendee_proto.ClearExtension(extension) 1121 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1122 self.assertTrue(string_list is not extendee_proto.Extensions[extension]) 1123 # Shouldn't be allowed to do Extensions[extension] = 'a' 1124 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1125 extension, 'a') 1126 1127 def testTopLevelExtensionsForOptionalMessage(self): 1128 extendee_proto = unittest_pb2.TestAllExtensions() 1129 extension = unittest_pb2.optional_foreign_message_extension 1130 self.assertTrue(not extendee_proto.HasExtension(extension)) 1131 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1132 # As with normal (non-extension) fields, merely reading from the 1133 # thing shouldn't set the "has" bit. 1134 self.assertTrue(not extendee_proto.HasExtension(extension)) 1135 extendee_proto.Extensions[extension].c = 23 1136 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1137 self.assertTrue(extendee_proto.HasExtension(extension)) 1138 # Save a reference here. 1139 foreign_message = extendee_proto.Extensions[extension] 1140 extendee_proto.ClearExtension(extension) 1141 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) 1142 # Setting a field on foreign_message now shouldn't set 1143 # any "has" bits on extendee_proto. 1144 foreign_message.c = 42 1145 self.assertEqual(42, foreign_message.c) 1146 self.assertTrue(foreign_message.HasField('c')) 1147 self.assertTrue(not extendee_proto.HasExtension(extension)) 1148 # Shouldn't be allowed to do Extensions[extension] = 'a' 1149 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1150 extension, 'a') 1151 1152 def testTopLevelExtensionsForRepeatedMessage(self): 1153 extendee_proto = unittest_pb2.TestAllExtensions() 1154 extension = unittest_pb2.repeatedgroup_extension 1155 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1156 group = extendee_proto.Extensions[extension].add() 1157 group.a = 23 1158 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1159 group.a = 42 1160 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1161 group_list = extendee_proto.Extensions[extension] 1162 extendee_proto.ClearExtension(extension) 1163 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1164 self.assertTrue(group_list is not extendee_proto.Extensions[extension]) 1165 # Shouldn't be allowed to do Extensions[extension] = 'a' 1166 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1167 extension, 'a') 1168 1169 def testNestedExtensions(self): 1170 extendee_proto = unittest_pb2.TestAllExtensions() 1171 extension = unittest_pb2.TestRequired.single 1172 1173 # We just test the non-repeated case. 1174 self.assertTrue(not extendee_proto.HasExtension(extension)) 1175 required = extendee_proto.Extensions[extension] 1176 self.assertEqual(0, required.a) 1177 self.assertTrue(not extendee_proto.HasExtension(extension)) 1178 required.a = 23 1179 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1180 self.assertTrue(extendee_proto.HasExtension(extension)) 1181 extendee_proto.ClearExtension(extension) 1182 self.assertTrue(required is not extendee_proto.Extensions[extension]) 1183 self.assertTrue(not extendee_proto.HasExtension(extension)) 1184 1185 def testRegisteredExtensions(self): 1186 self.assertTrue('protobuf_unittest.optional_int32_extension' in 1187 unittest_pb2.TestAllExtensions._extensions_by_name) 1188 self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) 1189 # Make sure extensions haven't been registered into types that shouldn't 1190 # have any. 1191 self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) 1192 1193 # If message A directly contains message B, and 1194 # a.HasField('b') is currently False, then mutating any 1195 # extension in B should change a.HasField('b') to True 1196 # (and so on up the object tree). 1197 def testHasBitsForAncestorsOfExtendedMessage(self): 1198 # Optional scalar extension. 1199 toplevel = more_extensions_pb2.TopLevelMessage() 1200 self.assertTrue(not toplevel.HasField('submessage')) 1201 self.assertEqual(0, toplevel.submessage.Extensions[ 1202 more_extensions_pb2.optional_int_extension]) 1203 self.assertTrue(not toplevel.HasField('submessage')) 1204 toplevel.submessage.Extensions[ 1205 more_extensions_pb2.optional_int_extension] = 23 1206 self.assertEqual(23, toplevel.submessage.Extensions[ 1207 more_extensions_pb2.optional_int_extension]) 1208 self.assertTrue(toplevel.HasField('submessage')) 1209 1210 # Repeated scalar extension. 1211 toplevel = more_extensions_pb2.TopLevelMessage() 1212 self.assertTrue(not toplevel.HasField('submessage')) 1213 self.assertEqual([], toplevel.submessage.Extensions[ 1214 more_extensions_pb2.repeated_int_extension]) 1215 self.assertTrue(not toplevel.HasField('submessage')) 1216 toplevel.submessage.Extensions[ 1217 more_extensions_pb2.repeated_int_extension].append(23) 1218 self.assertEqual([23], toplevel.submessage.Extensions[ 1219 more_extensions_pb2.repeated_int_extension]) 1220 self.assertTrue(toplevel.HasField('submessage')) 1221 1222 # Optional message extension. 1223 toplevel = more_extensions_pb2.TopLevelMessage() 1224 self.assertTrue(not toplevel.HasField('submessage')) 1225 self.assertEqual(0, toplevel.submessage.Extensions[ 1226 more_extensions_pb2.optional_message_extension].foreign_message_int) 1227 self.assertTrue(not toplevel.HasField('submessage')) 1228 toplevel.submessage.Extensions[ 1229 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1230 self.assertEqual(23, toplevel.submessage.Extensions[ 1231 more_extensions_pb2.optional_message_extension].foreign_message_int) 1232 self.assertTrue(toplevel.HasField('submessage')) 1233 1234 # Repeated message extension. 1235 toplevel = more_extensions_pb2.TopLevelMessage() 1236 self.assertTrue(not toplevel.HasField('submessage')) 1237 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1238 more_extensions_pb2.repeated_message_extension])) 1239 self.assertTrue(not toplevel.HasField('submessage')) 1240 foreign = toplevel.submessage.Extensions[ 1241 more_extensions_pb2.repeated_message_extension].add() 1242 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1243 more_extensions_pb2.repeated_message_extension][0]) 1244 self.assertTrue(toplevel.HasField('submessage')) 1245 1246 def testDisconnectionAfterClearingEmptyMessage(self): 1247 toplevel = more_extensions_pb2.TopLevelMessage() 1248 extendee_proto = toplevel.submessage 1249 extension = more_extensions_pb2.optional_message_extension 1250 extension_proto = extendee_proto.Extensions[extension] 1251 extendee_proto.ClearExtension(extension) 1252 extension_proto.foreign_message_int = 23 1253 1254 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) 1255 1256 def testExtensionFailureModes(self): 1257 extendee_proto = unittest_pb2.TestAllExtensions() 1258 1259 # Try non-extension-handle arguments to HasExtension, 1260 # ClearExtension(), and Extensions[]... 1261 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1262 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1263 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1264 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1265 1266 # Try something that *is* an extension handle, just not for 1267 # this message... 1268 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1269 more_extensions_pb2.optional_message_extension, 1270 more_extensions_pb2.repeated_int_extension, 1271 more_extensions_pb2.repeated_message_extension): 1272 self.assertRaises(KeyError, extendee_proto.HasExtension, 1273 unknown_handle) 1274 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1275 unknown_handle) 1276 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1277 unknown_handle) 1278 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1279 unknown_handle, 5) 1280 1281 # Try call HasExtension() with a valid handle, but for a 1282 # *repeated* field. (Just as with non-extension repeated 1283 # fields, Has*() isn't supported for extension repeated fields). 1284 self.assertRaises(KeyError, extendee_proto.HasExtension, 1285 unittest_pb2.repeated_string_extension) 1286 1287 def testStaticParseFrom(self): 1288 proto1 = unittest_pb2.TestAllTypes() 1289 test_util.SetAllFields(proto1) 1290 1291 string1 = proto1.SerializeToString() 1292 proto2 = unittest_pb2.TestAllTypes.FromString(string1) 1293 1294 # Messages should be equal. 1295 self.assertEqual(proto2, proto1) 1296 1297 def testMergeFromSingularField(self): 1298 # Test merge with just a singular field. 1299 proto1 = unittest_pb2.TestAllTypes() 1300 proto1.optional_int32 = 1 1301 1302 proto2 = unittest_pb2.TestAllTypes() 1303 # This shouldn't get overwritten. 1304 proto2.optional_string = 'value' 1305 1306 proto2.MergeFrom(proto1) 1307 self.assertEqual(1, proto2.optional_int32) 1308 self.assertEqual('value', proto2.optional_string) 1309 1310 def testMergeFromRepeatedField(self): 1311 # Test merge with just a repeated field. 1312 proto1 = unittest_pb2.TestAllTypes() 1313 proto1.repeated_int32.append(1) 1314 proto1.repeated_int32.append(2) 1315 1316 proto2 = unittest_pb2.TestAllTypes() 1317 proto2.repeated_int32.append(0) 1318 proto2.MergeFrom(proto1) 1319 1320 self.assertEqual(0, proto2.repeated_int32[0]) 1321 self.assertEqual(1, proto2.repeated_int32[1]) 1322 self.assertEqual(2, proto2.repeated_int32[2]) 1323 1324 def testMergeFromOptionalGroup(self): 1325 # Test merge with an optional group. 1326 proto1 = unittest_pb2.TestAllTypes() 1327 proto1.optionalgroup.a = 12 1328 proto2 = unittest_pb2.TestAllTypes() 1329 proto2.MergeFrom(proto1) 1330 self.assertEqual(12, proto2.optionalgroup.a) 1331 1332 def testMergeFromRepeatedNestedMessage(self): 1333 # Test merge with a repeated nested message. 1334 proto1 = unittest_pb2.TestAllTypes() 1335 m = proto1.repeated_nested_message.add() 1336 m.bb = 123 1337 m = proto1.repeated_nested_message.add() 1338 m.bb = 321 1339 1340 proto2 = unittest_pb2.TestAllTypes() 1341 m = proto2.repeated_nested_message.add() 1342 m.bb = 999 1343 proto2.MergeFrom(proto1) 1344 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 1345 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 1346 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 1347 1348 proto3 = unittest_pb2.TestAllTypes() 1349 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 1350 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 1351 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 1352 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 1353 1354 def testMergeFromAllFields(self): 1355 # With all fields set. 1356 proto1 = unittest_pb2.TestAllTypes() 1357 test_util.SetAllFields(proto1) 1358 proto2 = unittest_pb2.TestAllTypes() 1359 proto2.MergeFrom(proto1) 1360 1361 # Messages should be equal. 1362 self.assertEqual(proto2, proto1) 1363 1364 # Serialized string should be equal too. 1365 string1 = proto1.SerializeToString() 1366 string2 = proto2.SerializeToString() 1367 self.assertEqual(string1, string2) 1368 1369 def testMergeFromExtensionsSingular(self): 1370 proto1 = unittest_pb2.TestAllExtensions() 1371 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1372 1373 proto2 = unittest_pb2.TestAllExtensions() 1374 proto2.MergeFrom(proto1) 1375 self.assertEqual( 1376 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1377 1378 def testMergeFromExtensionsRepeated(self): 1379 proto1 = unittest_pb2.TestAllExtensions() 1380 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1381 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1382 1383 proto2 = unittest_pb2.TestAllExtensions() 1384 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1385 proto2.MergeFrom(proto1) 1386 self.assertEqual( 1387 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1388 self.assertEqual( 1389 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1390 self.assertEqual( 1391 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1392 self.assertEqual( 1393 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1394 1395 def testMergeFromExtensionsNestedMessage(self): 1396 proto1 = unittest_pb2.TestAllExtensions() 1397 ext1 = proto1.Extensions[ 1398 unittest_pb2.repeated_nested_message_extension] 1399 m = ext1.add() 1400 m.bb = 222 1401 m = ext1.add() 1402 m.bb = 333 1403 1404 proto2 = unittest_pb2.TestAllExtensions() 1405 ext2 = proto2.Extensions[ 1406 unittest_pb2.repeated_nested_message_extension] 1407 m = ext2.add() 1408 m.bb = 111 1409 1410 proto2.MergeFrom(proto1) 1411 ext2 = proto2.Extensions[ 1412 unittest_pb2.repeated_nested_message_extension] 1413 self.assertEqual(3, len(ext2)) 1414 self.assertEqual(111, ext2[0].bb) 1415 self.assertEqual(222, ext2[1].bb) 1416 self.assertEqual(333, ext2[2].bb) 1417 1418 def testMergeFromBug(self): 1419 message1 = unittest_pb2.TestAllTypes() 1420 message2 = unittest_pb2.TestAllTypes() 1421 1422 # Cause optional_nested_message to be instantiated within message1, even 1423 # though it is not considered to be "present". 1424 message1.optional_nested_message 1425 self.assertFalse(message1.HasField('optional_nested_message')) 1426 1427 # Merge into message2. This should not instantiate the field is message2. 1428 message2.MergeFrom(message1) 1429 self.assertFalse(message2.HasField('optional_nested_message')) 1430 1431 def testCopyFromSingularField(self): 1432 # Test copy with just a singular field. 1433 proto1 = unittest_pb2.TestAllTypes() 1434 proto1.optional_int32 = 1 1435 proto1.optional_string = 'important-text' 1436 1437 proto2 = unittest_pb2.TestAllTypes() 1438 proto2.optional_string = 'value' 1439 1440 proto2.CopyFrom(proto1) 1441 self.assertEqual(1, proto2.optional_int32) 1442 self.assertEqual('important-text', proto2.optional_string) 1443 1444 def testCopyFromRepeatedField(self): 1445 # Test copy with a repeated field. 1446 proto1 = unittest_pb2.TestAllTypes() 1447 proto1.repeated_int32.append(1) 1448 proto1.repeated_int32.append(2) 1449 1450 proto2 = unittest_pb2.TestAllTypes() 1451 proto2.repeated_int32.append(0) 1452 proto2.CopyFrom(proto1) 1453 1454 self.assertEqual(1, proto2.repeated_int32[0]) 1455 self.assertEqual(2, proto2.repeated_int32[1]) 1456 1457 def testCopyFromAllFields(self): 1458 # With all fields set. 1459 proto1 = unittest_pb2.TestAllTypes() 1460 test_util.SetAllFields(proto1) 1461 proto2 = unittest_pb2.TestAllTypes() 1462 proto2.CopyFrom(proto1) 1463 1464 # Messages should be equal. 1465 self.assertEqual(proto2, proto1) 1466 1467 # Serialized string should be equal too. 1468 string1 = proto1.SerializeToString() 1469 string2 = proto2.SerializeToString() 1470 self.assertEqual(string1, string2) 1471 1472 def testCopyFromSelf(self): 1473 proto1 = unittest_pb2.TestAllTypes() 1474 proto1.repeated_int32.append(1) 1475 proto1.optional_int32 = 2 1476 proto1.optional_string = 'important-text' 1477 1478 proto1.CopyFrom(proto1) 1479 self.assertEqual(1, proto1.repeated_int32[0]) 1480 self.assertEqual(2, proto1.optional_int32) 1481 self.assertEqual('important-text', proto1.optional_string) 1482 1483 def testCopyFromBadType(self): 1484 # The python implementation doesn't raise an exception in this 1485 # case. In theory it should. 1486 if api_implementation.Type() == 'python': 1487 return 1488 proto1 = unittest_pb2.TestAllTypes() 1489 proto2 = unittest_pb2.TestAllExtensions() 1490 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1491 1492 def testDeepCopy(self): 1493 proto1 = unittest_pb2.TestAllTypes() 1494 proto1.optional_int32 = 1 1495 proto2 = copy.deepcopy(proto1) 1496 self.assertEqual(1, proto2.optional_int32) 1497 1498 proto1.repeated_int32.append(2) 1499 proto1.repeated_int32.append(3) 1500 container = copy.deepcopy(proto1.repeated_int32) 1501 self.assertEqual([2, 3], container) 1502 1503 # TODO(anuraag): Implement deepcopy for repeated composite / extension dict 1504 1505 def testClear(self): 1506 proto = unittest_pb2.TestAllTypes() 1507 # C++ implementation does not support lazy fields right now so leave it 1508 # out for now. 1509 if api_implementation.Type() == 'python': 1510 test_util.SetAllFields(proto) 1511 else: 1512 test_util.SetAllNonLazyFields(proto) 1513 # Clear the message. 1514 proto.Clear() 1515 self.assertEqual(proto.ByteSize(), 0) 1516 empty_proto = unittest_pb2.TestAllTypes() 1517 self.assertEqual(proto, empty_proto) 1518 1519 # Test if extensions which were set are cleared. 1520 proto = unittest_pb2.TestAllExtensions() 1521 test_util.SetAllExtensions(proto) 1522 # Clear the message. 1523 proto.Clear() 1524 self.assertEqual(proto.ByteSize(), 0) 1525 empty_proto = unittest_pb2.TestAllExtensions() 1526 self.assertEqual(proto, empty_proto) 1527 1528 def testDisconnectingBeforeClear(self): 1529 proto = unittest_pb2.TestAllTypes() 1530 nested = proto.optional_nested_message 1531 proto.Clear() 1532 self.assertTrue(nested is not proto.optional_nested_message) 1533 nested.bb = 23 1534 self.assertTrue(not proto.HasField('optional_nested_message')) 1535 self.assertEqual(0, proto.optional_nested_message.bb) 1536 1537 proto = unittest_pb2.TestAllTypes() 1538 nested = proto.optional_nested_message 1539 nested.bb = 5 1540 foreign = proto.optional_foreign_message 1541 foreign.c = 6 1542 1543 proto.Clear() 1544 self.assertTrue(nested is not proto.optional_nested_message) 1545 self.assertTrue(foreign is not proto.optional_foreign_message) 1546 self.assertEqual(5, nested.bb) 1547 self.assertEqual(6, foreign.c) 1548 nested.bb = 15 1549 foreign.c = 16 1550 self.assertFalse(proto.HasField('optional_nested_message')) 1551 self.assertEqual(0, proto.optional_nested_message.bb) 1552 self.assertFalse(proto.HasField('optional_foreign_message')) 1553 self.assertEqual(0, proto.optional_foreign_message.c) 1554 1555 def testOneOf(self): 1556 proto = unittest_pb2.TestAllTypes() 1557 proto.oneof_uint32 = 10 1558 proto.oneof_nested_message.bb = 11 1559 self.assertEqual(11, proto.oneof_nested_message.bb) 1560 self.assertFalse(proto.HasField('oneof_uint32')) 1561 nested = proto.oneof_nested_message 1562 proto.oneof_string = 'abc' 1563 self.assertEqual('abc', proto.oneof_string) 1564 self.assertEqual(11, nested.bb) 1565 self.assertFalse(proto.HasField('oneof_nested_message')) 1566 1567 def assertInitialized(self, proto): 1568 self.assertTrue(proto.IsInitialized()) 1569 # Neither method should raise an exception. 1570 proto.SerializeToString() 1571 proto.SerializePartialToString() 1572 1573 def assertNotInitialized(self, proto): 1574 self.assertFalse(proto.IsInitialized()) 1575 self.assertRaises(message.EncodeError, proto.SerializeToString) 1576 # "Partial" serialization doesn't care if message is uninitialized. 1577 proto.SerializePartialToString() 1578 1579 def testIsInitialized(self): 1580 # Trivial cases - all optional fields and extensions. 1581 proto = unittest_pb2.TestAllTypes() 1582 self.assertInitialized(proto) 1583 proto = unittest_pb2.TestAllExtensions() 1584 self.assertInitialized(proto) 1585 1586 # The case of uninitialized required fields. 1587 proto = unittest_pb2.TestRequired() 1588 self.assertNotInitialized(proto) 1589 proto.a = proto.b = proto.c = 2 1590 self.assertInitialized(proto) 1591 1592 # The case of uninitialized submessage. 1593 proto = unittest_pb2.TestRequiredForeign() 1594 self.assertInitialized(proto) 1595 proto.optional_message.a = 1 1596 self.assertNotInitialized(proto) 1597 proto.optional_message.b = 0 1598 proto.optional_message.c = 0 1599 self.assertInitialized(proto) 1600 1601 # Uninitialized repeated submessage. 1602 message1 = proto.repeated_message.add() 1603 self.assertNotInitialized(proto) 1604 message1.a = message1.b = message1.c = 0 1605 self.assertInitialized(proto) 1606 1607 # Uninitialized repeated group in an extension. 1608 proto = unittest_pb2.TestAllExtensions() 1609 extension = unittest_pb2.TestRequired.multi 1610 message1 = proto.Extensions[extension].add() 1611 message2 = proto.Extensions[extension].add() 1612 self.assertNotInitialized(proto) 1613 message1.a = 1 1614 message1.b = 1 1615 message1.c = 1 1616 self.assertNotInitialized(proto) 1617 message2.a = 2 1618 message2.b = 2 1619 message2.c = 2 1620 self.assertInitialized(proto) 1621 1622 # Uninitialized nonrepeated message in an extension. 1623 proto = unittest_pb2.TestAllExtensions() 1624 extension = unittest_pb2.TestRequired.single 1625 proto.Extensions[extension].a = 1 1626 self.assertNotInitialized(proto) 1627 proto.Extensions[extension].b = 2 1628 proto.Extensions[extension].c = 3 1629 self.assertInitialized(proto) 1630 1631 # Try passing an errors list. 1632 errors = [] 1633 proto = unittest_pb2.TestRequired() 1634 self.assertFalse(proto.IsInitialized(errors)) 1635 self.assertEqual(errors, ['a', 'b', 'c']) 1636 1637 @unittest.skipIf( 1638 api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 1639 'Errors are only available from the most recent C++ implementation.') 1640 def testFileDescriptorErrors(self): 1641 file_name = 'test_file_descriptor_errors.proto' 1642 package_name = 'test_file_descriptor_errors.proto' 1643 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 1644 file_descriptor_proto.name = file_name 1645 file_descriptor_proto.package = package_name 1646 m1 = file_descriptor_proto.message_type.add() 1647 m1.name = 'msg1' 1648 # Compiles the proto into the C++ descriptor pool 1649 descriptor.FileDescriptor( 1650 file_name, 1651 package_name, 1652 serialized_pb=file_descriptor_proto.SerializeToString()) 1653 # Add a FileDescriptorProto that has duplicate symbols 1654 another_file_name = 'another_test_file_descriptor_errors.proto' 1655 file_descriptor_proto.name = another_file_name 1656 m2 = file_descriptor_proto.message_type.add() 1657 m2.name = 'msg2' 1658 with self.assertRaises(TypeError) as cm: 1659 descriptor.FileDescriptor( 1660 another_file_name, 1661 package_name, 1662 serialized_pb=file_descriptor_proto.SerializeToString()) 1663 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 1664 getattr(cm.expected, '__name__', cm.expected)) 1665 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 1666 # Error message will say something about this definition being a 1667 # duplicate, though we don't check the message exactly to avoid a 1668 # dependency on the C++ logging code. 1669 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 1670 1671 def testStringUTF8Encoding(self): 1672 proto = unittest_pb2.TestAllTypes() 1673 1674 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 1675 self.assertRaises(TypeError, 1676 setattr, proto, 'optional_bytes', u'unicode object') 1677 1678 # Check that the default value is of python's 'unicode' type. 1679 self.assertEqual(type(proto.optional_string), six.text_type) 1680 1681 proto.optional_string = six.text_type('Testing') 1682 self.assertEqual(proto.optional_string, str('Testing')) 1683 1684 # Assign a value of type 'str' which can be encoded in UTF-8. 1685 proto.optional_string = str('Testing') 1686 self.assertEqual(proto.optional_string, six.text_type('Testing')) 1687 1688 # Try to assign a 'bytes' object which contains non-UTF-8. 1689 self.assertRaises(ValueError, 1690 setattr, proto, 'optional_string', b'a\x80a') 1691 # No exception: Assign already encoded UTF-8 bytes to a string field. 1692 utf8_bytes = u'Тест'.encode('utf-8') 1693 proto.optional_string = utf8_bytes 1694 # No exception: Assign the a non-ascii unicode object. 1695 proto.optional_string = u'Тест' 1696 # No exception thrown (normal str assignment containing ASCII). 1697 proto.optional_string = 'abc' 1698 1699 def testStringUTF8Serialization(self): 1700 proto = message_set_extensions_pb2.TestMessageSet() 1701 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 1702 extension = extension_message.message_set_extension 1703 1704 test_utf8 = u'Тест' 1705 test_utf8_bytes = test_utf8.encode('utf-8') 1706 1707 # 'Test' in another language, using UTF-8 charset. 1708 proto.Extensions[extension].str = test_utf8 1709 1710 # Serialize using the MessageSet wire format (this is specified in the 1711 # .proto file). 1712 serialized = proto.SerializeToString() 1713 1714 # Check byte size. 1715 self.assertEqual(proto.ByteSize(), len(serialized)) 1716 1717 raw = unittest_mset_pb2.RawMessageSet() 1718 bytes_read = raw.MergeFromString(serialized) 1719 self.assertEqual(len(serialized), bytes_read) 1720 1721 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 1722 1723 self.assertEqual(1, len(raw.item)) 1724 # Check that the type_id is the same as the tag ID in the .proto file. 1725 self.assertEqual(raw.item[0].type_id, 98418634) 1726 1727 # Check the actual bytes on the wire. 1728 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 1729 bytes_read = message2.MergeFromString(raw.item[0].message) 1730 self.assertEqual(len(raw.item[0].message), bytes_read) 1731 1732 self.assertEqual(type(message2.str), six.text_type) 1733 self.assertEqual(message2.str, test_utf8) 1734 1735 # The pure Python API throws an exception on MergeFromString(), 1736 # if any of the string fields of the message can't be UTF-8 decoded. 1737 # The C++ implementation of the API has no way to check that on 1738 # MergeFromString and thus has no way to throw the exception. 1739 # 1740 # The pure Python API always returns objects of type 'unicode' (UTF-8 1741 # encoded), or 'bytes' (in 7 bit ASCII). 1742 badbytes = raw.item[0].message.replace( 1743 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 1744 1745 unicode_decode_failed = False 1746 try: 1747 message2.MergeFromString(badbytes) 1748 except UnicodeDecodeError: 1749 unicode_decode_failed = True 1750 string_field = message2.str 1751 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 1752 1753 def testBytesInTextFormat(self): 1754 proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 1755 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', 1756 six.text_type(proto)) 1757 1758 def testEmptyNestedMessage(self): 1759 proto = unittest_pb2.TestAllTypes() 1760 proto.optional_nested_message.MergeFrom( 1761 unittest_pb2.TestAllTypes.NestedMessage()) 1762 self.assertTrue(proto.HasField('optional_nested_message')) 1763 1764 proto = unittest_pb2.TestAllTypes() 1765 proto.optional_nested_message.CopyFrom( 1766 unittest_pb2.TestAllTypes.NestedMessage()) 1767 self.assertTrue(proto.HasField('optional_nested_message')) 1768 1769 proto = unittest_pb2.TestAllTypes() 1770 bytes_read = proto.optional_nested_message.MergeFromString(b'') 1771 self.assertEqual(0, bytes_read) 1772 self.assertTrue(proto.HasField('optional_nested_message')) 1773 1774 proto = unittest_pb2.TestAllTypes() 1775 proto.optional_nested_message.ParseFromString(b'') 1776 self.assertTrue(proto.HasField('optional_nested_message')) 1777 1778 serialized = proto.SerializeToString() 1779 proto2 = unittest_pb2.TestAllTypes() 1780 self.assertEqual( 1781 len(serialized), 1782 proto2.MergeFromString(serialized)) 1783 self.assertTrue(proto2.HasField('optional_nested_message')) 1784 1785 def testSetInParent(self): 1786 proto = unittest_pb2.TestAllTypes() 1787 self.assertFalse(proto.HasField('optionalgroup')) 1788 proto.optionalgroup.SetInParent() 1789 self.assertTrue(proto.HasField('optionalgroup')) 1790 1791 def testPackageInitializationImport(self): 1792 """Test that we can import nested messages from their __init__.py. 1793 1794 Such setup is not trivial since at the time of processing of __init__.py one 1795 can't refer to its submodules by name in code, so expressions like 1796 google.protobuf.internal.import_test_package.inner_pb2 1797 don't work. They do work in imports, so we have assign an alias at import 1798 and then use that alias in generated code. 1799 """ 1800 # We import here since it's the import that used to fail, and we want 1801 # the failure to have the right context. 1802 # pylint: disable=g-import-not-at-top 1803 from google.protobuf.internal import import_test_package 1804 # pylint: enable=g-import-not-at-top 1805 msg = import_test_package.myproto.Outer() 1806 # Just check the default value. 1807 self.assertEqual(57, msg.inner.value) 1808 1809# Since we had so many tests for protocol buffer equality, we broke these out 1810# into separate TestCase classes. 1811 1812 1813class TestAllTypesEqualityTest(unittest.TestCase): 1814 1815 def setUp(self): 1816 self.first_proto = unittest_pb2.TestAllTypes() 1817 self.second_proto = unittest_pb2.TestAllTypes() 1818 1819 def testNotHashable(self): 1820 self.assertRaises(TypeError, hash, self.first_proto) 1821 1822 def testSelfEquality(self): 1823 self.assertEqual(self.first_proto, self.first_proto) 1824 1825 def testEmptyProtosEqual(self): 1826 self.assertEqual(self.first_proto, self.second_proto) 1827 1828 1829class FullProtosEqualityTest(unittest.TestCase): 1830 1831 """Equality tests using completely-full protos as a starting point.""" 1832 1833 def setUp(self): 1834 self.first_proto = unittest_pb2.TestAllTypes() 1835 self.second_proto = unittest_pb2.TestAllTypes() 1836 test_util.SetAllFields(self.first_proto) 1837 test_util.SetAllFields(self.second_proto) 1838 1839 def testNotHashable(self): 1840 self.assertRaises(TypeError, hash, self.first_proto) 1841 1842 def testNoneNotEqual(self): 1843 self.assertNotEqual(self.first_proto, None) 1844 self.assertNotEqual(None, self.second_proto) 1845 1846 def testNotEqualToOtherMessage(self): 1847 third_proto = unittest_pb2.TestRequired() 1848 self.assertNotEqual(self.first_proto, third_proto) 1849 self.assertNotEqual(third_proto, self.second_proto) 1850 1851 def testAllFieldsFilledEquality(self): 1852 self.assertEqual(self.first_proto, self.second_proto) 1853 1854 def testNonRepeatedScalar(self): 1855 # Nonrepeated scalar field change should cause inequality. 1856 self.first_proto.optional_int32 += 1 1857 self.assertNotEqual(self.first_proto, self.second_proto) 1858 # ...as should clearing a field. 1859 self.first_proto.ClearField('optional_int32') 1860 self.assertNotEqual(self.first_proto, self.second_proto) 1861 1862 def testNonRepeatedComposite(self): 1863 # Change a nonrepeated composite field. 1864 self.first_proto.optional_nested_message.bb += 1 1865 self.assertNotEqual(self.first_proto, self.second_proto) 1866 self.first_proto.optional_nested_message.bb -= 1 1867 self.assertEqual(self.first_proto, self.second_proto) 1868 # Clear a field in the nested message. 1869 self.first_proto.optional_nested_message.ClearField('bb') 1870 self.assertNotEqual(self.first_proto, self.second_proto) 1871 self.first_proto.optional_nested_message.bb = ( 1872 self.second_proto.optional_nested_message.bb) 1873 self.assertEqual(self.first_proto, self.second_proto) 1874 # Remove the nested message entirely. 1875 self.first_proto.ClearField('optional_nested_message') 1876 self.assertNotEqual(self.first_proto, self.second_proto) 1877 1878 def testRepeatedScalar(self): 1879 # Change a repeated scalar field. 1880 self.first_proto.repeated_int32.append(5) 1881 self.assertNotEqual(self.first_proto, self.second_proto) 1882 self.first_proto.ClearField('repeated_int32') 1883 self.assertNotEqual(self.first_proto, self.second_proto) 1884 1885 def testRepeatedComposite(self): 1886 # Change value within a repeated composite field. 1887 self.first_proto.repeated_nested_message[0].bb += 1 1888 self.assertNotEqual(self.first_proto, self.second_proto) 1889 self.first_proto.repeated_nested_message[0].bb -= 1 1890 self.assertEqual(self.first_proto, self.second_proto) 1891 # Add a value to a repeated composite field. 1892 self.first_proto.repeated_nested_message.add() 1893 self.assertNotEqual(self.first_proto, self.second_proto) 1894 self.second_proto.repeated_nested_message.add() 1895 self.assertEqual(self.first_proto, self.second_proto) 1896 1897 def testNonRepeatedScalarHasBits(self): 1898 # Ensure that we test "has" bits as well as value for 1899 # nonrepeated scalar field. 1900 self.first_proto.ClearField('optional_int32') 1901 self.second_proto.optional_int32 = 0 1902 self.assertNotEqual(self.first_proto, self.second_proto) 1903 1904 def testNonRepeatedCompositeHasBits(self): 1905 # Ensure that we test "has" bits as well as value for 1906 # nonrepeated composite field. 1907 self.first_proto.ClearField('optional_nested_message') 1908 self.second_proto.optional_nested_message.ClearField('bb') 1909 self.assertNotEqual(self.first_proto, self.second_proto) 1910 self.first_proto.optional_nested_message.bb = 0 1911 self.first_proto.optional_nested_message.ClearField('bb') 1912 self.assertEqual(self.first_proto, self.second_proto) 1913 1914 1915class ExtensionEqualityTest(unittest.TestCase): 1916 1917 def testExtensionEquality(self): 1918 first_proto = unittest_pb2.TestAllExtensions() 1919 second_proto = unittest_pb2.TestAllExtensions() 1920 self.assertEqual(first_proto, second_proto) 1921 test_util.SetAllExtensions(first_proto) 1922 self.assertNotEqual(first_proto, second_proto) 1923 test_util.SetAllExtensions(second_proto) 1924 self.assertEqual(first_proto, second_proto) 1925 1926 # Ensure that we check value equality. 1927 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 1928 self.assertNotEqual(first_proto, second_proto) 1929 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 1930 self.assertEqual(first_proto, second_proto) 1931 1932 # Ensure that we also look at "has" bits. 1933 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 1934 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1935 self.assertNotEqual(first_proto, second_proto) 1936 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1937 self.assertEqual(first_proto, second_proto) 1938 1939 # Ensure that differences in cached values 1940 # don't matter if "has" bits are both false. 1941 first_proto = unittest_pb2.TestAllExtensions() 1942 second_proto = unittest_pb2.TestAllExtensions() 1943 self.assertEqual( 1944 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 1945 self.assertEqual(first_proto, second_proto) 1946 1947 1948class MutualRecursionEqualityTest(unittest.TestCase): 1949 1950 def testEqualityWithMutualRecursion(self): 1951 first_proto = unittest_pb2.TestMutualRecursionA() 1952 second_proto = unittest_pb2.TestMutualRecursionA() 1953 self.assertEqual(first_proto, second_proto) 1954 first_proto.bb.a.bb.optional_int32 = 23 1955 self.assertNotEqual(first_proto, second_proto) 1956 second_proto.bb.a.bb.optional_int32 = 23 1957 self.assertEqual(first_proto, second_proto) 1958 1959 1960class ByteSizeTest(unittest.TestCase): 1961 1962 def setUp(self): 1963 self.proto = unittest_pb2.TestAllTypes() 1964 self.extended_proto = more_extensions_pb2.ExtendedMessage() 1965 self.packed_proto = unittest_pb2.TestPackedTypes() 1966 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 1967 1968 def Size(self): 1969 return self.proto.ByteSize() 1970 1971 def testEmptyMessage(self): 1972 self.assertEqual(0, self.proto.ByteSize()) 1973 1974 def testSizedOnKwargs(self): 1975 # Use a separate message to ensure testing right after creation. 1976 proto = unittest_pb2.TestAllTypes() 1977 self.assertEqual(0, proto.ByteSize()) 1978 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 1979 # One byte for the tag, one to encode varint 1. 1980 self.assertEqual(2, proto_kwargs.ByteSize()) 1981 1982 def testVarints(self): 1983 def Test(i, expected_varint_size): 1984 self.proto.Clear() 1985 self.proto.optional_int64 = i 1986 # Add one to the varint size for the tag info 1987 # for tag 1. 1988 self.assertEqual(expected_varint_size + 1, self.Size()) 1989 Test(0, 1) 1990 Test(1, 1) 1991 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 1992 Test((1 << i) - 1, num_bytes) 1993 Test(-1, 10) 1994 Test(-2, 10) 1995 Test(-(1 << 63), 10) 1996 1997 def testStrings(self): 1998 self.proto.optional_string = '' 1999 # Need one byte for tag info (tag #14), and one byte for length. 2000 self.assertEqual(2, self.Size()) 2001 2002 self.proto.optional_string = 'abc' 2003 # Need one byte for tag info (tag #14), and one byte for length. 2004 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2005 2006 self.proto.optional_string = 'x' * 128 2007 # Need one byte for tag info (tag #14), and TWO bytes for length. 2008 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2009 2010 def testOtherNumerics(self): 2011 self.proto.optional_fixed32 = 1234 2012 # One byte for tag and 4 bytes for fixed32. 2013 self.assertEqual(5, self.Size()) 2014 self.proto = unittest_pb2.TestAllTypes() 2015 2016 self.proto.optional_fixed64 = 1234 2017 # One byte for tag and 8 bytes for fixed64. 2018 self.assertEqual(9, self.Size()) 2019 self.proto = unittest_pb2.TestAllTypes() 2020 2021 self.proto.optional_float = 1.234 2022 # One byte for tag and 4 bytes for float. 2023 self.assertEqual(5, self.Size()) 2024 self.proto = unittest_pb2.TestAllTypes() 2025 2026 self.proto.optional_double = 1.234 2027 # One byte for tag and 8 bytes for float. 2028 self.assertEqual(9, self.Size()) 2029 self.proto = unittest_pb2.TestAllTypes() 2030 2031 self.proto.optional_sint32 = 64 2032 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2033 self.assertEqual(3, self.Size()) 2034 self.proto = unittest_pb2.TestAllTypes() 2035 2036 def testComposites(self): 2037 # 3 bytes. 2038 self.proto.optional_nested_message.bb = (1 << 14) 2039 # Plus one byte for bb tag. 2040 # Plus 1 byte for optional_nested_message serialized size. 2041 # Plus two bytes for optional_nested_message tag. 2042 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2043 2044 def testGroups(self): 2045 # 4 bytes. 2046 self.proto.optionalgroup.a = (1 << 21) 2047 # Plus two bytes for |a| tag. 2048 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2049 self.assertEqual(4 + 2 + 2*2, self.Size()) 2050 2051 def testRepeatedScalars(self): 2052 self.proto.repeated_int32.append(10) # 1 byte. 2053 self.proto.repeated_int32.append(128) # 2 bytes. 2054 # Also need 2 bytes for each entry for tag. 2055 self.assertEqual(1 + 2 + 2*2, self.Size()) 2056 2057 def testRepeatedScalarsExtend(self): 2058 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2059 # Also need 2 bytes for each entry for tag. 2060 self.assertEqual(1 + 2 + 2*2, self.Size()) 2061 2062 def testRepeatedScalarsRemove(self): 2063 self.proto.repeated_int32.append(10) # 1 byte. 2064 self.proto.repeated_int32.append(128) # 2 bytes. 2065 # Also need 2 bytes for each entry for tag. 2066 self.assertEqual(1 + 2 + 2*2, self.Size()) 2067 self.proto.repeated_int32.remove(128) 2068 self.assertEqual(1 + 2, self.Size()) 2069 2070 def testRepeatedComposites(self): 2071 # Empty message. 2 bytes tag plus 1 byte length. 2072 foreign_message_0 = self.proto.repeated_nested_message.add() 2073 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2074 foreign_message_1 = self.proto.repeated_nested_message.add() 2075 foreign_message_1.bb = 7 2076 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2077 2078 def testRepeatedCompositesDelete(self): 2079 # Empty message. 2 bytes tag plus 1 byte length. 2080 foreign_message_0 = self.proto.repeated_nested_message.add() 2081 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2082 foreign_message_1 = self.proto.repeated_nested_message.add() 2083 foreign_message_1.bb = 9 2084 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2085 2086 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2087 del self.proto.repeated_nested_message[0] 2088 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2089 2090 # Now add a new message. 2091 foreign_message_2 = self.proto.repeated_nested_message.add() 2092 foreign_message_2.bb = 12 2093 2094 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2095 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2096 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2097 2098 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2099 del self.proto.repeated_nested_message[1] 2100 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2101 2102 del self.proto.repeated_nested_message[0] 2103 self.assertEqual(0, self.Size()) 2104 2105 def testRepeatedGroups(self): 2106 # 2-byte START_GROUP plus 2-byte END_GROUP. 2107 group_0 = self.proto.repeatedgroup.add() 2108 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2109 # plus 2-byte END_GROUP. 2110 group_1 = self.proto.repeatedgroup.add() 2111 group_1.a = 7 2112 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2113 2114 def testExtensions(self): 2115 proto = unittest_pb2.TestAllExtensions() 2116 self.assertEqual(0, proto.ByteSize()) 2117 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2118 proto.Extensions[extension] = 23 2119 # 1 byte for tag, 1 byte for value. 2120 self.assertEqual(2, proto.ByteSize()) 2121 2122 def testCacheInvalidationForNonrepeatedScalar(self): 2123 # Test non-extension. 2124 self.proto.optional_int32 = 1 2125 self.assertEqual(2, self.proto.ByteSize()) 2126 self.proto.optional_int32 = 128 2127 self.assertEqual(3, self.proto.ByteSize()) 2128 self.proto.ClearField('optional_int32') 2129 self.assertEqual(0, self.proto.ByteSize()) 2130 2131 # Test within extension. 2132 extension = more_extensions_pb2.optional_int_extension 2133 self.extended_proto.Extensions[extension] = 1 2134 self.assertEqual(2, self.extended_proto.ByteSize()) 2135 self.extended_proto.Extensions[extension] = 128 2136 self.assertEqual(3, self.extended_proto.ByteSize()) 2137 self.extended_proto.ClearExtension(extension) 2138 self.assertEqual(0, self.extended_proto.ByteSize()) 2139 2140 def testCacheInvalidationForRepeatedScalar(self): 2141 # Test non-extension. 2142 self.proto.repeated_int32.append(1) 2143 self.assertEqual(3, self.proto.ByteSize()) 2144 self.proto.repeated_int32.append(1) 2145 self.assertEqual(6, self.proto.ByteSize()) 2146 self.proto.repeated_int32[1] = 128 2147 self.assertEqual(7, self.proto.ByteSize()) 2148 self.proto.ClearField('repeated_int32') 2149 self.assertEqual(0, self.proto.ByteSize()) 2150 2151 # Test within extension. 2152 extension = more_extensions_pb2.repeated_int_extension 2153 repeated = self.extended_proto.Extensions[extension] 2154 repeated.append(1) 2155 self.assertEqual(2, self.extended_proto.ByteSize()) 2156 repeated.append(1) 2157 self.assertEqual(4, self.extended_proto.ByteSize()) 2158 repeated[1] = 128 2159 self.assertEqual(5, self.extended_proto.ByteSize()) 2160 self.extended_proto.ClearExtension(extension) 2161 self.assertEqual(0, self.extended_proto.ByteSize()) 2162 2163 def testCacheInvalidationForNonrepeatedMessage(self): 2164 # Test non-extension. 2165 self.proto.optional_foreign_message.c = 1 2166 self.assertEqual(5, self.proto.ByteSize()) 2167 self.proto.optional_foreign_message.c = 128 2168 self.assertEqual(6, self.proto.ByteSize()) 2169 self.proto.optional_foreign_message.ClearField('c') 2170 self.assertEqual(3, self.proto.ByteSize()) 2171 self.proto.ClearField('optional_foreign_message') 2172 self.assertEqual(0, self.proto.ByteSize()) 2173 2174 if api_implementation.Type() == 'python': 2175 # This is only possible in pure-Python implementation of the API. 2176 child = self.proto.optional_foreign_message 2177 self.proto.ClearField('optional_foreign_message') 2178 child.c = 128 2179 self.assertEqual(0, self.proto.ByteSize()) 2180 2181 # Test within extension. 2182 extension = more_extensions_pb2.optional_message_extension 2183 child = self.extended_proto.Extensions[extension] 2184 self.assertEqual(0, self.extended_proto.ByteSize()) 2185 child.foreign_message_int = 1 2186 self.assertEqual(4, self.extended_proto.ByteSize()) 2187 child.foreign_message_int = 128 2188 self.assertEqual(5, self.extended_proto.ByteSize()) 2189 self.extended_proto.ClearExtension(extension) 2190 self.assertEqual(0, self.extended_proto.ByteSize()) 2191 2192 def testCacheInvalidationForRepeatedMessage(self): 2193 # Test non-extension. 2194 child0 = self.proto.repeated_foreign_message.add() 2195 self.assertEqual(3, self.proto.ByteSize()) 2196 self.proto.repeated_foreign_message.add() 2197 self.assertEqual(6, self.proto.ByteSize()) 2198 child0.c = 1 2199 self.assertEqual(8, self.proto.ByteSize()) 2200 self.proto.ClearField('repeated_foreign_message') 2201 self.assertEqual(0, self.proto.ByteSize()) 2202 2203 # Test within extension. 2204 extension = more_extensions_pb2.repeated_message_extension 2205 child_list = self.extended_proto.Extensions[extension] 2206 child0 = child_list.add() 2207 self.assertEqual(2, self.extended_proto.ByteSize()) 2208 child_list.add() 2209 self.assertEqual(4, self.extended_proto.ByteSize()) 2210 child0.foreign_message_int = 1 2211 self.assertEqual(6, self.extended_proto.ByteSize()) 2212 child0.ClearField('foreign_message_int') 2213 self.assertEqual(4, self.extended_proto.ByteSize()) 2214 self.extended_proto.ClearExtension(extension) 2215 self.assertEqual(0, self.extended_proto.ByteSize()) 2216 2217 def testPackedRepeatedScalars(self): 2218 self.assertEqual(0, self.packed_proto.ByteSize()) 2219 2220 self.packed_proto.packed_int32.append(10) # 1 byte. 2221 self.packed_proto.packed_int32.append(128) # 2 bytes. 2222 # The tag is 2 bytes (the field number is 90), and the varint 2223 # storing the length is 1 byte. 2224 int_size = 1 + 2 + 3 2225 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2226 2227 self.packed_proto.packed_double.append(4.2) # 8 bytes 2228 self.packed_proto.packed_double.append(3.25) # 8 bytes 2229 # 2 more tag bytes, 1 more length byte. 2230 double_size = 8 + 8 + 3 2231 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2232 2233 self.packed_proto.ClearField('packed_int32') 2234 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2235 2236 def testPackedExtensions(self): 2237 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2238 extension = self.packed_extended_proto.Extensions[ 2239 unittest_pb2.packed_fixed32_extension] 2240 extension.extend([1, 2, 3, 4]) # 16 bytes 2241 # Tag is 3 bytes. 2242 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2243 2244 2245# Issues to be sure to cover include: 2246# * Handling of unrecognized tags ("uninterpreted_bytes"). 2247# * Handling of MessageSets. 2248# * Consistent ordering of tags in the wire format, 2249# including ordering between extensions and non-extension 2250# fields. 2251# * Consistent serialization of negative numbers, especially 2252# negative int32s. 2253# * Handling of empty submessages (with and without "has" 2254# bits set). 2255 2256class SerializationTest(unittest.TestCase): 2257 2258 def testSerializeEmtpyMessage(self): 2259 first_proto = unittest_pb2.TestAllTypes() 2260 second_proto = unittest_pb2.TestAllTypes() 2261 serialized = first_proto.SerializeToString() 2262 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2263 self.assertEqual( 2264 len(serialized), 2265 second_proto.MergeFromString(serialized)) 2266 self.assertEqual(first_proto, second_proto) 2267 2268 def testSerializeAllFields(self): 2269 first_proto = unittest_pb2.TestAllTypes() 2270 second_proto = unittest_pb2.TestAllTypes() 2271 test_util.SetAllFields(first_proto) 2272 serialized = first_proto.SerializeToString() 2273 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2274 self.assertEqual( 2275 len(serialized), 2276 second_proto.MergeFromString(serialized)) 2277 self.assertEqual(first_proto, second_proto) 2278 2279 def testSerializeAllExtensions(self): 2280 first_proto = unittest_pb2.TestAllExtensions() 2281 second_proto = unittest_pb2.TestAllExtensions() 2282 test_util.SetAllExtensions(first_proto) 2283 serialized = first_proto.SerializeToString() 2284 self.assertEqual( 2285 len(serialized), 2286 second_proto.MergeFromString(serialized)) 2287 self.assertEqual(first_proto, second_proto) 2288 2289 def testSerializeWithOptionalGroup(self): 2290 first_proto = unittest_pb2.TestAllTypes() 2291 second_proto = unittest_pb2.TestAllTypes() 2292 first_proto.optionalgroup.a = 242 2293 serialized = first_proto.SerializeToString() 2294 self.assertEqual( 2295 len(serialized), 2296 second_proto.MergeFromString(serialized)) 2297 self.assertEqual(first_proto, second_proto) 2298 2299 def testSerializeNegativeValues(self): 2300 first_proto = unittest_pb2.TestAllTypes() 2301 2302 first_proto.optional_int32 = -1 2303 first_proto.optional_int64 = -(2 << 40) 2304 first_proto.optional_sint32 = -3 2305 first_proto.optional_sint64 = -(4 << 40) 2306 first_proto.optional_sfixed32 = -5 2307 first_proto.optional_sfixed64 = -(6 << 40) 2308 2309 second_proto = unittest_pb2.TestAllTypes.FromString( 2310 first_proto.SerializeToString()) 2311 2312 self.assertEqual(first_proto, second_proto) 2313 2314 def testParseTruncated(self): 2315 # This test is only applicable for the Python implementation of the API. 2316 if api_implementation.Type() != 'python': 2317 return 2318 2319 first_proto = unittest_pb2.TestAllTypes() 2320 test_util.SetAllFields(first_proto) 2321 serialized = first_proto.SerializeToString() 2322 2323 for truncation_point in range(len(serialized) + 1): 2324 try: 2325 second_proto = unittest_pb2.TestAllTypes() 2326 unknown_fields = unittest_pb2.TestEmptyMessage() 2327 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2328 # If we didn't raise an error then we read exactly the amount expected. 2329 self.assertEqual(truncation_point, pos) 2330 2331 # Parsing to unknown fields should not throw if parsing to known fields 2332 # did not. 2333 try: 2334 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2335 self.assertEqual(truncation_point, pos2) 2336 except message.DecodeError: 2337 self.fail('Parsing unknown fields failed when parsing known fields ' 2338 'did not.') 2339 except message.DecodeError: 2340 # Parsing unknown fields should also fail. 2341 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2342 serialized, 0, truncation_point) 2343 2344 def testCanonicalSerializationOrder(self): 2345 proto = more_messages_pb2.OutOfOrderFields() 2346 # These are also their tag numbers. Even though we're setting these in 2347 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2348 # file, they should nonetheless be serialized in tag order. 2349 proto.optional_sint32 = 5 2350 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2351 proto.optional_uint32 = 3 2352 proto.Extensions[more_messages_pb2.optional_int64] = 2 2353 proto.optional_int32 = 1 2354 serialized = proto.SerializeToString() 2355 self.assertEqual(proto.ByteSize(), len(serialized)) 2356 d = _MiniDecoder(serialized) 2357 ReadTag = d.ReadFieldNumberAndWireType 2358 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2359 self.assertEqual(1, d.ReadInt32()) 2360 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2361 self.assertEqual(2, d.ReadInt64()) 2362 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2363 self.assertEqual(3, d.ReadUInt32()) 2364 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2365 self.assertEqual(4, d.ReadUInt64()) 2366 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2367 self.assertEqual(5, d.ReadSInt32()) 2368 2369 def testCanonicalSerializationOrderSameAsCpp(self): 2370 # Copy of the same test we use for C++. 2371 proto = unittest_pb2.TestFieldOrderings() 2372 test_util.SetAllFieldsAndExtensions(proto) 2373 serialized = proto.SerializeToString() 2374 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2375 2376 def testMergeFromStringWhenFieldsAlreadySet(self): 2377 first_proto = unittest_pb2.TestAllTypes() 2378 first_proto.repeated_string.append('foobar') 2379 first_proto.optional_int32 = 23 2380 first_proto.optional_nested_message.bb = 42 2381 serialized = first_proto.SerializeToString() 2382 2383 second_proto = unittest_pb2.TestAllTypes() 2384 second_proto.repeated_string.append('baz') 2385 second_proto.optional_int32 = 100 2386 second_proto.optional_nested_message.bb = 999 2387 2388 bytes_parsed = second_proto.MergeFromString(serialized) 2389 self.assertEqual(len(serialized), bytes_parsed) 2390 2391 # Ensure that we append to repeated fields. 2392 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2393 # Ensure that we overwrite nonrepeatd scalars. 2394 self.assertEqual(23, second_proto.optional_int32) 2395 # Ensure that we recursively call MergeFromString() on 2396 # submessages. 2397 self.assertEqual(42, second_proto.optional_nested_message.bb) 2398 2399 def testMessageSetWireFormat(self): 2400 proto = message_set_extensions_pb2.TestMessageSet() 2401 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2402 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2403 extension1 = extension_message1.message_set_extension 2404 extension2 = extension_message2.message_set_extension 2405 extension3 = message_set_extensions_pb2.message_set_extension3 2406 proto.Extensions[extension1].i = 123 2407 proto.Extensions[extension2].str = 'foo' 2408 proto.Extensions[extension3].text = 'bar' 2409 2410 # Serialize using the MessageSet wire format (this is specified in the 2411 # .proto file). 2412 serialized = proto.SerializeToString() 2413 2414 raw = unittest_mset_pb2.RawMessageSet() 2415 self.assertEqual(False, 2416 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2417 self.assertEqual( 2418 len(serialized), 2419 raw.MergeFromString(serialized)) 2420 self.assertEqual(3, len(raw.item)) 2421 2422 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2423 self.assertEqual( 2424 len(raw.item[0].message), 2425 message1.MergeFromString(raw.item[0].message)) 2426 self.assertEqual(123, message1.i) 2427 2428 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2429 self.assertEqual( 2430 len(raw.item[1].message), 2431 message2.MergeFromString(raw.item[1].message)) 2432 self.assertEqual('foo', message2.str) 2433 2434 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2435 self.assertEqual( 2436 len(raw.item[2].message), 2437 message3.MergeFromString(raw.item[2].message)) 2438 self.assertEqual('bar', message3.text) 2439 2440 # Deserialize using the MessageSet wire format. 2441 proto2 = message_set_extensions_pb2.TestMessageSet() 2442 self.assertEqual( 2443 len(serialized), 2444 proto2.MergeFromString(serialized)) 2445 self.assertEqual(123, proto2.Extensions[extension1].i) 2446 self.assertEqual('foo', proto2.Extensions[extension2].str) 2447 self.assertEqual('bar', proto2.Extensions[extension3].text) 2448 2449 # Check byte size. 2450 self.assertEqual(proto2.ByteSize(), len(serialized)) 2451 self.assertEqual(proto.ByteSize(), len(serialized)) 2452 2453 def testMessageSetWireFormatUnknownExtension(self): 2454 # Create a message using the message set wire format with an unknown 2455 # message. 2456 raw = unittest_mset_pb2.RawMessageSet() 2457 2458 # Add an item. 2459 item = raw.item.add() 2460 item.type_id = 98418603 2461 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2462 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2463 message1.i = 12345 2464 item.message = message1.SerializeToString() 2465 2466 # Add a second, unknown extension. 2467 item = raw.item.add() 2468 item.type_id = 98418604 2469 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2470 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2471 message1.i = 12346 2472 item.message = message1.SerializeToString() 2473 2474 # Add another unknown extension. 2475 item = raw.item.add() 2476 item.type_id = 98418605 2477 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2478 message1.str = 'foo' 2479 item.message = message1.SerializeToString() 2480 2481 serialized = raw.SerializeToString() 2482 2483 # Parse message using the message set wire format. 2484 proto = message_set_extensions_pb2.TestMessageSet() 2485 self.assertEqual( 2486 len(serialized), 2487 proto.MergeFromString(serialized)) 2488 2489 # Check that the message parsed well. 2490 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2491 extension1 = extension_message1.message_set_extension 2492 self.assertEqual(12345, proto.Extensions[extension1].i) 2493 2494 def testUnknownFields(self): 2495 proto = unittest_pb2.TestAllTypes() 2496 test_util.SetAllFields(proto) 2497 2498 serialized = proto.SerializeToString() 2499 2500 # The empty message should be parsable with all of the fields 2501 # unknown. 2502 proto2 = unittest_pb2.TestEmptyMessage() 2503 2504 # Parsing this message should succeed. 2505 self.assertEqual( 2506 len(serialized), 2507 proto2.MergeFromString(serialized)) 2508 2509 # Now test with a int64 field set. 2510 proto = unittest_pb2.TestAllTypes() 2511 proto.optional_int64 = 0x0fffffffffffffff 2512 serialized = proto.SerializeToString() 2513 # The empty message should be parsable with all of the fields 2514 # unknown. 2515 proto2 = unittest_pb2.TestEmptyMessage() 2516 # Parsing this message should succeed. 2517 self.assertEqual( 2518 len(serialized), 2519 proto2.MergeFromString(serialized)) 2520 2521 def _CheckRaises(self, exc_class, callable_obj, exception): 2522 """This method checks if the excpetion type and message are as expected.""" 2523 try: 2524 callable_obj() 2525 except exc_class as ex: 2526 # Check if the exception message is the right one. 2527 self.assertEqual(exception, str(ex)) 2528 return 2529 else: 2530 raise self.failureException('%s not raised' % str(exc_class)) 2531 2532 def testSerializeUninitialized(self): 2533 proto = unittest_pb2.TestRequired() 2534 self._CheckRaises( 2535 message.EncodeError, 2536 proto.SerializeToString, 2537 'Message protobuf_unittest.TestRequired is missing required fields: ' 2538 'a,b,c') 2539 # Shouldn't raise exceptions. 2540 partial = proto.SerializePartialToString() 2541 2542 proto2 = unittest_pb2.TestRequired() 2543 self.assertFalse(proto2.HasField('a')) 2544 # proto2 ParseFromString does not check that required fields are set. 2545 proto2.ParseFromString(partial) 2546 self.assertFalse(proto2.HasField('a')) 2547 2548 proto.a = 1 2549 self._CheckRaises( 2550 message.EncodeError, 2551 proto.SerializeToString, 2552 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2553 # Shouldn't raise exceptions. 2554 partial = proto.SerializePartialToString() 2555 2556 proto.b = 2 2557 self._CheckRaises( 2558 message.EncodeError, 2559 proto.SerializeToString, 2560 'Message protobuf_unittest.TestRequired is missing required fields: c') 2561 # Shouldn't raise exceptions. 2562 partial = proto.SerializePartialToString() 2563 2564 proto.c = 3 2565 serialized = proto.SerializeToString() 2566 # Shouldn't raise exceptions. 2567 partial = proto.SerializePartialToString() 2568 2569 proto2 = unittest_pb2.TestRequired() 2570 self.assertEqual( 2571 len(serialized), 2572 proto2.MergeFromString(serialized)) 2573 self.assertEqual(1, proto2.a) 2574 self.assertEqual(2, proto2.b) 2575 self.assertEqual(3, proto2.c) 2576 self.assertEqual( 2577 len(partial), 2578 proto2.MergeFromString(partial)) 2579 self.assertEqual(1, proto2.a) 2580 self.assertEqual(2, proto2.b) 2581 self.assertEqual(3, proto2.c) 2582 2583 def testSerializeUninitializedSubMessage(self): 2584 proto = unittest_pb2.TestRequiredForeign() 2585 2586 # Sub-message doesn't exist yet, so this succeeds. 2587 proto.SerializeToString() 2588 2589 proto.optional_message.a = 1 2590 self._CheckRaises( 2591 message.EncodeError, 2592 proto.SerializeToString, 2593 'Message protobuf_unittest.TestRequiredForeign ' 2594 'is missing required fields: ' 2595 'optional_message.b,optional_message.c') 2596 2597 proto.optional_message.b = 2 2598 proto.optional_message.c = 3 2599 proto.SerializeToString() 2600 2601 proto.repeated_message.add().a = 1 2602 proto.repeated_message.add().b = 2 2603 self._CheckRaises( 2604 message.EncodeError, 2605 proto.SerializeToString, 2606 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2607 'repeated_message[0].b,repeated_message[0].c,' 2608 'repeated_message[1].a,repeated_message[1].c') 2609 2610 proto.repeated_message[0].b = 2 2611 proto.repeated_message[0].c = 3 2612 proto.repeated_message[1].a = 1 2613 proto.repeated_message[1].c = 3 2614 proto.SerializeToString() 2615 2616 def testSerializeAllPackedFields(self): 2617 first_proto = unittest_pb2.TestPackedTypes() 2618 second_proto = unittest_pb2.TestPackedTypes() 2619 test_util.SetAllPackedFields(first_proto) 2620 serialized = first_proto.SerializeToString() 2621 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2622 bytes_read = second_proto.MergeFromString(serialized) 2623 self.assertEqual(second_proto.ByteSize(), bytes_read) 2624 self.assertEqual(first_proto, second_proto) 2625 2626 def testSerializeAllPackedExtensions(self): 2627 first_proto = unittest_pb2.TestPackedExtensions() 2628 second_proto = unittest_pb2.TestPackedExtensions() 2629 test_util.SetAllPackedExtensions(first_proto) 2630 serialized = first_proto.SerializeToString() 2631 bytes_read = second_proto.MergeFromString(serialized) 2632 self.assertEqual(second_proto.ByteSize(), bytes_read) 2633 self.assertEqual(first_proto, second_proto) 2634 2635 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2636 first_proto = unittest_pb2.TestPackedTypes() 2637 first_proto.packed_int32.extend([1, 2]) 2638 first_proto.packed_double.append(3.0) 2639 serialized = first_proto.SerializeToString() 2640 2641 second_proto = unittest_pb2.TestPackedTypes() 2642 second_proto.packed_int32.append(3) 2643 second_proto.packed_double.extend([1.0, 2.0]) 2644 second_proto.packed_sint32.append(4) 2645 2646 self.assertEqual( 2647 len(serialized), 2648 second_proto.MergeFromString(serialized)) 2649 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2650 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2651 self.assertEqual([4], second_proto.packed_sint32) 2652 2653 def testPackedFieldsWireFormat(self): 2654 proto = unittest_pb2.TestPackedTypes() 2655 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2656 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2657 proto.packed_float.append(2.0) # 4 bytes, will be before double 2658 serialized = proto.SerializeToString() 2659 self.assertEqual(proto.ByteSize(), len(serialized)) 2660 d = _MiniDecoder(serialized) 2661 ReadTag = d.ReadFieldNumberAndWireType 2662 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2663 self.assertEqual(1+1+1+2, d.ReadInt32()) 2664 self.assertEqual(1, d.ReadInt32()) 2665 self.assertEqual(2, d.ReadInt32()) 2666 self.assertEqual(150, d.ReadInt32()) 2667 self.assertEqual(3, d.ReadInt32()) 2668 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2669 self.assertEqual(4, d.ReadInt32()) 2670 self.assertEqual(2.0, d.ReadFloat()) 2671 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2672 self.assertEqual(8+8, d.ReadInt32()) 2673 self.assertEqual(1.0, d.ReadDouble()) 2674 self.assertEqual(1000.0, d.ReadDouble()) 2675 self.assertTrue(d.EndOfStream()) 2676 2677 def testParsePackedFromUnpacked(self): 2678 unpacked = unittest_pb2.TestUnpackedTypes() 2679 test_util.SetAllUnpackedFields(unpacked) 2680 packed = unittest_pb2.TestPackedTypes() 2681 serialized = unpacked.SerializeToString() 2682 self.assertEqual( 2683 len(serialized), 2684 packed.MergeFromString(serialized)) 2685 expected = unittest_pb2.TestPackedTypes() 2686 test_util.SetAllPackedFields(expected) 2687 self.assertEqual(expected, packed) 2688 2689 def testParseUnpackedFromPacked(self): 2690 packed = unittest_pb2.TestPackedTypes() 2691 test_util.SetAllPackedFields(packed) 2692 unpacked = unittest_pb2.TestUnpackedTypes() 2693 serialized = packed.SerializeToString() 2694 self.assertEqual( 2695 len(serialized), 2696 unpacked.MergeFromString(serialized)) 2697 expected = unittest_pb2.TestUnpackedTypes() 2698 test_util.SetAllUnpackedFields(expected) 2699 self.assertEqual(expected, unpacked) 2700 2701 def testFieldNumbers(self): 2702 proto = unittest_pb2.TestAllTypes() 2703 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2704 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2705 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 2706 self.assertEqual( 2707 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 2708 self.assertEqual( 2709 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 2710 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 2711 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 2712 self.assertEqual( 2713 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 2714 self.assertEqual( 2715 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 2716 2717 def testExtensionFieldNumbers(self): 2718 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 2719 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 2720 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 2721 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 2722 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 2723 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 2724 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 2725 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 2726 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 2727 self.assertEqual( 2728 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 2729 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 2730 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2731 21) 2732 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 2733 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 2734 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 2735 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 2736 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 2737 self.assertEqual( 2738 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 2739 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 2740 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2741 51) 2742 2743 def testInitKwargs(self): 2744 proto = unittest_pb2.TestAllTypes( 2745 optional_int32=1, 2746 optional_string='foo', 2747 optional_bool=True, 2748 optional_bytes=b'bar', 2749 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 2750 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 2751 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 2752 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 2753 repeated_int32=[1, 2, 3]) 2754 self.assertTrue(proto.IsInitialized()) 2755 self.assertTrue(proto.HasField('optional_int32')) 2756 self.assertTrue(proto.HasField('optional_string')) 2757 self.assertTrue(proto.HasField('optional_bool')) 2758 self.assertTrue(proto.HasField('optional_bytes')) 2759 self.assertTrue(proto.HasField('optional_nested_message')) 2760 self.assertTrue(proto.HasField('optional_foreign_message')) 2761 self.assertTrue(proto.HasField('optional_nested_enum')) 2762 self.assertTrue(proto.HasField('optional_foreign_enum')) 2763 self.assertEqual(1, proto.optional_int32) 2764 self.assertEqual('foo', proto.optional_string) 2765 self.assertEqual(True, proto.optional_bool) 2766 self.assertEqual(b'bar', proto.optional_bytes) 2767 self.assertEqual(1, proto.optional_nested_message.bb) 2768 self.assertEqual(1, proto.optional_foreign_message.c) 2769 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 2770 proto.optional_nested_enum) 2771 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 2772 self.assertEqual([1, 2, 3], proto.repeated_int32) 2773 2774 def testInitArgsUnknownFieldName(self): 2775 def InitalizeEmptyMessageWithExtraKeywordArg(): 2776 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 2777 self._CheckRaises( 2778 ValueError, 2779 InitalizeEmptyMessageWithExtraKeywordArg, 2780 'Protocol message TestEmptyMessage has no "unknown" field.') 2781 2782 def testInitRequiredKwargs(self): 2783 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 2784 self.assertTrue(proto.IsInitialized()) 2785 self.assertTrue(proto.HasField('a')) 2786 self.assertTrue(proto.HasField('b')) 2787 self.assertTrue(proto.HasField('c')) 2788 self.assertTrue(not proto.HasField('dummy2')) 2789 self.assertEqual(1, proto.a) 2790 self.assertEqual(1, proto.b) 2791 self.assertEqual(1, proto.c) 2792 2793 def testInitRequiredForeignKwargs(self): 2794 proto = unittest_pb2.TestRequiredForeign( 2795 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 2796 self.assertTrue(proto.IsInitialized()) 2797 self.assertTrue(proto.HasField('optional_message')) 2798 self.assertTrue(proto.optional_message.IsInitialized()) 2799 self.assertTrue(proto.optional_message.HasField('a')) 2800 self.assertTrue(proto.optional_message.HasField('b')) 2801 self.assertTrue(proto.optional_message.HasField('c')) 2802 self.assertTrue(not proto.optional_message.HasField('dummy2')) 2803 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 2804 proto.optional_message) 2805 self.assertEqual(1, proto.optional_message.a) 2806 self.assertEqual(1, proto.optional_message.b) 2807 self.assertEqual(1, proto.optional_message.c) 2808 2809 def testInitRepeatedKwargs(self): 2810 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 2811 self.assertTrue(proto.IsInitialized()) 2812 self.assertEqual(1, proto.repeated_int32[0]) 2813 self.assertEqual(2, proto.repeated_int32[1]) 2814 self.assertEqual(3, proto.repeated_int32[2]) 2815 2816 2817class OptionsTest(unittest.TestCase): 2818 2819 def testMessageOptions(self): 2820 proto = message_set_extensions_pb2.TestMessageSet() 2821 self.assertEqual(True, 2822 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2823 proto = unittest_pb2.TestAllTypes() 2824 self.assertEqual(False, 2825 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2826 2827 def testPackedOptions(self): 2828 proto = unittest_pb2.TestAllTypes() 2829 proto.optional_int32 = 1 2830 proto.optional_double = 3.0 2831 for field_descriptor, _ in proto.ListFields(): 2832 self.assertEqual(False, field_descriptor.GetOptions().packed) 2833 2834 proto = unittest_pb2.TestPackedTypes() 2835 proto.packed_int32.append(1) 2836 proto.packed_double.append(3.0) 2837 for field_descriptor, _ in proto.ListFields(): 2838 self.assertEqual(True, field_descriptor.GetOptions().packed) 2839 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 2840 field_descriptor.label) 2841 2842 2843 2844class ClassAPITest(unittest.TestCase): 2845 2846 @unittest.skipIf( 2847 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 2848 'C++ implementation requires a call to MakeDescriptor()') 2849 def testMakeClassWithNestedDescriptor(self): 2850 leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', 2851 containing_type=None, fields=[], 2852 nested_types=[], enum_types=[], 2853 extensions=[]) 2854 child_desc = descriptor.Descriptor('child', 'package.parent.child', '', 2855 containing_type=None, fields=[], 2856 nested_types=[leaf_desc], enum_types=[], 2857 extensions=[]) 2858 sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', 2859 '', containing_type=None, fields=[], 2860 nested_types=[], enum_types=[], 2861 extensions=[]) 2862 parent_desc = descriptor.Descriptor('parent', 'package.parent', '', 2863 containing_type=None, fields=[], 2864 nested_types=[child_desc, sibling_desc], 2865 enum_types=[], extensions=[]) 2866 message_class = reflection.MakeClass(parent_desc) 2867 self.assertIn('child', message_class.__dict__) 2868 self.assertIn('sibling', message_class.__dict__) 2869 self.assertIn('leaf', message_class.child.__dict__) 2870 2871 def _GetSerializedFileDescriptor(self, name): 2872 """Get a serialized representation of a test FileDescriptorProto. 2873 2874 Args: 2875 name: All calls to this must use a unique message name, to avoid 2876 collisions in the cpp descriptor pool. 2877 Returns: 2878 A string containing the serialized form of a test FileDescriptorProto. 2879 """ 2880 file_descriptor_str = ( 2881 'message_type {' 2882 ' name: "' + name + '"' 2883 ' field {' 2884 ' name: "flat"' 2885 ' number: 1' 2886 ' label: LABEL_REPEATED' 2887 ' type: TYPE_UINT32' 2888 ' }' 2889 ' field {' 2890 ' name: "bar"' 2891 ' number: 2' 2892 ' label: LABEL_OPTIONAL' 2893 ' type: TYPE_MESSAGE' 2894 ' type_name: "Bar"' 2895 ' }' 2896 ' nested_type {' 2897 ' name: "Bar"' 2898 ' field {' 2899 ' name: "baz"' 2900 ' number: 3' 2901 ' label: LABEL_OPTIONAL' 2902 ' type: TYPE_MESSAGE' 2903 ' type_name: "Baz"' 2904 ' }' 2905 ' nested_type {' 2906 ' name: "Baz"' 2907 ' enum_type {' 2908 ' name: "deep_enum"' 2909 ' value {' 2910 ' name: "VALUE_A"' 2911 ' number: 0' 2912 ' }' 2913 ' }' 2914 ' field {' 2915 ' name: "deep"' 2916 ' number: 4' 2917 ' label: LABEL_OPTIONAL' 2918 ' type: TYPE_UINT32' 2919 ' }' 2920 ' }' 2921 ' }' 2922 '}') 2923 file_descriptor = descriptor_pb2.FileDescriptorProto() 2924 text_format.Merge(file_descriptor_str, file_descriptor) 2925 return file_descriptor.SerializeToString() 2926 2927 def testParsingFlatClassWithExplicitClassDeclaration(self): 2928 """Test that the generated class can parse a flat message.""" 2929 # TODO(xiaofeng): This test fails with cpp implemetnation in the call 2930 # of six.with_metaclass(). The other two callsites of with_metaclass 2931 # in this file are both excluded from cpp test, so it might be expected 2932 # to fail. Need someone more familiar with the python code to take a 2933 # look at this. 2934 if api_implementation.Type() != 'python': 2935 return 2936 file_descriptor = descriptor_pb2.FileDescriptorProto() 2937 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 2938 msg_descriptor = descriptor.MakeDescriptor( 2939 file_descriptor.message_type[0]) 2940 2941 class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 2942 DESCRIPTOR = msg_descriptor 2943 msg = MessageClass() 2944 msg_str = ( 2945 'flat: 0 ' 2946 'flat: 1 ' 2947 'flat: 2 ') 2948 text_format.Merge(msg_str, msg) 2949 self.assertEqual(msg.flat, [0, 1, 2]) 2950 2951 def testParsingFlatClass(self): 2952 """Test that the generated class can parse a flat message.""" 2953 file_descriptor = descriptor_pb2.FileDescriptorProto() 2954 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 2955 msg_descriptor = descriptor.MakeDescriptor( 2956 file_descriptor.message_type[0]) 2957 msg_class = reflection.MakeClass(msg_descriptor) 2958 msg = msg_class() 2959 msg_str = ( 2960 'flat: 0 ' 2961 'flat: 1 ' 2962 'flat: 2 ') 2963 text_format.Merge(msg_str, msg) 2964 self.assertEqual(msg.flat, [0, 1, 2]) 2965 2966 def testParsingNestedClass(self): 2967 """Test that the generated class can parse a nested message.""" 2968 file_descriptor = descriptor_pb2.FileDescriptorProto() 2969 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 2970 msg_descriptor = descriptor.MakeDescriptor( 2971 file_descriptor.message_type[0]) 2972 msg_class = reflection.MakeClass(msg_descriptor) 2973 msg = msg_class() 2974 msg_str = ( 2975 'bar {' 2976 ' baz {' 2977 ' deep: 4' 2978 ' }' 2979 '}') 2980 text_format.Merge(msg_str, msg) 2981 self.assertEqual(msg.bar.baz.deep, 4) 2982 2983if __name__ == '__main__': 2984 unittest.main() 2985