1#! /usr/bin/python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38__author__ = 'robinson@google.com (Will Robinson)' 39 40import copy 41import gc 42import operator 43import struct 44 45from google.apputils import basetest 46from google.protobuf import unittest_import_pb2 47from google.protobuf import unittest_mset_pb2 48from google.protobuf import unittest_pb2 49from google.protobuf import descriptor_pb2 50from google.protobuf import descriptor 51from google.protobuf import message 52from google.protobuf import reflection 53from google.protobuf import text_format 54from google.protobuf.internal import api_implementation 55from google.protobuf.internal import more_extensions_pb2 56from google.protobuf.internal import more_messages_pb2 57from google.protobuf.internal import wire_format 58from google.protobuf.internal import test_util 59from google.protobuf.internal import decoder 60 61 62class _MiniDecoder(object): 63 """Decodes a stream of values from a string. 64 65 Once upon a time we actually had a class called decoder.Decoder. Then we 66 got rid of it during a redesign that made decoding much, much faster overall. 67 But a couple tests in this file used it to check that the serialized form of 68 a message was correct. So, this class implements just the methods that were 69 used by said tests, so that we don't have to rewrite the tests. 70 """ 71 72 def __init__(self, bytes): 73 self._bytes = bytes 74 self._pos = 0 75 76 def ReadVarint(self): 77 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 78 return result 79 80 ReadInt32 = ReadVarint 81 ReadInt64 = ReadVarint 82 ReadUInt32 = ReadVarint 83 ReadUInt64 = ReadVarint 84 85 def ReadSInt64(self): 86 return wire_format.ZigZagDecode(self.ReadVarint()) 87 88 ReadSInt32 = ReadSInt64 89 90 def ReadFieldNumberAndWireType(self): 91 return wire_format.UnpackTag(self.ReadVarint()) 92 93 def ReadFloat(self): 94 result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] 95 self._pos += 4 96 return result 97 98 def ReadDouble(self): 99 result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] 100 self._pos += 8 101 return result 102 103 def EndOfStream(self): 104 return self._pos == len(self._bytes) 105 106 107class ReflectionTest(basetest.TestCase): 108 109 def assertListsEqual(self, values, others): 110 self.assertEqual(len(values), len(others)) 111 for i in range(len(values)): 112 self.assertEqual(values[i], others[i]) 113 114 def testScalarConstructor(self): 115 # Constructor with only scalar types should succeed. 116 proto = unittest_pb2.TestAllTypes( 117 optional_int32=24, 118 optional_double=54.321, 119 optional_string='optional_string') 120 121 self.assertEqual(24, proto.optional_int32) 122 self.assertEqual(54.321, proto.optional_double) 123 self.assertEqual('optional_string', proto.optional_string) 124 125 def testRepeatedScalarConstructor(self): 126 # Constructor with only repeated scalar types should succeed. 127 proto = unittest_pb2.TestAllTypes( 128 repeated_int32=[1, 2, 3, 4], 129 repeated_double=[1.23, 54.321], 130 repeated_bool=[True, False, False], 131 repeated_string=["optional_string"]) 132 133 self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) 134 self.assertEquals([1.23, 54.321], list(proto.repeated_double)) 135 self.assertEquals([True, False, False], list(proto.repeated_bool)) 136 self.assertEquals(["optional_string"], list(proto.repeated_string)) 137 138 def testRepeatedCompositeConstructor(self): 139 # Constructor with only repeated composite types should succeed. 140 proto = unittest_pb2.TestAllTypes( 141 repeated_nested_message=[ 142 unittest_pb2.TestAllTypes.NestedMessage( 143 bb=unittest_pb2.TestAllTypes.FOO), 144 unittest_pb2.TestAllTypes.NestedMessage( 145 bb=unittest_pb2.TestAllTypes.BAR)], 146 repeated_foreign_message=[ 147 unittest_pb2.ForeignMessage(c=-43), 148 unittest_pb2.ForeignMessage(c=45324), 149 unittest_pb2.ForeignMessage(c=12)], 150 repeatedgroup=[ 151 unittest_pb2.TestAllTypes.RepeatedGroup(), 152 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 153 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 154 155 self.assertEquals( 156 [unittest_pb2.TestAllTypes.NestedMessage( 157 bb=unittest_pb2.TestAllTypes.FOO), 158 unittest_pb2.TestAllTypes.NestedMessage( 159 bb=unittest_pb2.TestAllTypes.BAR)], 160 list(proto.repeated_nested_message)) 161 self.assertEquals( 162 [unittest_pb2.ForeignMessage(c=-43), 163 unittest_pb2.ForeignMessage(c=45324), 164 unittest_pb2.ForeignMessage(c=12)], 165 list(proto.repeated_foreign_message)) 166 self.assertEquals( 167 [unittest_pb2.TestAllTypes.RepeatedGroup(), 168 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 169 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 170 list(proto.repeatedgroup)) 171 172 def testMixedConstructor(self): 173 # Constructor with only mixed types should succeed. 174 proto = unittest_pb2.TestAllTypes( 175 optional_int32=24, 176 optional_string='optional_string', 177 repeated_double=[1.23, 54.321], 178 repeated_bool=[True, False, False], 179 repeated_nested_message=[ 180 unittest_pb2.TestAllTypes.NestedMessage( 181 bb=unittest_pb2.TestAllTypes.FOO), 182 unittest_pb2.TestAllTypes.NestedMessage( 183 bb=unittest_pb2.TestAllTypes.BAR)], 184 repeated_foreign_message=[ 185 unittest_pb2.ForeignMessage(c=-43), 186 unittest_pb2.ForeignMessage(c=45324), 187 unittest_pb2.ForeignMessage(c=12)]) 188 189 self.assertEqual(24, proto.optional_int32) 190 self.assertEqual('optional_string', proto.optional_string) 191 self.assertEquals([1.23, 54.321], list(proto.repeated_double)) 192 self.assertEquals([True, False, False], list(proto.repeated_bool)) 193 self.assertEquals( 194 [unittest_pb2.TestAllTypes.NestedMessage( 195 bb=unittest_pb2.TestAllTypes.FOO), 196 unittest_pb2.TestAllTypes.NestedMessage( 197 bb=unittest_pb2.TestAllTypes.BAR)], 198 list(proto.repeated_nested_message)) 199 self.assertEquals( 200 [unittest_pb2.ForeignMessage(c=-43), 201 unittest_pb2.ForeignMessage(c=45324), 202 unittest_pb2.ForeignMessage(c=12)], 203 list(proto.repeated_foreign_message)) 204 205 def testConstructorTypeError(self): 206 self.assertRaises( 207 TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") 208 self.assertRaises( 209 TypeError, unittest_pb2.TestAllTypes, optional_string=1234) 210 self.assertRaises( 211 TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) 212 self.assertRaises( 213 TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) 214 self.assertRaises( 215 TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) 216 self.assertRaises( 217 TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) 218 self.assertRaises( 219 TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) 220 self.assertRaises( 221 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) 222 self.assertRaises( 223 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) 224 225 def testConstructorInvalidatesCachedByteSize(self): 226 message = unittest_pb2.TestAllTypes(optional_int32 = 12) 227 self.assertEquals(2, message.ByteSize()) 228 229 message = unittest_pb2.TestAllTypes( 230 optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) 231 self.assertEquals(3, message.ByteSize()) 232 233 message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) 234 self.assertEquals(3, message.ByteSize()) 235 236 message = unittest_pb2.TestAllTypes( 237 repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) 238 self.assertEquals(3, message.ByteSize()) 239 240 def testSimpleHasBits(self): 241 # Test a scalar. 242 proto = unittest_pb2.TestAllTypes() 243 self.assertTrue(not proto.HasField('optional_int32')) 244 self.assertEqual(0, proto.optional_int32) 245 # HasField() shouldn't be true if all we've done is 246 # read the default value. 247 self.assertTrue(not proto.HasField('optional_int32')) 248 proto.optional_int32 = 1 249 # Setting a value however *should* set the "has" bit. 250 self.assertTrue(proto.HasField('optional_int32')) 251 proto.ClearField('optional_int32') 252 # And clearing that value should unset the "has" bit. 253 self.assertTrue(not proto.HasField('optional_int32')) 254 255 def testHasBitsWithSinglyNestedScalar(self): 256 # Helper used to test foreign messages and groups. 257 # 258 # composite_field_name should be the name of a non-repeated 259 # composite (i.e., foreign or group) field in TestAllTypes, 260 # and scalar_field_name should be the name of an integer-valued 261 # scalar field within that composite. 262 # 263 # I never thought I'd miss C++ macros and templates so much. :( 264 # This helper is semantically just: 265 # 266 # assert proto.composite_field.scalar_field == 0 267 # assert not proto.composite_field.HasField('scalar_field') 268 # assert not proto.HasField('composite_field') 269 # 270 # proto.composite_field.scalar_field = 10 271 # old_composite_field = proto.composite_field 272 # 273 # assert proto.composite_field.scalar_field == 10 274 # assert proto.composite_field.HasField('scalar_field') 275 # assert proto.HasField('composite_field') 276 # 277 # proto.ClearField('composite_field') 278 # 279 # assert not proto.composite_field.HasField('scalar_field') 280 # assert not proto.HasField('composite_field') 281 # assert proto.composite_field.scalar_field == 0 282 # 283 # # Now ensure that ClearField('composite_field') disconnected 284 # # the old field object from the object tree... 285 # assert old_composite_field is not proto.composite_field 286 # old_composite_field.scalar_field = 20 287 # assert not proto.composite_field.HasField('scalar_field') 288 # assert not proto.HasField('composite_field') 289 def TestCompositeHasBits(composite_field_name, scalar_field_name): 290 proto = unittest_pb2.TestAllTypes() 291 # First, check that we can get the scalar value, and see that it's the 292 # default (0), but that proto.HasField('omposite') and 293 # proto.composite.HasField('scalar') will still return False. 294 composite_field = getattr(proto, composite_field_name) 295 original_scalar_value = getattr(composite_field, scalar_field_name) 296 self.assertEqual(0, original_scalar_value) 297 # Assert that the composite object does not "have" the scalar. 298 self.assertTrue(not composite_field.HasField(scalar_field_name)) 299 # Assert that proto does not "have" the composite field. 300 self.assertTrue(not proto.HasField(composite_field_name)) 301 302 # Now set the scalar within the composite field. Ensure that the setting 303 # is reflected, and that proto.HasField('composite') and 304 # proto.composite.HasField('scalar') now both return True. 305 new_val = 20 306 setattr(composite_field, scalar_field_name, new_val) 307 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 308 # Hold on to a reference to the current composite_field object. 309 old_composite_field = composite_field 310 # Assert that the has methods now return true. 311 self.assertTrue(composite_field.HasField(scalar_field_name)) 312 self.assertTrue(proto.HasField(composite_field_name)) 313 314 # Now call the clear method... 315 proto.ClearField(composite_field_name) 316 317 # ...and ensure that the "has" bits are all back to False... 318 composite_field = getattr(proto, composite_field_name) 319 self.assertTrue(not composite_field.HasField(scalar_field_name)) 320 self.assertTrue(not proto.HasField(composite_field_name)) 321 # ...and ensure that the scalar field has returned to its default. 322 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 323 324 self.assertTrue(old_composite_field is not composite_field) 325 setattr(old_composite_field, scalar_field_name, new_val) 326 self.assertTrue(not composite_field.HasField(scalar_field_name)) 327 self.assertTrue(not proto.HasField(composite_field_name)) 328 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 329 330 # Test simple, single-level nesting when we set a scalar. 331 TestCompositeHasBits('optionalgroup', 'a') 332 TestCompositeHasBits('optional_nested_message', 'bb') 333 TestCompositeHasBits('optional_foreign_message', 'c') 334 TestCompositeHasBits('optional_import_message', 'd') 335 336 def testReferencesToNestedMessage(self): 337 proto = unittest_pb2.TestAllTypes() 338 nested = proto.optional_nested_message 339 del proto 340 # A previous version had a bug where this would raise an exception when 341 # hitting a now-dead weak reference. 342 nested.bb = 23 343 344 def testDisconnectingNestedMessageBeforeSettingField(self): 345 proto = unittest_pb2.TestAllTypes() 346 nested = proto.optional_nested_message 347 proto.ClearField('optional_nested_message') # Should disconnect from parent 348 self.assertTrue(nested is not proto.optional_nested_message) 349 nested.bb = 23 350 self.assertTrue(not proto.HasField('optional_nested_message')) 351 self.assertEqual(0, proto.optional_nested_message.bb) 352 353 def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): 354 proto = unittest_pb2.TestAllTypes() 355 nested = proto.optional_nested_message 356 proto.ClearField('optional_nested_message') 357 del proto 358 del nested 359 # Force a garbage collect so that the underlying CMessages are freed along 360 # with the Messages they point to. This is to make sure we're not deleting 361 # default message instances. 362 gc.collect() 363 proto = unittest_pb2.TestAllTypes() 364 nested = proto.optional_nested_message 365 366 def testDisconnectingNestedMessageAfterSettingField(self): 367 proto = unittest_pb2.TestAllTypes() 368 nested = proto.optional_nested_message 369 nested.bb = 5 370 self.assertTrue(proto.HasField('optional_nested_message')) 371 proto.ClearField('optional_nested_message') # Should disconnect from parent 372 self.assertEqual(5, nested.bb) 373 self.assertEqual(0, proto.optional_nested_message.bb) 374 self.assertTrue(nested is not proto.optional_nested_message) 375 nested.bb = 23 376 self.assertTrue(not proto.HasField('optional_nested_message')) 377 self.assertEqual(0, proto.optional_nested_message.bb) 378 379 def testDisconnectingNestedMessageBeforeGettingField(self): 380 proto = unittest_pb2.TestAllTypes() 381 self.assertTrue(not proto.HasField('optional_nested_message')) 382 proto.ClearField('optional_nested_message') 383 self.assertTrue(not proto.HasField('optional_nested_message')) 384 385 def testDisconnectingNestedMessageAfterMerge(self): 386 # This test exercises the code path that does not use ReleaseMessage(). 387 # The underlying fear is that if we use ReleaseMessage() incorrectly, 388 # we will have memory leaks. It's hard to check that that doesn't happen, 389 # but at least we can exercise that code path to make sure it works. 390 proto1 = unittest_pb2.TestAllTypes() 391 proto2 = unittest_pb2.TestAllTypes() 392 proto2.optional_nested_message.bb = 5 393 proto1.MergeFrom(proto2) 394 self.assertTrue(proto1.HasField('optional_nested_message')) 395 proto1.ClearField('optional_nested_message') 396 self.assertTrue(not proto1.HasField('optional_nested_message')) 397 398 def testDisconnectingLazyNestedMessage(self): 399 # This test exercises releasing a nested message that is lazy. This test 400 # only exercises real code in the C++ implementation as Python does not 401 # support lazy parsing, but the current C++ implementation results in 402 # memory corruption and a crash. 403 if api_implementation.Type() != 'python': 404 return 405 proto = unittest_pb2.TestAllTypes() 406 proto.optional_lazy_message.bb = 5 407 proto.ClearField('optional_lazy_message') 408 del proto 409 gc.collect() 410 411 def testHasBitsWhenModifyingRepeatedFields(self): 412 # Test nesting when we add an element to a repeated field in a submessage. 413 proto = unittest_pb2.TestNestedMessageHasBits() 414 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 415 self.assertEqual( 416 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 417 self.assertTrue(proto.HasField('optional_nested_message')) 418 419 # Do the same test, but with a repeated composite field within the 420 # submessage. 421 proto.ClearField('optional_nested_message') 422 self.assertTrue(not proto.HasField('optional_nested_message')) 423 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 424 self.assertTrue(proto.HasField('optional_nested_message')) 425 426 def testHasBitsForManyLevelsOfNesting(self): 427 # Test nesting many levels deep. 428 recursive_proto = unittest_pb2.TestMutualRecursionA() 429 self.assertTrue(not recursive_proto.HasField('bb')) 430 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 431 self.assertTrue(not recursive_proto.HasField('bb')) 432 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 433 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 434 self.assertTrue(recursive_proto.HasField('bb')) 435 self.assertTrue(recursive_proto.bb.HasField('a')) 436 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 437 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 438 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 439 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) 440 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 441 442 def testSingularListFields(self): 443 proto = unittest_pb2.TestAllTypes() 444 proto.optional_fixed32 = 1 445 proto.optional_int32 = 5 446 proto.optional_string = 'foo' 447 # Access sub-message but don't set it yet. 448 nested_message = proto.optional_nested_message 449 self.assertEqual( 450 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 451 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 452 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 453 proto.ListFields()) 454 455 proto.optional_nested_message.bb = 123 456 self.assertEqual( 457 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 458 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 459 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 460 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 461 nested_message) ], 462 proto.ListFields()) 463 464 def testRepeatedListFields(self): 465 proto = unittest_pb2.TestAllTypes() 466 proto.repeated_fixed32.append(1) 467 proto.repeated_int32.append(5) 468 proto.repeated_int32.append(11) 469 proto.repeated_string.extend(['foo', 'bar']) 470 proto.repeated_string.extend([]) 471 proto.repeated_string.append('baz') 472 proto.repeated_string.extend(str(x) for x in xrange(2)) 473 proto.optional_int32 = 21 474 proto.repeated_bool # Access but don't set anything; should not be listed. 475 self.assertEqual( 476 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 477 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 478 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 479 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 480 ['foo', 'bar', 'baz', '0', '1']) ], 481 proto.ListFields()) 482 483 def testSingularListExtensions(self): 484 proto = unittest_pb2.TestAllExtensions() 485 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 486 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 487 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 488 self.assertEqual( 489 [ (unittest_pb2.optional_int32_extension , 5), 490 (unittest_pb2.optional_fixed32_extension, 1), 491 (unittest_pb2.optional_string_extension , 'foo') ], 492 proto.ListFields()) 493 494 def testRepeatedListExtensions(self): 495 proto = unittest_pb2.TestAllExtensions() 496 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 497 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 498 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 499 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 500 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 501 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 502 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 503 self.assertEqual( 504 [ (unittest_pb2.optional_int32_extension , 21), 505 (unittest_pb2.repeated_int32_extension , [5, 11]), 506 (unittest_pb2.repeated_fixed32_extension, [1]), 507 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 508 proto.ListFields()) 509 510 def testListFieldsAndExtensions(self): 511 proto = unittest_pb2.TestFieldOrderings() 512 test_util.SetAllFieldsAndExtensions(proto) 513 unittest_pb2.my_extension_int 514 self.assertEqual( 515 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 516 (unittest_pb2.my_extension_int , 23), 517 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 518 (unittest_pb2.my_extension_string , 'bar'), 519 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 520 proto.ListFields()) 521 522 def testDefaultValues(self): 523 proto = unittest_pb2.TestAllTypes() 524 self.assertEqual(0, proto.optional_int32) 525 self.assertEqual(0, proto.optional_int64) 526 self.assertEqual(0, proto.optional_uint32) 527 self.assertEqual(0, proto.optional_uint64) 528 self.assertEqual(0, proto.optional_sint32) 529 self.assertEqual(0, proto.optional_sint64) 530 self.assertEqual(0, proto.optional_fixed32) 531 self.assertEqual(0, proto.optional_fixed64) 532 self.assertEqual(0, proto.optional_sfixed32) 533 self.assertEqual(0, proto.optional_sfixed64) 534 self.assertEqual(0.0, proto.optional_float) 535 self.assertEqual(0.0, proto.optional_double) 536 self.assertEqual(False, proto.optional_bool) 537 self.assertEqual('', proto.optional_string) 538 self.assertEqual(b'', proto.optional_bytes) 539 540 self.assertEqual(41, proto.default_int32) 541 self.assertEqual(42, proto.default_int64) 542 self.assertEqual(43, proto.default_uint32) 543 self.assertEqual(44, proto.default_uint64) 544 self.assertEqual(-45, proto.default_sint32) 545 self.assertEqual(46, proto.default_sint64) 546 self.assertEqual(47, proto.default_fixed32) 547 self.assertEqual(48, proto.default_fixed64) 548 self.assertEqual(49, proto.default_sfixed32) 549 self.assertEqual(-50, proto.default_sfixed64) 550 self.assertEqual(51.5, proto.default_float) 551 self.assertEqual(52e3, proto.default_double) 552 self.assertEqual(True, proto.default_bool) 553 self.assertEqual('hello', proto.default_string) 554 self.assertEqual(b'world', proto.default_bytes) 555 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 556 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 557 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 558 proto.default_import_enum) 559 560 proto = unittest_pb2.TestExtremeDefaultValues() 561 self.assertEqual(u'\u1234', proto.utf8_string) 562 563 def testHasFieldWithUnknownFieldName(self): 564 proto = unittest_pb2.TestAllTypes() 565 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 566 567 def testClearFieldWithUnknownFieldName(self): 568 proto = unittest_pb2.TestAllTypes() 569 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 570 571 def testClearRemovesChildren(self): 572 # Make sure there aren't any implementation bugs that are only partially 573 # clearing the message (which can happen in the more complex C++ 574 # implementation which has parallel message lists). 575 proto = unittest_pb2.TestRequiredForeign() 576 for i in range(10): 577 proto.repeated_message.add() 578 proto2 = unittest_pb2.TestRequiredForeign() 579 proto.CopyFrom(proto2) 580 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 581 582 def testDisallowedAssignments(self): 583 # It's illegal to assign values directly to repeated fields 584 # or to nonrepeated composite fields. Ensure that this fails. 585 proto = unittest_pb2.TestAllTypes() 586 # Repeated fields. 587 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 588 # Lists shouldn't work, either. 589 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 590 # Composite fields. 591 self.assertRaises(AttributeError, setattr, proto, 592 'optional_nested_message', 23) 593 # Assignment to a repeated nested message field without specifying 594 # the index in the array of nested messages. 595 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 596 'bb', 34) 597 # Assignment to an attribute of a repeated field. 598 self.assertRaises(AttributeError, setattr, proto.repeated_float, 599 'some_attribute', 34) 600 # proto.nonexistent_field = 23 should fail as well. 601 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 602 603 def testSingleScalarTypeSafety(self): 604 proto = unittest_pb2.TestAllTypes() 605 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 606 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 607 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 608 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 609 610 def testIntegerTypes(self): 611 def TestGetAndDeserialize(field_name, value, expected_type): 612 proto = unittest_pb2.TestAllTypes() 613 setattr(proto, field_name, value) 614 self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) 615 proto2 = unittest_pb2.TestAllTypes() 616 proto2.ParseFromString(proto.SerializeToString()) 617 self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) 618 619 TestGetAndDeserialize('optional_int32', 1, int) 620 TestGetAndDeserialize('optional_int32', 1 << 30, int) 621 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 622 if struct.calcsize('L') == 4: 623 # Python only has signed ints, so 32-bit python can't fit an uint32 624 # in an int. 625 TestGetAndDeserialize('optional_uint32', 1 << 31, long) 626 else: 627 # 64-bit python can fit uint32 inside an int 628 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 629 TestGetAndDeserialize('optional_int64', 1 << 30, long) 630 TestGetAndDeserialize('optional_int64', 1 << 60, long) 631 TestGetAndDeserialize('optional_uint64', 1 << 30, long) 632 TestGetAndDeserialize('optional_uint64', 1 << 60, long) 633 634 def testSingleScalarBoundsChecking(self): 635 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 636 pb = unittest_pb2.TestAllTypes() 637 setattr(pb, field_name, expected_min) 638 self.assertEqual(expected_min, getattr(pb, field_name)) 639 setattr(pb, field_name, expected_max) 640 self.assertEqual(expected_max, getattr(pb, field_name)) 641 self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) 642 self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) 643 644 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 645 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 646 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 647 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 648 649 pb = unittest_pb2.TestAllTypes() 650 pb.optional_nested_enum = 1 651 self.assertEqual(1, pb.optional_nested_enum) 652 653 def testRepeatedScalarTypeSafety(self): 654 proto = unittest_pb2.TestAllTypes() 655 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 656 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 657 self.assertRaises(TypeError, proto.repeated_string, 10) 658 self.assertRaises(TypeError, proto.repeated_bytes, 10) 659 660 proto.repeated_int32.append(10) 661 proto.repeated_int32[0] = 23 662 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 663 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 664 665 # Repeated enums tests. 666 #proto.repeated_nested_enum.append(0) 667 668 def testSingleScalarGettersAndSetters(self): 669 proto = unittest_pb2.TestAllTypes() 670 self.assertEqual(0, proto.optional_int32) 671 proto.optional_int32 = 1 672 self.assertEqual(1, proto.optional_int32) 673 674 proto.optional_uint64 = 0xffffffffffff 675 self.assertEqual(0xffffffffffff, proto.optional_uint64) 676 proto.optional_uint64 = 0xffffffffffffffff 677 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 678 # TODO(robinson): Test all other scalar field types. 679 680 def testSingleScalarClearField(self): 681 proto = unittest_pb2.TestAllTypes() 682 # Should be allowed to clear something that's not there (a no-op). 683 proto.ClearField('optional_int32') 684 proto.optional_int32 = 1 685 self.assertTrue(proto.HasField('optional_int32')) 686 proto.ClearField('optional_int32') 687 self.assertEqual(0, proto.optional_int32) 688 self.assertTrue(not proto.HasField('optional_int32')) 689 # TODO(robinson): Test all other scalar field types. 690 691 def testEnums(self): 692 proto = unittest_pb2.TestAllTypes() 693 self.assertEqual(1, proto.FOO) 694 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) 695 self.assertEqual(2, proto.BAR) 696 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) 697 self.assertEqual(3, proto.BAZ) 698 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) 699 700 def testEnum_Name(self): 701 self.assertEqual('FOREIGN_FOO', 702 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) 703 self.assertEqual('FOREIGN_BAR', 704 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) 705 self.assertEqual('FOREIGN_BAZ', 706 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) 707 self.assertRaises(ValueError, 708 unittest_pb2.ForeignEnum.Name, 11312) 709 710 proto = unittest_pb2.TestAllTypes() 711 self.assertEqual('FOO', 712 proto.NestedEnum.Name(proto.FOO)) 713 self.assertEqual('FOO', 714 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) 715 self.assertEqual('BAR', 716 proto.NestedEnum.Name(proto.BAR)) 717 self.assertEqual('BAR', 718 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) 719 self.assertEqual('BAZ', 720 proto.NestedEnum.Name(proto.BAZ)) 721 self.assertEqual('BAZ', 722 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) 723 self.assertRaises(ValueError, 724 proto.NestedEnum.Name, 11312) 725 self.assertRaises(ValueError, 726 unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) 727 728 def testEnum_Value(self): 729 self.assertEqual(unittest_pb2.FOREIGN_FOO, 730 unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) 731 self.assertEqual(unittest_pb2.FOREIGN_BAR, 732 unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) 733 self.assertEqual(unittest_pb2.FOREIGN_BAZ, 734 unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) 735 self.assertRaises(ValueError, 736 unittest_pb2.ForeignEnum.Value, 'FO') 737 738 proto = unittest_pb2.TestAllTypes() 739 self.assertEqual(proto.FOO, 740 proto.NestedEnum.Value('FOO')) 741 self.assertEqual(proto.FOO, 742 unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) 743 self.assertEqual(proto.BAR, 744 proto.NestedEnum.Value('BAR')) 745 self.assertEqual(proto.BAR, 746 unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) 747 self.assertEqual(proto.BAZ, 748 proto.NestedEnum.Value('BAZ')) 749 self.assertEqual(proto.BAZ, 750 unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) 751 self.assertRaises(ValueError, 752 proto.NestedEnum.Value, 'Foo') 753 self.assertRaises(ValueError, 754 unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') 755 756 def testEnum_KeysAndValues(self): 757 self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], 758 unittest_pb2.ForeignEnum.keys()) 759 self.assertEqual([4, 5, 6], 760 unittest_pb2.ForeignEnum.values()) 761 self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), 762 ('FOREIGN_BAZ', 6)], 763 unittest_pb2.ForeignEnum.items()) 764 765 proto = unittest_pb2.TestAllTypes() 766 self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) 767 self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) 768 self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], 769 proto.NestedEnum.items()) 770 771 def testRepeatedScalars(self): 772 proto = unittest_pb2.TestAllTypes() 773 774 self.assertTrue(not proto.repeated_int32) 775 self.assertEqual(0, len(proto.repeated_int32)) 776 proto.repeated_int32.append(5) 777 proto.repeated_int32.append(10) 778 proto.repeated_int32.append(15) 779 self.assertTrue(proto.repeated_int32) 780 self.assertEqual(3, len(proto.repeated_int32)) 781 782 self.assertEqual([5, 10, 15], proto.repeated_int32) 783 784 # Test single retrieval. 785 self.assertEqual(5, proto.repeated_int32[0]) 786 self.assertEqual(15, proto.repeated_int32[-1]) 787 # Test out-of-bounds indices. 788 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 789 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 790 # Test incorrect types passed to __getitem__. 791 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 792 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 793 794 # Test single assignment. 795 proto.repeated_int32[1] = 20 796 self.assertEqual([5, 20, 15], proto.repeated_int32) 797 798 # Test insertion. 799 proto.repeated_int32.insert(1, 25) 800 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 801 802 # Test slice retrieval. 803 proto.repeated_int32.append(30) 804 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 805 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 806 807 # Test slice assignment with an iterator 808 proto.repeated_int32[1:4] = (i for i in xrange(3)) 809 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 810 811 # Test slice assignment. 812 proto.repeated_int32[1:4] = [35, 40, 45] 813 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 814 815 # Test that we can use the field as an iterator. 816 result = [] 817 for i in proto.repeated_int32: 818 result.append(i) 819 self.assertEqual([5, 35, 40, 45, 30], result) 820 821 # Test single deletion. 822 del proto.repeated_int32[2] 823 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 824 825 # Test slice deletion. 826 del proto.repeated_int32[2:] 827 self.assertEqual([5, 35], proto.repeated_int32) 828 829 # Test extending. 830 proto.repeated_int32.extend([3, 13]) 831 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 832 833 # Test clearing. 834 proto.ClearField('repeated_int32') 835 self.assertTrue(not proto.repeated_int32) 836 self.assertEqual(0, len(proto.repeated_int32)) 837 838 proto.repeated_int32.append(1) 839 self.assertEqual(1, proto.repeated_int32[-1]) 840 # Test assignment to a negative index. 841 proto.repeated_int32[-1] = 2 842 self.assertEqual(2, proto.repeated_int32[-1]) 843 844 # Test deletion at negative indices. 845 proto.repeated_int32[:] = [0, 1, 2, 3] 846 del proto.repeated_int32[-1] 847 self.assertEqual([0, 1, 2], proto.repeated_int32) 848 849 del proto.repeated_int32[-2] 850 self.assertEqual([0, 2], proto.repeated_int32) 851 852 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 853 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 854 855 del proto.repeated_int32[-2:-1] 856 self.assertEqual([2], proto.repeated_int32) 857 858 del proto.repeated_int32[100:10000] 859 self.assertEqual([2], proto.repeated_int32) 860 861 def testRepeatedScalarsRemove(self): 862 proto = unittest_pb2.TestAllTypes() 863 864 self.assertTrue(not proto.repeated_int32) 865 self.assertEqual(0, len(proto.repeated_int32)) 866 proto.repeated_int32.append(5) 867 proto.repeated_int32.append(10) 868 proto.repeated_int32.append(5) 869 proto.repeated_int32.append(5) 870 871 self.assertEqual(4, len(proto.repeated_int32)) 872 proto.repeated_int32.remove(5) 873 self.assertEqual(3, len(proto.repeated_int32)) 874 self.assertEqual(10, proto.repeated_int32[0]) 875 self.assertEqual(5, proto.repeated_int32[1]) 876 self.assertEqual(5, proto.repeated_int32[2]) 877 878 proto.repeated_int32.remove(5) 879 self.assertEqual(2, len(proto.repeated_int32)) 880 self.assertEqual(10, proto.repeated_int32[0]) 881 self.assertEqual(5, proto.repeated_int32[1]) 882 883 proto.repeated_int32.remove(10) 884 self.assertEqual(1, len(proto.repeated_int32)) 885 self.assertEqual(5, proto.repeated_int32[0]) 886 887 # Remove a non-existent element. 888 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 889 890 def testRepeatedComposites(self): 891 proto = unittest_pb2.TestAllTypes() 892 self.assertTrue(not proto.repeated_nested_message) 893 self.assertEqual(0, len(proto.repeated_nested_message)) 894 m0 = proto.repeated_nested_message.add() 895 m1 = proto.repeated_nested_message.add() 896 self.assertTrue(proto.repeated_nested_message) 897 self.assertEqual(2, len(proto.repeated_nested_message)) 898 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 899 self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) 900 901 # Test out-of-bounds indices. 902 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 903 1234) 904 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 905 -1234) 906 907 # Test incorrect types passed to __getitem__. 908 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 909 'foo') 910 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 911 None) 912 913 # Test slice retrieval. 914 m2 = proto.repeated_nested_message.add() 915 m3 = proto.repeated_nested_message.add() 916 m4 = proto.repeated_nested_message.add() 917 self.assertListsEqual( 918 [m1, m2, m3], proto.repeated_nested_message[1:4]) 919 self.assertListsEqual( 920 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 921 self.assertListsEqual( 922 [m0, m1], proto.repeated_nested_message[:2]) 923 self.assertListsEqual( 924 [m2, m3, m4], proto.repeated_nested_message[2:]) 925 self.assertEqual( 926 m0, proto.repeated_nested_message[0]) 927 self.assertListsEqual( 928 [m0], proto.repeated_nested_message[:1]) 929 930 # Test that we can use the field as an iterator. 931 result = [] 932 for i in proto.repeated_nested_message: 933 result.append(i) 934 self.assertListsEqual([m0, m1, m2, m3, m4], result) 935 936 # Test single deletion. 937 del proto.repeated_nested_message[2] 938 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 939 940 # Test slice deletion. 941 del proto.repeated_nested_message[2:] 942 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 943 944 # Test extending. 945 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 946 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 947 proto.repeated_nested_message.extend([n1,n2]) 948 self.assertEqual(4, len(proto.repeated_nested_message)) 949 self.assertEqual(n1, proto.repeated_nested_message[2]) 950 self.assertEqual(n2, proto.repeated_nested_message[3]) 951 952 # Test clearing. 953 proto.ClearField('repeated_nested_message') 954 self.assertTrue(not proto.repeated_nested_message) 955 self.assertEqual(0, len(proto.repeated_nested_message)) 956 957 # Test constructing an element while adding it. 958 proto.repeated_nested_message.add(bb=23) 959 self.assertEqual(1, len(proto.repeated_nested_message)) 960 self.assertEqual(23, proto.repeated_nested_message[0].bb) 961 962 def testRepeatedCompositeRemove(self): 963 proto = unittest_pb2.TestAllTypes() 964 965 self.assertEqual(0, len(proto.repeated_nested_message)) 966 m0 = proto.repeated_nested_message.add() 967 # Need to set some differentiating variable so m0 != m1 != m2: 968 m0.bb = len(proto.repeated_nested_message) 969 m1 = proto.repeated_nested_message.add() 970 m1.bb = len(proto.repeated_nested_message) 971 self.assertTrue(m0 != m1) 972 m2 = proto.repeated_nested_message.add() 973 m2.bb = len(proto.repeated_nested_message) 974 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 975 976 self.assertEqual(3, len(proto.repeated_nested_message)) 977 proto.repeated_nested_message.remove(m0) 978 self.assertEqual(2, len(proto.repeated_nested_message)) 979 self.assertEqual(m1, proto.repeated_nested_message[0]) 980 self.assertEqual(m2, proto.repeated_nested_message[1]) 981 982 # Removing m0 again or removing None should raise error 983 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 984 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 985 self.assertEqual(2, len(proto.repeated_nested_message)) 986 987 proto.repeated_nested_message.remove(m2) 988 self.assertEqual(1, len(proto.repeated_nested_message)) 989 self.assertEqual(m1, proto.repeated_nested_message[0]) 990 991 def testHandWrittenReflection(self): 992 # Hand written extensions are only supported by the pure-Python 993 # implementation of the API. 994 if api_implementation.Type() != 'python': 995 return 996 997 FieldDescriptor = descriptor.FieldDescriptor 998 foo_field_descriptor = FieldDescriptor( 999 name='foo_field', full_name='MyProto.foo_field', 1000 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1001 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1002 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1003 containing_type=None, message_type=None, enum_type=None, 1004 is_extension=False, extension_scope=None, 1005 options=descriptor_pb2.FieldOptions()) 1006 mydescriptor = descriptor.Descriptor( 1007 name='MyProto', full_name='MyProto', filename='ignored', 1008 containing_type=None, nested_types=[], enum_types=[], 1009 fields=[foo_field_descriptor], extensions=[], 1010 options=descriptor_pb2.MessageOptions()) 1011 class MyProtoClass(message.Message): 1012 DESCRIPTOR = mydescriptor 1013 __metaclass__ = reflection.GeneratedProtocolMessageType 1014 myproto_instance = MyProtoClass() 1015 self.assertEqual(0, myproto_instance.foo_field) 1016 self.assertTrue(not myproto_instance.HasField('foo_field')) 1017 myproto_instance.foo_field = 23 1018 self.assertEqual(23, myproto_instance.foo_field) 1019 self.assertTrue(myproto_instance.HasField('foo_field')) 1020 1021 def testDescriptorProtoSupport(self): 1022 # Hand written descriptors/reflection are only supported by the pure-Python 1023 # implementation of the API. 1024 if api_implementation.Type() != 'python': 1025 return 1026 1027 def AddDescriptorField(proto, field_name, field_type): 1028 AddDescriptorField.field_index += 1 1029 new_field = proto.field.add() 1030 new_field.name = field_name 1031 new_field.type = field_type 1032 new_field.number = AddDescriptorField.field_index 1033 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1034 1035 AddDescriptorField.field_index = 0 1036 1037 desc_proto = descriptor_pb2.DescriptorProto() 1038 desc_proto.name = 'Car' 1039 fdp = descriptor_pb2.FieldDescriptorProto 1040 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1041 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1042 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1043 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1044 # Add a repeated field 1045 AddDescriptorField.field_index += 1 1046 new_field = desc_proto.field.add() 1047 new_field.name = 'owners' 1048 new_field.type = fdp.TYPE_STRING 1049 new_field.number = AddDescriptorField.field_index 1050 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1051 1052 desc = descriptor.MakeDescriptor(desc_proto) 1053 self.assertTrue(desc.fields_by_name.has_key('name')) 1054 self.assertTrue(desc.fields_by_name.has_key('year')) 1055 self.assertTrue(desc.fields_by_name.has_key('automatic')) 1056 self.assertTrue(desc.fields_by_name.has_key('price')) 1057 self.assertTrue(desc.fields_by_name.has_key('owners')) 1058 1059 class CarMessage(message.Message): 1060 __metaclass__ = reflection.GeneratedProtocolMessageType 1061 DESCRIPTOR = desc 1062 1063 prius = CarMessage() 1064 prius.name = 'prius' 1065 prius.year = 2010 1066 prius.automatic = True 1067 prius.price = 25134.75 1068 prius.owners.extend(['bob', 'susan']) 1069 1070 serialized_prius = prius.SerializeToString() 1071 new_prius = reflection.ParseMessage(desc, serialized_prius) 1072 self.assertTrue(new_prius is not prius) 1073 self.assertEqual(prius, new_prius) 1074 1075 # these are unnecessary assuming message equality works as advertised but 1076 # explicitly check to be safe since we're mucking about in metaclass foo 1077 self.assertEqual(prius.name, new_prius.name) 1078 self.assertEqual(prius.year, new_prius.year) 1079 self.assertEqual(prius.automatic, new_prius.automatic) 1080 self.assertEqual(prius.price, new_prius.price) 1081 self.assertEqual(prius.owners, new_prius.owners) 1082 1083 def testTopLevelExtensionsForOptionalScalar(self): 1084 extendee_proto = unittest_pb2.TestAllExtensions() 1085 extension = unittest_pb2.optional_int32_extension 1086 self.assertTrue(not extendee_proto.HasExtension(extension)) 1087 self.assertEqual(0, extendee_proto.Extensions[extension]) 1088 # As with normal scalar fields, just doing a read doesn't actually set the 1089 # "has" bit. 1090 self.assertTrue(not extendee_proto.HasExtension(extension)) 1091 # Actually set the thing. 1092 extendee_proto.Extensions[extension] = 23 1093 self.assertEqual(23, extendee_proto.Extensions[extension]) 1094 self.assertTrue(extendee_proto.HasExtension(extension)) 1095 # Ensure that clearing works as well. 1096 extendee_proto.ClearExtension(extension) 1097 self.assertEqual(0, extendee_proto.Extensions[extension]) 1098 self.assertTrue(not extendee_proto.HasExtension(extension)) 1099 1100 def testTopLevelExtensionsForRepeatedScalar(self): 1101 extendee_proto = unittest_pb2.TestAllExtensions() 1102 extension = unittest_pb2.repeated_string_extension 1103 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1104 extendee_proto.Extensions[extension].append('foo') 1105 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1106 string_list = extendee_proto.Extensions[extension] 1107 extendee_proto.ClearExtension(extension) 1108 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1109 self.assertTrue(string_list is not extendee_proto.Extensions[extension]) 1110 # Shouldn't be allowed to do Extensions[extension] = 'a' 1111 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1112 extension, 'a') 1113 1114 def testTopLevelExtensionsForOptionalMessage(self): 1115 extendee_proto = unittest_pb2.TestAllExtensions() 1116 extension = unittest_pb2.optional_foreign_message_extension 1117 self.assertTrue(not extendee_proto.HasExtension(extension)) 1118 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1119 # As with normal (non-extension) fields, merely reading from the 1120 # thing shouldn't set the "has" bit. 1121 self.assertTrue(not extendee_proto.HasExtension(extension)) 1122 extendee_proto.Extensions[extension].c = 23 1123 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1124 self.assertTrue(extendee_proto.HasExtension(extension)) 1125 # Save a reference here. 1126 foreign_message = extendee_proto.Extensions[extension] 1127 extendee_proto.ClearExtension(extension) 1128 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) 1129 # Setting a field on foreign_message now shouldn't set 1130 # any "has" bits on extendee_proto. 1131 foreign_message.c = 42 1132 self.assertEqual(42, foreign_message.c) 1133 self.assertTrue(foreign_message.HasField('c')) 1134 self.assertTrue(not extendee_proto.HasExtension(extension)) 1135 # Shouldn't be allowed to do Extensions[extension] = 'a' 1136 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1137 extension, 'a') 1138 1139 def testTopLevelExtensionsForRepeatedMessage(self): 1140 extendee_proto = unittest_pb2.TestAllExtensions() 1141 extension = unittest_pb2.repeatedgroup_extension 1142 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1143 group = extendee_proto.Extensions[extension].add() 1144 group.a = 23 1145 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1146 group.a = 42 1147 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1148 group_list = extendee_proto.Extensions[extension] 1149 extendee_proto.ClearExtension(extension) 1150 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1151 self.assertTrue(group_list is not extendee_proto.Extensions[extension]) 1152 # Shouldn't be allowed to do Extensions[extension] = 'a' 1153 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1154 extension, 'a') 1155 1156 def testNestedExtensions(self): 1157 extendee_proto = unittest_pb2.TestAllExtensions() 1158 extension = unittest_pb2.TestRequired.single 1159 1160 # We just test the non-repeated case. 1161 self.assertTrue(not extendee_proto.HasExtension(extension)) 1162 required = extendee_proto.Extensions[extension] 1163 self.assertEqual(0, required.a) 1164 self.assertTrue(not extendee_proto.HasExtension(extension)) 1165 required.a = 23 1166 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1167 self.assertTrue(extendee_proto.HasExtension(extension)) 1168 extendee_proto.ClearExtension(extension) 1169 self.assertTrue(required is not extendee_proto.Extensions[extension]) 1170 self.assertTrue(not extendee_proto.HasExtension(extension)) 1171 1172 def testRegisteredExtensions(self): 1173 self.assertTrue('protobuf_unittest.optional_int32_extension' in 1174 unittest_pb2.TestAllExtensions._extensions_by_name) 1175 self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) 1176 # Make sure extensions haven't been registered into types that shouldn't 1177 # have any. 1178 self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) 1179 1180 # If message A directly contains message B, and 1181 # a.HasField('b') is currently False, then mutating any 1182 # extension in B should change a.HasField('b') to True 1183 # (and so on up the object tree). 1184 def testHasBitsForAncestorsOfExtendedMessage(self): 1185 # Optional scalar extension. 1186 toplevel = more_extensions_pb2.TopLevelMessage() 1187 self.assertTrue(not toplevel.HasField('submessage')) 1188 self.assertEqual(0, toplevel.submessage.Extensions[ 1189 more_extensions_pb2.optional_int_extension]) 1190 self.assertTrue(not toplevel.HasField('submessage')) 1191 toplevel.submessage.Extensions[ 1192 more_extensions_pb2.optional_int_extension] = 23 1193 self.assertEqual(23, toplevel.submessage.Extensions[ 1194 more_extensions_pb2.optional_int_extension]) 1195 self.assertTrue(toplevel.HasField('submessage')) 1196 1197 # Repeated scalar extension. 1198 toplevel = more_extensions_pb2.TopLevelMessage() 1199 self.assertTrue(not toplevel.HasField('submessage')) 1200 self.assertEqual([], toplevel.submessage.Extensions[ 1201 more_extensions_pb2.repeated_int_extension]) 1202 self.assertTrue(not toplevel.HasField('submessage')) 1203 toplevel.submessage.Extensions[ 1204 more_extensions_pb2.repeated_int_extension].append(23) 1205 self.assertEqual([23], toplevel.submessage.Extensions[ 1206 more_extensions_pb2.repeated_int_extension]) 1207 self.assertTrue(toplevel.HasField('submessage')) 1208 1209 # Optional message extension. 1210 toplevel = more_extensions_pb2.TopLevelMessage() 1211 self.assertTrue(not toplevel.HasField('submessage')) 1212 self.assertEqual(0, toplevel.submessage.Extensions[ 1213 more_extensions_pb2.optional_message_extension].foreign_message_int) 1214 self.assertTrue(not toplevel.HasField('submessage')) 1215 toplevel.submessage.Extensions[ 1216 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1217 self.assertEqual(23, toplevel.submessage.Extensions[ 1218 more_extensions_pb2.optional_message_extension].foreign_message_int) 1219 self.assertTrue(toplevel.HasField('submessage')) 1220 1221 # Repeated message extension. 1222 toplevel = more_extensions_pb2.TopLevelMessage() 1223 self.assertTrue(not toplevel.HasField('submessage')) 1224 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1225 more_extensions_pb2.repeated_message_extension])) 1226 self.assertTrue(not toplevel.HasField('submessage')) 1227 foreign = toplevel.submessage.Extensions[ 1228 more_extensions_pb2.repeated_message_extension].add() 1229 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1230 more_extensions_pb2.repeated_message_extension][0]) 1231 self.assertTrue(toplevel.HasField('submessage')) 1232 1233 def testDisconnectionAfterClearingEmptyMessage(self): 1234 toplevel = more_extensions_pb2.TopLevelMessage() 1235 extendee_proto = toplevel.submessage 1236 extension = more_extensions_pb2.optional_message_extension 1237 extension_proto = extendee_proto.Extensions[extension] 1238 extendee_proto.ClearExtension(extension) 1239 extension_proto.foreign_message_int = 23 1240 1241 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) 1242 1243 def testExtensionFailureModes(self): 1244 extendee_proto = unittest_pb2.TestAllExtensions() 1245 1246 # Try non-extension-handle arguments to HasExtension, 1247 # ClearExtension(), and Extensions[]... 1248 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1249 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1250 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1251 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1252 1253 # Try something that *is* an extension handle, just not for 1254 # this message... 1255 unknown_handle = more_extensions_pb2.optional_int_extension 1256 self.assertRaises(KeyError, extendee_proto.HasExtension, 1257 unknown_handle) 1258 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1259 unknown_handle) 1260 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1261 unknown_handle) 1262 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1263 unknown_handle, 5) 1264 1265 # Try call HasExtension() with a valid handle, but for a 1266 # *repeated* field. (Just as with non-extension repeated 1267 # fields, Has*() isn't supported for extension repeated fields). 1268 self.assertRaises(KeyError, extendee_proto.HasExtension, 1269 unittest_pb2.repeated_string_extension) 1270 1271 def testStaticParseFrom(self): 1272 proto1 = unittest_pb2.TestAllTypes() 1273 test_util.SetAllFields(proto1) 1274 1275 string1 = proto1.SerializeToString() 1276 proto2 = unittest_pb2.TestAllTypes.FromString(string1) 1277 1278 # Messages should be equal. 1279 self.assertEqual(proto2, proto1) 1280 1281 def testMergeFromSingularField(self): 1282 # Test merge with just a singular field. 1283 proto1 = unittest_pb2.TestAllTypes() 1284 proto1.optional_int32 = 1 1285 1286 proto2 = unittest_pb2.TestAllTypes() 1287 # This shouldn't get overwritten. 1288 proto2.optional_string = 'value' 1289 1290 proto2.MergeFrom(proto1) 1291 self.assertEqual(1, proto2.optional_int32) 1292 self.assertEqual('value', proto2.optional_string) 1293 1294 def testMergeFromRepeatedField(self): 1295 # Test merge with just a repeated field. 1296 proto1 = unittest_pb2.TestAllTypes() 1297 proto1.repeated_int32.append(1) 1298 proto1.repeated_int32.append(2) 1299 1300 proto2 = unittest_pb2.TestAllTypes() 1301 proto2.repeated_int32.append(0) 1302 proto2.MergeFrom(proto1) 1303 1304 self.assertEqual(0, proto2.repeated_int32[0]) 1305 self.assertEqual(1, proto2.repeated_int32[1]) 1306 self.assertEqual(2, proto2.repeated_int32[2]) 1307 1308 def testMergeFromOptionalGroup(self): 1309 # Test merge with an optional group. 1310 proto1 = unittest_pb2.TestAllTypes() 1311 proto1.optionalgroup.a = 12 1312 proto2 = unittest_pb2.TestAllTypes() 1313 proto2.MergeFrom(proto1) 1314 self.assertEqual(12, proto2.optionalgroup.a) 1315 1316 def testMergeFromRepeatedNestedMessage(self): 1317 # Test merge with a repeated nested message. 1318 proto1 = unittest_pb2.TestAllTypes() 1319 m = proto1.repeated_nested_message.add() 1320 m.bb = 123 1321 m = proto1.repeated_nested_message.add() 1322 m.bb = 321 1323 1324 proto2 = unittest_pb2.TestAllTypes() 1325 m = proto2.repeated_nested_message.add() 1326 m.bb = 999 1327 proto2.MergeFrom(proto1) 1328 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 1329 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 1330 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 1331 1332 proto3 = unittest_pb2.TestAllTypes() 1333 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 1334 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 1335 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 1336 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 1337 1338 def testMergeFromAllFields(self): 1339 # With all fields set. 1340 proto1 = unittest_pb2.TestAllTypes() 1341 test_util.SetAllFields(proto1) 1342 proto2 = unittest_pb2.TestAllTypes() 1343 proto2.MergeFrom(proto1) 1344 1345 # Messages should be equal. 1346 self.assertEqual(proto2, proto1) 1347 1348 # Serialized string should be equal too. 1349 string1 = proto1.SerializeToString() 1350 string2 = proto2.SerializeToString() 1351 self.assertEqual(string1, string2) 1352 1353 def testMergeFromExtensionsSingular(self): 1354 proto1 = unittest_pb2.TestAllExtensions() 1355 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1356 1357 proto2 = unittest_pb2.TestAllExtensions() 1358 proto2.MergeFrom(proto1) 1359 self.assertEqual( 1360 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1361 1362 def testMergeFromExtensionsRepeated(self): 1363 proto1 = unittest_pb2.TestAllExtensions() 1364 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1365 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1366 1367 proto2 = unittest_pb2.TestAllExtensions() 1368 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1369 proto2.MergeFrom(proto1) 1370 self.assertEqual( 1371 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1372 self.assertEqual( 1373 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1374 self.assertEqual( 1375 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1376 self.assertEqual( 1377 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1378 1379 def testMergeFromExtensionsNestedMessage(self): 1380 proto1 = unittest_pb2.TestAllExtensions() 1381 ext1 = proto1.Extensions[ 1382 unittest_pb2.repeated_nested_message_extension] 1383 m = ext1.add() 1384 m.bb = 222 1385 m = ext1.add() 1386 m.bb = 333 1387 1388 proto2 = unittest_pb2.TestAllExtensions() 1389 ext2 = proto2.Extensions[ 1390 unittest_pb2.repeated_nested_message_extension] 1391 m = ext2.add() 1392 m.bb = 111 1393 1394 proto2.MergeFrom(proto1) 1395 ext2 = proto2.Extensions[ 1396 unittest_pb2.repeated_nested_message_extension] 1397 self.assertEqual(3, len(ext2)) 1398 self.assertEqual(111, ext2[0].bb) 1399 self.assertEqual(222, ext2[1].bb) 1400 self.assertEqual(333, ext2[2].bb) 1401 1402 def testMergeFromBug(self): 1403 message1 = unittest_pb2.TestAllTypes() 1404 message2 = unittest_pb2.TestAllTypes() 1405 1406 # Cause optional_nested_message to be instantiated within message1, even 1407 # though it is not considered to be "present". 1408 message1.optional_nested_message 1409 self.assertFalse(message1.HasField('optional_nested_message')) 1410 1411 # Merge into message2. This should not instantiate the field is message2. 1412 message2.MergeFrom(message1) 1413 self.assertFalse(message2.HasField('optional_nested_message')) 1414 1415 def testCopyFromSingularField(self): 1416 # Test copy with just a singular field. 1417 proto1 = unittest_pb2.TestAllTypes() 1418 proto1.optional_int32 = 1 1419 proto1.optional_string = 'important-text' 1420 1421 proto2 = unittest_pb2.TestAllTypes() 1422 proto2.optional_string = 'value' 1423 1424 proto2.CopyFrom(proto1) 1425 self.assertEqual(1, proto2.optional_int32) 1426 self.assertEqual('important-text', proto2.optional_string) 1427 1428 def testCopyFromRepeatedField(self): 1429 # Test copy with a repeated field. 1430 proto1 = unittest_pb2.TestAllTypes() 1431 proto1.repeated_int32.append(1) 1432 proto1.repeated_int32.append(2) 1433 1434 proto2 = unittest_pb2.TestAllTypes() 1435 proto2.repeated_int32.append(0) 1436 proto2.CopyFrom(proto1) 1437 1438 self.assertEqual(1, proto2.repeated_int32[0]) 1439 self.assertEqual(2, proto2.repeated_int32[1]) 1440 1441 def testCopyFromAllFields(self): 1442 # With all fields set. 1443 proto1 = unittest_pb2.TestAllTypes() 1444 test_util.SetAllFields(proto1) 1445 proto2 = unittest_pb2.TestAllTypes() 1446 proto2.CopyFrom(proto1) 1447 1448 # Messages should be equal. 1449 self.assertEqual(proto2, proto1) 1450 1451 # Serialized string should be equal too. 1452 string1 = proto1.SerializeToString() 1453 string2 = proto2.SerializeToString() 1454 self.assertEqual(string1, string2) 1455 1456 def testCopyFromSelf(self): 1457 proto1 = unittest_pb2.TestAllTypes() 1458 proto1.repeated_int32.append(1) 1459 proto1.optional_int32 = 2 1460 proto1.optional_string = 'important-text' 1461 1462 proto1.CopyFrom(proto1) 1463 self.assertEqual(1, proto1.repeated_int32[0]) 1464 self.assertEqual(2, proto1.optional_int32) 1465 self.assertEqual('important-text', proto1.optional_string) 1466 1467 def testCopyFromBadType(self): 1468 # The python implementation doesn't raise an exception in this 1469 # case. In theory it should. 1470 if api_implementation.Type() == 'python': 1471 return 1472 proto1 = unittest_pb2.TestAllTypes() 1473 proto2 = unittest_pb2.TestAllExtensions() 1474 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1475 1476 def testDeepCopy(self): 1477 proto1 = unittest_pb2.TestAllTypes() 1478 proto1.optional_int32 = 1 1479 proto2 = copy.deepcopy(proto1) 1480 self.assertEqual(1, proto2.optional_int32) 1481 1482 proto1.repeated_int32.append(2) 1483 proto1.repeated_int32.append(3) 1484 container = copy.deepcopy(proto1.repeated_int32) 1485 self.assertEqual([2, 3], container) 1486 1487 # TODO(anuraag): Implement deepcopy for repeated composite / extension dict 1488 1489 def testClear(self): 1490 proto = unittest_pb2.TestAllTypes() 1491 # C++ implementation does not support lazy fields right now so leave it 1492 # out for now. 1493 if api_implementation.Type() == 'python': 1494 test_util.SetAllFields(proto) 1495 else: 1496 test_util.SetAllNonLazyFields(proto) 1497 # Clear the message. 1498 proto.Clear() 1499 self.assertEquals(proto.ByteSize(), 0) 1500 empty_proto = unittest_pb2.TestAllTypes() 1501 self.assertEquals(proto, empty_proto) 1502 1503 # Test if extensions which were set are cleared. 1504 proto = unittest_pb2.TestAllExtensions() 1505 test_util.SetAllExtensions(proto) 1506 # Clear the message. 1507 proto.Clear() 1508 self.assertEquals(proto.ByteSize(), 0) 1509 empty_proto = unittest_pb2.TestAllExtensions() 1510 self.assertEquals(proto, empty_proto) 1511 1512 def testDisconnectingBeforeClear(self): 1513 proto = unittest_pb2.TestAllTypes() 1514 nested = proto.optional_nested_message 1515 proto.Clear() 1516 self.assertTrue(nested is not proto.optional_nested_message) 1517 nested.bb = 23 1518 self.assertTrue(not proto.HasField('optional_nested_message')) 1519 self.assertEqual(0, proto.optional_nested_message.bb) 1520 1521 proto = unittest_pb2.TestAllTypes() 1522 nested = proto.optional_nested_message 1523 nested.bb = 5 1524 foreign = proto.optional_foreign_message 1525 foreign.c = 6 1526 1527 proto.Clear() 1528 self.assertTrue(nested is not proto.optional_nested_message) 1529 self.assertTrue(foreign is not proto.optional_foreign_message) 1530 self.assertEqual(5, nested.bb) 1531 self.assertEqual(6, foreign.c) 1532 nested.bb = 15 1533 foreign.c = 16 1534 self.assertFalse(proto.HasField('optional_nested_message')) 1535 self.assertEqual(0, proto.optional_nested_message.bb) 1536 self.assertFalse(proto.HasField('optional_foreign_message')) 1537 self.assertEqual(0, proto.optional_foreign_message.c) 1538 1539 def testOneOf(self): 1540 proto = unittest_pb2.TestAllTypes() 1541 proto.oneof_uint32 = 10 1542 proto.oneof_nested_message.bb = 11 1543 self.assertEqual(11, proto.oneof_nested_message.bb) 1544 self.assertFalse(proto.HasField('oneof_uint32')) 1545 nested = proto.oneof_nested_message 1546 proto.oneof_string = 'abc' 1547 self.assertEqual('abc', proto.oneof_string) 1548 self.assertEqual(11, nested.bb) 1549 self.assertFalse(proto.HasField('oneof_nested_message')) 1550 1551 def assertInitialized(self, proto): 1552 self.assertTrue(proto.IsInitialized()) 1553 # Neither method should raise an exception. 1554 proto.SerializeToString() 1555 proto.SerializePartialToString() 1556 1557 def assertNotInitialized(self, proto): 1558 self.assertFalse(proto.IsInitialized()) 1559 self.assertRaises(message.EncodeError, proto.SerializeToString) 1560 # "Partial" serialization doesn't care if message is uninitialized. 1561 proto.SerializePartialToString() 1562 1563 def testIsInitialized(self): 1564 # Trivial cases - all optional fields and extensions. 1565 proto = unittest_pb2.TestAllTypes() 1566 self.assertInitialized(proto) 1567 proto = unittest_pb2.TestAllExtensions() 1568 self.assertInitialized(proto) 1569 1570 # The case of uninitialized required fields. 1571 proto = unittest_pb2.TestRequired() 1572 self.assertNotInitialized(proto) 1573 proto.a = proto.b = proto.c = 2 1574 self.assertInitialized(proto) 1575 1576 # The case of uninitialized submessage. 1577 proto = unittest_pb2.TestRequiredForeign() 1578 self.assertInitialized(proto) 1579 proto.optional_message.a = 1 1580 self.assertNotInitialized(proto) 1581 proto.optional_message.b = 0 1582 proto.optional_message.c = 0 1583 self.assertInitialized(proto) 1584 1585 # Uninitialized repeated submessage. 1586 message1 = proto.repeated_message.add() 1587 self.assertNotInitialized(proto) 1588 message1.a = message1.b = message1.c = 0 1589 self.assertInitialized(proto) 1590 1591 # Uninitialized repeated group in an extension. 1592 proto = unittest_pb2.TestAllExtensions() 1593 extension = unittest_pb2.TestRequired.multi 1594 message1 = proto.Extensions[extension].add() 1595 message2 = proto.Extensions[extension].add() 1596 self.assertNotInitialized(proto) 1597 message1.a = 1 1598 message1.b = 1 1599 message1.c = 1 1600 self.assertNotInitialized(proto) 1601 message2.a = 2 1602 message2.b = 2 1603 message2.c = 2 1604 self.assertInitialized(proto) 1605 1606 # Uninitialized nonrepeated message in an extension. 1607 proto = unittest_pb2.TestAllExtensions() 1608 extension = unittest_pb2.TestRequired.single 1609 proto.Extensions[extension].a = 1 1610 self.assertNotInitialized(proto) 1611 proto.Extensions[extension].b = 2 1612 proto.Extensions[extension].c = 3 1613 self.assertInitialized(proto) 1614 1615 # Try passing an errors list. 1616 errors = [] 1617 proto = unittest_pb2.TestRequired() 1618 self.assertFalse(proto.IsInitialized(errors)) 1619 self.assertEqual(errors, ['a', 'b', 'c']) 1620 1621 @basetest.unittest.skipIf( 1622 api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 1623 'Errors are only available from the most recent C++ implementation.') 1624 def testFileDescriptorErrors(self): 1625 file_name = 'test_file_descriptor_errors.proto' 1626 package_name = 'test_file_descriptor_errors.proto' 1627 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 1628 file_descriptor_proto.name = file_name 1629 file_descriptor_proto.package = package_name 1630 m1 = file_descriptor_proto.message_type.add() 1631 m1.name = 'msg1' 1632 # Compiles the proto into the C++ descriptor pool 1633 descriptor.FileDescriptor( 1634 file_name, 1635 package_name, 1636 serialized_pb=file_descriptor_proto.SerializeToString()) 1637 # Add a FileDescriptorProto that has duplicate symbols 1638 another_file_name = 'another_test_file_descriptor_errors.proto' 1639 file_descriptor_proto.name = another_file_name 1640 m2 = file_descriptor_proto.message_type.add() 1641 m2.name = 'msg2' 1642 with self.assertRaises(TypeError) as cm: 1643 descriptor.FileDescriptor( 1644 another_file_name, 1645 package_name, 1646 serialized_pb=file_descriptor_proto.SerializeToString()) 1647 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 1648 getattr(cm.expected, '__name__', cm.expected)) 1649 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 1650 # Error message will say something about this definition being a 1651 # duplicate, though we don't check the message exactly to avoid a 1652 # dependency on the C++ logging code. 1653 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 1654 1655 def testStringUTF8Encoding(self): 1656 proto = unittest_pb2.TestAllTypes() 1657 1658 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 1659 self.assertRaises(TypeError, 1660 setattr, proto, 'optional_bytes', u'unicode object') 1661 1662 # Check that the default value is of python's 'unicode' type. 1663 self.assertEqual(type(proto.optional_string), unicode) 1664 1665 proto.optional_string = unicode('Testing') 1666 self.assertEqual(proto.optional_string, str('Testing')) 1667 1668 # Assign a value of type 'str' which can be encoded in UTF-8. 1669 proto.optional_string = str('Testing') 1670 self.assertEqual(proto.optional_string, unicode('Testing')) 1671 1672 # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. 1673 self.assertRaises(ValueError, 1674 setattr, proto, 'optional_string', b'a\x80a') 1675 if str is bytes: # PY2 1676 # Assign a 'str' object which contains a UTF-8 encoded string. 1677 self.assertRaises(ValueError, 1678 setattr, proto, 'optional_string', 'Тест') 1679 else: 1680 proto.optional_string = 'Тест' 1681 # No exception thrown. 1682 proto.optional_string = 'abc' 1683 1684 def testStringUTF8Serialization(self): 1685 proto = unittest_mset_pb2.TestMessageSet() 1686 extension_message = unittest_mset_pb2.TestMessageSetExtension2 1687 extension = extension_message.message_set_extension 1688 1689 test_utf8 = u'Тест' 1690 test_utf8_bytes = test_utf8.encode('utf-8') 1691 1692 # 'Test' in another language, using UTF-8 charset. 1693 proto.Extensions[extension].str = test_utf8 1694 1695 # Serialize using the MessageSet wire format (this is specified in the 1696 # .proto file). 1697 serialized = proto.SerializeToString() 1698 1699 # Check byte size. 1700 self.assertEqual(proto.ByteSize(), len(serialized)) 1701 1702 raw = unittest_mset_pb2.RawMessageSet() 1703 bytes_read = raw.MergeFromString(serialized) 1704 self.assertEqual(len(serialized), bytes_read) 1705 1706 message2 = unittest_mset_pb2.TestMessageSetExtension2() 1707 1708 self.assertEqual(1, len(raw.item)) 1709 # Check that the type_id is the same as the tag ID in the .proto file. 1710 self.assertEqual(raw.item[0].type_id, 1547769) 1711 1712 # Check the actual bytes on the wire. 1713 self.assertTrue( 1714 raw.item[0].message.endswith(test_utf8_bytes)) 1715 bytes_read = message2.MergeFromString(raw.item[0].message) 1716 self.assertEqual(len(raw.item[0].message), bytes_read) 1717 1718 self.assertEqual(type(message2.str), unicode) 1719 self.assertEqual(message2.str, test_utf8) 1720 1721 # The pure Python API throws an exception on MergeFromString(), 1722 # if any of the string fields of the message can't be UTF-8 decoded. 1723 # The C++ implementation of the API has no way to check that on 1724 # MergeFromString and thus has no way to throw the exception. 1725 # 1726 # The pure Python API always returns objects of type 'unicode' (UTF-8 1727 # encoded), or 'bytes' (in 7 bit ASCII). 1728 badbytes = raw.item[0].message.replace( 1729 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 1730 1731 unicode_decode_failed = False 1732 try: 1733 message2.MergeFromString(badbytes) 1734 except UnicodeDecodeError: 1735 unicode_decode_failed = True 1736 string_field = message2.str 1737 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 1738 1739 def testBytesInTextFormat(self): 1740 proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 1741 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', 1742 unicode(proto)) 1743 1744 def testEmptyNestedMessage(self): 1745 proto = unittest_pb2.TestAllTypes() 1746 proto.optional_nested_message.MergeFrom( 1747 unittest_pb2.TestAllTypes.NestedMessage()) 1748 self.assertTrue(proto.HasField('optional_nested_message')) 1749 1750 proto = unittest_pb2.TestAllTypes() 1751 proto.optional_nested_message.CopyFrom( 1752 unittest_pb2.TestAllTypes.NestedMessage()) 1753 self.assertTrue(proto.HasField('optional_nested_message')) 1754 1755 proto = unittest_pb2.TestAllTypes() 1756 bytes_read = proto.optional_nested_message.MergeFromString(b'') 1757 self.assertEqual(0, bytes_read) 1758 self.assertTrue(proto.HasField('optional_nested_message')) 1759 1760 proto = unittest_pb2.TestAllTypes() 1761 proto.optional_nested_message.ParseFromString(b'') 1762 self.assertTrue(proto.HasField('optional_nested_message')) 1763 1764 serialized = proto.SerializeToString() 1765 proto2 = unittest_pb2.TestAllTypes() 1766 self.assertEqual( 1767 len(serialized), 1768 proto2.MergeFromString(serialized)) 1769 self.assertTrue(proto2.HasField('optional_nested_message')) 1770 1771 def testSetInParent(self): 1772 proto = unittest_pb2.TestAllTypes() 1773 self.assertFalse(proto.HasField('optionalgroup')) 1774 proto.optionalgroup.SetInParent() 1775 self.assertTrue(proto.HasField('optionalgroup')) 1776 1777 1778# Since we had so many tests for protocol buffer equality, we broke these out 1779# into separate TestCase classes. 1780 1781 1782class TestAllTypesEqualityTest(basetest.TestCase): 1783 1784 def setUp(self): 1785 self.first_proto = unittest_pb2.TestAllTypes() 1786 self.second_proto = unittest_pb2.TestAllTypes() 1787 1788 def testNotHashable(self): 1789 self.assertRaises(TypeError, hash, self.first_proto) 1790 1791 def testSelfEquality(self): 1792 self.assertEqual(self.first_proto, self.first_proto) 1793 1794 def testEmptyProtosEqual(self): 1795 self.assertEqual(self.first_proto, self.second_proto) 1796 1797 1798class FullProtosEqualityTest(basetest.TestCase): 1799 1800 """Equality tests using completely-full protos as a starting point.""" 1801 1802 def setUp(self): 1803 self.first_proto = unittest_pb2.TestAllTypes() 1804 self.second_proto = unittest_pb2.TestAllTypes() 1805 test_util.SetAllFields(self.first_proto) 1806 test_util.SetAllFields(self.second_proto) 1807 1808 def testNotHashable(self): 1809 self.assertRaises(TypeError, hash, self.first_proto) 1810 1811 def testNoneNotEqual(self): 1812 self.assertNotEqual(self.first_proto, None) 1813 self.assertNotEqual(None, self.second_proto) 1814 1815 def testNotEqualToOtherMessage(self): 1816 third_proto = unittest_pb2.TestRequired() 1817 self.assertNotEqual(self.first_proto, third_proto) 1818 self.assertNotEqual(third_proto, self.second_proto) 1819 1820 def testAllFieldsFilledEquality(self): 1821 self.assertEqual(self.first_proto, self.second_proto) 1822 1823 def testNonRepeatedScalar(self): 1824 # Nonrepeated scalar field change should cause inequality. 1825 self.first_proto.optional_int32 += 1 1826 self.assertNotEqual(self.first_proto, self.second_proto) 1827 # ...as should clearing a field. 1828 self.first_proto.ClearField('optional_int32') 1829 self.assertNotEqual(self.first_proto, self.second_proto) 1830 1831 def testNonRepeatedComposite(self): 1832 # Change a nonrepeated composite field. 1833 self.first_proto.optional_nested_message.bb += 1 1834 self.assertNotEqual(self.first_proto, self.second_proto) 1835 self.first_proto.optional_nested_message.bb -= 1 1836 self.assertEqual(self.first_proto, self.second_proto) 1837 # Clear a field in the nested message. 1838 self.first_proto.optional_nested_message.ClearField('bb') 1839 self.assertNotEqual(self.first_proto, self.second_proto) 1840 self.first_proto.optional_nested_message.bb = ( 1841 self.second_proto.optional_nested_message.bb) 1842 self.assertEqual(self.first_proto, self.second_proto) 1843 # Remove the nested message entirely. 1844 self.first_proto.ClearField('optional_nested_message') 1845 self.assertNotEqual(self.first_proto, self.second_proto) 1846 1847 def testRepeatedScalar(self): 1848 # Change a repeated scalar field. 1849 self.first_proto.repeated_int32.append(5) 1850 self.assertNotEqual(self.first_proto, self.second_proto) 1851 self.first_proto.ClearField('repeated_int32') 1852 self.assertNotEqual(self.first_proto, self.second_proto) 1853 1854 def testRepeatedComposite(self): 1855 # Change value within a repeated composite field. 1856 self.first_proto.repeated_nested_message[0].bb += 1 1857 self.assertNotEqual(self.first_proto, self.second_proto) 1858 self.first_proto.repeated_nested_message[0].bb -= 1 1859 self.assertEqual(self.first_proto, self.second_proto) 1860 # Add a value to a repeated composite field. 1861 self.first_proto.repeated_nested_message.add() 1862 self.assertNotEqual(self.first_proto, self.second_proto) 1863 self.second_proto.repeated_nested_message.add() 1864 self.assertEqual(self.first_proto, self.second_proto) 1865 1866 def testNonRepeatedScalarHasBits(self): 1867 # Ensure that we test "has" bits as well as value for 1868 # nonrepeated scalar field. 1869 self.first_proto.ClearField('optional_int32') 1870 self.second_proto.optional_int32 = 0 1871 self.assertNotEqual(self.first_proto, self.second_proto) 1872 1873 def testNonRepeatedCompositeHasBits(self): 1874 # Ensure that we test "has" bits as well as value for 1875 # nonrepeated composite field. 1876 self.first_proto.ClearField('optional_nested_message') 1877 self.second_proto.optional_nested_message.ClearField('bb') 1878 self.assertNotEqual(self.first_proto, self.second_proto) 1879 self.first_proto.optional_nested_message.bb = 0 1880 self.first_proto.optional_nested_message.ClearField('bb') 1881 self.assertEqual(self.first_proto, self.second_proto) 1882 1883 1884class ExtensionEqualityTest(basetest.TestCase): 1885 1886 def testExtensionEquality(self): 1887 first_proto = unittest_pb2.TestAllExtensions() 1888 second_proto = unittest_pb2.TestAllExtensions() 1889 self.assertEqual(first_proto, second_proto) 1890 test_util.SetAllExtensions(first_proto) 1891 self.assertNotEqual(first_proto, second_proto) 1892 test_util.SetAllExtensions(second_proto) 1893 self.assertEqual(first_proto, second_proto) 1894 1895 # Ensure that we check value equality. 1896 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 1897 self.assertNotEqual(first_proto, second_proto) 1898 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 1899 self.assertEqual(first_proto, second_proto) 1900 1901 # Ensure that we also look at "has" bits. 1902 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 1903 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1904 self.assertNotEqual(first_proto, second_proto) 1905 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1906 self.assertEqual(first_proto, second_proto) 1907 1908 # Ensure that differences in cached values 1909 # don't matter if "has" bits are both false. 1910 first_proto = unittest_pb2.TestAllExtensions() 1911 second_proto = unittest_pb2.TestAllExtensions() 1912 self.assertEqual( 1913 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 1914 self.assertEqual(first_proto, second_proto) 1915 1916 1917class MutualRecursionEqualityTest(basetest.TestCase): 1918 1919 def testEqualityWithMutualRecursion(self): 1920 first_proto = unittest_pb2.TestMutualRecursionA() 1921 second_proto = unittest_pb2.TestMutualRecursionA() 1922 self.assertEqual(first_proto, second_proto) 1923 first_proto.bb.a.bb.optional_int32 = 23 1924 self.assertNotEqual(first_proto, second_proto) 1925 second_proto.bb.a.bb.optional_int32 = 23 1926 self.assertEqual(first_proto, second_proto) 1927 1928 1929class ByteSizeTest(basetest.TestCase): 1930 1931 def setUp(self): 1932 self.proto = unittest_pb2.TestAllTypes() 1933 self.extended_proto = more_extensions_pb2.ExtendedMessage() 1934 self.packed_proto = unittest_pb2.TestPackedTypes() 1935 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 1936 1937 def Size(self): 1938 return self.proto.ByteSize() 1939 1940 def testEmptyMessage(self): 1941 self.assertEqual(0, self.proto.ByteSize()) 1942 1943 def testSizedOnKwargs(self): 1944 # Use a separate message to ensure testing right after creation. 1945 proto = unittest_pb2.TestAllTypes() 1946 self.assertEqual(0, proto.ByteSize()) 1947 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 1948 # One byte for the tag, one to encode varint 1. 1949 self.assertEqual(2, proto_kwargs.ByteSize()) 1950 1951 def testVarints(self): 1952 def Test(i, expected_varint_size): 1953 self.proto.Clear() 1954 self.proto.optional_int64 = i 1955 # Add one to the varint size for the tag info 1956 # for tag 1. 1957 self.assertEqual(expected_varint_size + 1, self.Size()) 1958 Test(0, 1) 1959 Test(1, 1) 1960 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 1961 Test((1 << i) - 1, num_bytes) 1962 Test(-1, 10) 1963 Test(-2, 10) 1964 Test(-(1 << 63), 10) 1965 1966 def testStrings(self): 1967 self.proto.optional_string = '' 1968 # Need one byte for tag info (tag #14), and one byte for length. 1969 self.assertEqual(2, self.Size()) 1970 1971 self.proto.optional_string = 'abc' 1972 # Need one byte for tag info (tag #14), and one byte for length. 1973 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 1974 1975 self.proto.optional_string = 'x' * 128 1976 # Need one byte for tag info (tag #14), and TWO bytes for length. 1977 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 1978 1979 def testOtherNumerics(self): 1980 self.proto.optional_fixed32 = 1234 1981 # One byte for tag and 4 bytes for fixed32. 1982 self.assertEqual(5, self.Size()) 1983 self.proto = unittest_pb2.TestAllTypes() 1984 1985 self.proto.optional_fixed64 = 1234 1986 # One byte for tag and 8 bytes for fixed64. 1987 self.assertEqual(9, self.Size()) 1988 self.proto = unittest_pb2.TestAllTypes() 1989 1990 self.proto.optional_float = 1.234 1991 # One byte for tag and 4 bytes for float. 1992 self.assertEqual(5, self.Size()) 1993 self.proto = unittest_pb2.TestAllTypes() 1994 1995 self.proto.optional_double = 1.234 1996 # One byte for tag and 8 bytes for float. 1997 self.assertEqual(9, self.Size()) 1998 self.proto = unittest_pb2.TestAllTypes() 1999 2000 self.proto.optional_sint32 = 64 2001 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2002 self.assertEqual(3, self.Size()) 2003 self.proto = unittest_pb2.TestAllTypes() 2004 2005 def testComposites(self): 2006 # 3 bytes. 2007 self.proto.optional_nested_message.bb = (1 << 14) 2008 # Plus one byte for bb tag. 2009 # Plus 1 byte for optional_nested_message serialized size. 2010 # Plus two bytes for optional_nested_message tag. 2011 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2012 2013 def testGroups(self): 2014 # 4 bytes. 2015 self.proto.optionalgroup.a = (1 << 21) 2016 # Plus two bytes for |a| tag. 2017 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2018 self.assertEqual(4 + 2 + 2*2, self.Size()) 2019 2020 def testRepeatedScalars(self): 2021 self.proto.repeated_int32.append(10) # 1 byte. 2022 self.proto.repeated_int32.append(128) # 2 bytes. 2023 # Also need 2 bytes for each entry for tag. 2024 self.assertEqual(1 + 2 + 2*2, self.Size()) 2025 2026 def testRepeatedScalarsExtend(self): 2027 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2028 # Also need 2 bytes for each entry for tag. 2029 self.assertEqual(1 + 2 + 2*2, self.Size()) 2030 2031 def testRepeatedScalarsRemove(self): 2032 self.proto.repeated_int32.append(10) # 1 byte. 2033 self.proto.repeated_int32.append(128) # 2 bytes. 2034 # Also need 2 bytes for each entry for tag. 2035 self.assertEqual(1 + 2 + 2*2, self.Size()) 2036 self.proto.repeated_int32.remove(128) 2037 self.assertEqual(1 + 2, self.Size()) 2038 2039 def testRepeatedComposites(self): 2040 # Empty message. 2 bytes tag plus 1 byte length. 2041 foreign_message_0 = self.proto.repeated_nested_message.add() 2042 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2043 foreign_message_1 = self.proto.repeated_nested_message.add() 2044 foreign_message_1.bb = 7 2045 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2046 2047 def testRepeatedCompositesDelete(self): 2048 # Empty message. 2 bytes tag plus 1 byte length. 2049 foreign_message_0 = self.proto.repeated_nested_message.add() 2050 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2051 foreign_message_1 = self.proto.repeated_nested_message.add() 2052 foreign_message_1.bb = 9 2053 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2054 2055 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2056 del self.proto.repeated_nested_message[0] 2057 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2058 2059 # Now add a new message. 2060 foreign_message_2 = self.proto.repeated_nested_message.add() 2061 foreign_message_2.bb = 12 2062 2063 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2064 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2065 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2066 2067 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2068 del self.proto.repeated_nested_message[1] 2069 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2070 2071 del self.proto.repeated_nested_message[0] 2072 self.assertEqual(0, self.Size()) 2073 2074 def testRepeatedGroups(self): 2075 # 2-byte START_GROUP plus 2-byte END_GROUP. 2076 group_0 = self.proto.repeatedgroup.add() 2077 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2078 # plus 2-byte END_GROUP. 2079 group_1 = self.proto.repeatedgroup.add() 2080 group_1.a = 7 2081 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2082 2083 def testExtensions(self): 2084 proto = unittest_pb2.TestAllExtensions() 2085 self.assertEqual(0, proto.ByteSize()) 2086 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2087 proto.Extensions[extension] = 23 2088 # 1 byte for tag, 1 byte for value. 2089 self.assertEqual(2, proto.ByteSize()) 2090 2091 def testCacheInvalidationForNonrepeatedScalar(self): 2092 # Test non-extension. 2093 self.proto.optional_int32 = 1 2094 self.assertEqual(2, self.proto.ByteSize()) 2095 self.proto.optional_int32 = 128 2096 self.assertEqual(3, self.proto.ByteSize()) 2097 self.proto.ClearField('optional_int32') 2098 self.assertEqual(0, self.proto.ByteSize()) 2099 2100 # Test within extension. 2101 extension = more_extensions_pb2.optional_int_extension 2102 self.extended_proto.Extensions[extension] = 1 2103 self.assertEqual(2, self.extended_proto.ByteSize()) 2104 self.extended_proto.Extensions[extension] = 128 2105 self.assertEqual(3, self.extended_proto.ByteSize()) 2106 self.extended_proto.ClearExtension(extension) 2107 self.assertEqual(0, self.extended_proto.ByteSize()) 2108 2109 def testCacheInvalidationForRepeatedScalar(self): 2110 # Test non-extension. 2111 self.proto.repeated_int32.append(1) 2112 self.assertEqual(3, self.proto.ByteSize()) 2113 self.proto.repeated_int32.append(1) 2114 self.assertEqual(6, self.proto.ByteSize()) 2115 self.proto.repeated_int32[1] = 128 2116 self.assertEqual(7, self.proto.ByteSize()) 2117 self.proto.ClearField('repeated_int32') 2118 self.assertEqual(0, self.proto.ByteSize()) 2119 2120 # Test within extension. 2121 extension = more_extensions_pb2.repeated_int_extension 2122 repeated = self.extended_proto.Extensions[extension] 2123 repeated.append(1) 2124 self.assertEqual(2, self.extended_proto.ByteSize()) 2125 repeated.append(1) 2126 self.assertEqual(4, self.extended_proto.ByteSize()) 2127 repeated[1] = 128 2128 self.assertEqual(5, self.extended_proto.ByteSize()) 2129 self.extended_proto.ClearExtension(extension) 2130 self.assertEqual(0, self.extended_proto.ByteSize()) 2131 2132 def testCacheInvalidationForNonrepeatedMessage(self): 2133 # Test non-extension. 2134 self.proto.optional_foreign_message.c = 1 2135 self.assertEqual(5, self.proto.ByteSize()) 2136 self.proto.optional_foreign_message.c = 128 2137 self.assertEqual(6, self.proto.ByteSize()) 2138 self.proto.optional_foreign_message.ClearField('c') 2139 self.assertEqual(3, self.proto.ByteSize()) 2140 self.proto.ClearField('optional_foreign_message') 2141 self.assertEqual(0, self.proto.ByteSize()) 2142 2143 if api_implementation.Type() == 'python': 2144 # This is only possible in pure-Python implementation of the API. 2145 child = self.proto.optional_foreign_message 2146 self.proto.ClearField('optional_foreign_message') 2147 child.c = 128 2148 self.assertEqual(0, self.proto.ByteSize()) 2149 2150 # Test within extension. 2151 extension = more_extensions_pb2.optional_message_extension 2152 child = self.extended_proto.Extensions[extension] 2153 self.assertEqual(0, self.extended_proto.ByteSize()) 2154 child.foreign_message_int = 1 2155 self.assertEqual(4, self.extended_proto.ByteSize()) 2156 child.foreign_message_int = 128 2157 self.assertEqual(5, self.extended_proto.ByteSize()) 2158 self.extended_proto.ClearExtension(extension) 2159 self.assertEqual(0, self.extended_proto.ByteSize()) 2160 2161 def testCacheInvalidationForRepeatedMessage(self): 2162 # Test non-extension. 2163 child0 = self.proto.repeated_foreign_message.add() 2164 self.assertEqual(3, self.proto.ByteSize()) 2165 self.proto.repeated_foreign_message.add() 2166 self.assertEqual(6, self.proto.ByteSize()) 2167 child0.c = 1 2168 self.assertEqual(8, self.proto.ByteSize()) 2169 self.proto.ClearField('repeated_foreign_message') 2170 self.assertEqual(0, self.proto.ByteSize()) 2171 2172 # Test within extension. 2173 extension = more_extensions_pb2.repeated_message_extension 2174 child_list = self.extended_proto.Extensions[extension] 2175 child0 = child_list.add() 2176 self.assertEqual(2, self.extended_proto.ByteSize()) 2177 child_list.add() 2178 self.assertEqual(4, self.extended_proto.ByteSize()) 2179 child0.foreign_message_int = 1 2180 self.assertEqual(6, self.extended_proto.ByteSize()) 2181 child0.ClearField('foreign_message_int') 2182 self.assertEqual(4, self.extended_proto.ByteSize()) 2183 self.extended_proto.ClearExtension(extension) 2184 self.assertEqual(0, self.extended_proto.ByteSize()) 2185 2186 def testPackedRepeatedScalars(self): 2187 self.assertEqual(0, self.packed_proto.ByteSize()) 2188 2189 self.packed_proto.packed_int32.append(10) # 1 byte. 2190 self.packed_proto.packed_int32.append(128) # 2 bytes. 2191 # The tag is 2 bytes (the field number is 90), and the varint 2192 # storing the length is 1 byte. 2193 int_size = 1 + 2 + 3 2194 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2195 2196 self.packed_proto.packed_double.append(4.2) # 8 bytes 2197 self.packed_proto.packed_double.append(3.25) # 8 bytes 2198 # 2 more tag bytes, 1 more length byte. 2199 double_size = 8 + 8 + 3 2200 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2201 2202 self.packed_proto.ClearField('packed_int32') 2203 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2204 2205 def testPackedExtensions(self): 2206 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2207 extension = self.packed_extended_proto.Extensions[ 2208 unittest_pb2.packed_fixed32_extension] 2209 extension.extend([1, 2, 3, 4]) # 16 bytes 2210 # Tag is 3 bytes. 2211 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2212 2213 2214# Issues to be sure to cover include: 2215# * Handling of unrecognized tags ("uninterpreted_bytes"). 2216# * Handling of MessageSets. 2217# * Consistent ordering of tags in the wire format, 2218# including ordering between extensions and non-extension 2219# fields. 2220# * Consistent serialization of negative numbers, especially 2221# negative int32s. 2222# * Handling of empty submessages (with and without "has" 2223# bits set). 2224 2225class SerializationTest(basetest.TestCase): 2226 2227 def testSerializeEmtpyMessage(self): 2228 first_proto = unittest_pb2.TestAllTypes() 2229 second_proto = unittest_pb2.TestAllTypes() 2230 serialized = first_proto.SerializeToString() 2231 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2232 self.assertEqual( 2233 len(serialized), 2234 second_proto.MergeFromString(serialized)) 2235 self.assertEqual(first_proto, second_proto) 2236 2237 def testSerializeAllFields(self): 2238 first_proto = unittest_pb2.TestAllTypes() 2239 second_proto = unittest_pb2.TestAllTypes() 2240 test_util.SetAllFields(first_proto) 2241 serialized = first_proto.SerializeToString() 2242 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2243 self.assertEqual( 2244 len(serialized), 2245 second_proto.MergeFromString(serialized)) 2246 self.assertEqual(first_proto, second_proto) 2247 2248 def testSerializeAllExtensions(self): 2249 first_proto = unittest_pb2.TestAllExtensions() 2250 second_proto = unittest_pb2.TestAllExtensions() 2251 test_util.SetAllExtensions(first_proto) 2252 serialized = first_proto.SerializeToString() 2253 self.assertEqual( 2254 len(serialized), 2255 second_proto.MergeFromString(serialized)) 2256 self.assertEqual(first_proto, second_proto) 2257 2258 def testSerializeWithOptionalGroup(self): 2259 first_proto = unittest_pb2.TestAllTypes() 2260 second_proto = unittest_pb2.TestAllTypes() 2261 first_proto.optionalgroup.a = 242 2262 serialized = first_proto.SerializeToString() 2263 self.assertEqual( 2264 len(serialized), 2265 second_proto.MergeFromString(serialized)) 2266 self.assertEqual(first_proto, second_proto) 2267 2268 def testSerializeNegativeValues(self): 2269 first_proto = unittest_pb2.TestAllTypes() 2270 2271 first_proto.optional_int32 = -1 2272 first_proto.optional_int64 = -(2 << 40) 2273 first_proto.optional_sint32 = -3 2274 first_proto.optional_sint64 = -(4 << 40) 2275 first_proto.optional_sfixed32 = -5 2276 first_proto.optional_sfixed64 = -(6 << 40) 2277 2278 second_proto = unittest_pb2.TestAllTypes.FromString( 2279 first_proto.SerializeToString()) 2280 2281 self.assertEqual(first_proto, second_proto) 2282 2283 def testParseTruncated(self): 2284 # This test is only applicable for the Python implementation of the API. 2285 if api_implementation.Type() != 'python': 2286 return 2287 2288 first_proto = unittest_pb2.TestAllTypes() 2289 test_util.SetAllFields(first_proto) 2290 serialized = first_proto.SerializeToString() 2291 2292 for truncation_point in xrange(len(serialized) + 1): 2293 try: 2294 second_proto = unittest_pb2.TestAllTypes() 2295 unknown_fields = unittest_pb2.TestEmptyMessage() 2296 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2297 # If we didn't raise an error then we read exactly the amount expected. 2298 self.assertEqual(truncation_point, pos) 2299 2300 # Parsing to unknown fields should not throw if parsing to known fields 2301 # did not. 2302 try: 2303 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2304 self.assertEqual(truncation_point, pos2) 2305 except message.DecodeError: 2306 self.fail('Parsing unknown fields failed when parsing known fields ' 2307 'did not.') 2308 except message.DecodeError: 2309 # Parsing unknown fields should also fail. 2310 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2311 serialized, 0, truncation_point) 2312 2313 def testCanonicalSerializationOrder(self): 2314 proto = more_messages_pb2.OutOfOrderFields() 2315 # These are also their tag numbers. Even though we're setting these in 2316 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2317 # file, they should nonetheless be serialized in tag order. 2318 proto.optional_sint32 = 5 2319 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2320 proto.optional_uint32 = 3 2321 proto.Extensions[more_messages_pb2.optional_int64] = 2 2322 proto.optional_int32 = 1 2323 serialized = proto.SerializeToString() 2324 self.assertEqual(proto.ByteSize(), len(serialized)) 2325 d = _MiniDecoder(serialized) 2326 ReadTag = d.ReadFieldNumberAndWireType 2327 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2328 self.assertEqual(1, d.ReadInt32()) 2329 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2330 self.assertEqual(2, d.ReadInt64()) 2331 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2332 self.assertEqual(3, d.ReadUInt32()) 2333 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2334 self.assertEqual(4, d.ReadUInt64()) 2335 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2336 self.assertEqual(5, d.ReadSInt32()) 2337 2338 def testCanonicalSerializationOrderSameAsCpp(self): 2339 # Copy of the same test we use for C++. 2340 proto = unittest_pb2.TestFieldOrderings() 2341 test_util.SetAllFieldsAndExtensions(proto) 2342 serialized = proto.SerializeToString() 2343 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2344 2345 def testMergeFromStringWhenFieldsAlreadySet(self): 2346 first_proto = unittest_pb2.TestAllTypes() 2347 first_proto.repeated_string.append('foobar') 2348 first_proto.optional_int32 = 23 2349 first_proto.optional_nested_message.bb = 42 2350 serialized = first_proto.SerializeToString() 2351 2352 second_proto = unittest_pb2.TestAllTypes() 2353 second_proto.repeated_string.append('baz') 2354 second_proto.optional_int32 = 100 2355 second_proto.optional_nested_message.bb = 999 2356 2357 bytes_parsed = second_proto.MergeFromString(serialized) 2358 self.assertEqual(len(serialized), bytes_parsed) 2359 2360 # Ensure that we append to repeated fields. 2361 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2362 # Ensure that we overwrite nonrepeatd scalars. 2363 self.assertEqual(23, second_proto.optional_int32) 2364 # Ensure that we recursively call MergeFromString() on 2365 # submessages. 2366 self.assertEqual(42, second_proto.optional_nested_message.bb) 2367 2368 def testMessageSetWireFormat(self): 2369 proto = unittest_mset_pb2.TestMessageSet() 2370 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 2371 extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 2372 extension1 = extension_message1.message_set_extension 2373 extension2 = extension_message2.message_set_extension 2374 proto.Extensions[extension1].i = 123 2375 proto.Extensions[extension2].str = 'foo' 2376 2377 # Serialize using the MessageSet wire format (this is specified in the 2378 # .proto file). 2379 serialized = proto.SerializeToString() 2380 2381 raw = unittest_mset_pb2.RawMessageSet() 2382 self.assertEqual(False, 2383 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2384 self.assertEqual( 2385 len(serialized), 2386 raw.MergeFromString(serialized)) 2387 self.assertEqual(2, len(raw.item)) 2388 2389 message1 = unittest_mset_pb2.TestMessageSetExtension1() 2390 self.assertEqual( 2391 len(raw.item[0].message), 2392 message1.MergeFromString(raw.item[0].message)) 2393 self.assertEqual(123, message1.i) 2394 2395 message2 = unittest_mset_pb2.TestMessageSetExtension2() 2396 self.assertEqual( 2397 len(raw.item[1].message), 2398 message2.MergeFromString(raw.item[1].message)) 2399 self.assertEqual('foo', message2.str) 2400 2401 # Deserialize using the MessageSet wire format. 2402 proto2 = unittest_mset_pb2.TestMessageSet() 2403 self.assertEqual( 2404 len(serialized), 2405 proto2.MergeFromString(serialized)) 2406 self.assertEqual(123, proto2.Extensions[extension1].i) 2407 self.assertEqual('foo', proto2.Extensions[extension2].str) 2408 2409 # Check byte size. 2410 self.assertEqual(proto2.ByteSize(), len(serialized)) 2411 self.assertEqual(proto.ByteSize(), len(serialized)) 2412 2413 def testMessageSetWireFormatUnknownExtension(self): 2414 # Create a message using the message set wire format with an unknown 2415 # message. 2416 raw = unittest_mset_pb2.RawMessageSet() 2417 2418 # Add an item. 2419 item = raw.item.add() 2420 item.type_id = 1545008 2421 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 2422 message1 = unittest_mset_pb2.TestMessageSetExtension1() 2423 message1.i = 12345 2424 item.message = message1.SerializeToString() 2425 2426 # Add a second, unknown extension. 2427 item = raw.item.add() 2428 item.type_id = 1545009 2429 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 2430 message1 = unittest_mset_pb2.TestMessageSetExtension1() 2431 message1.i = 12346 2432 item.message = message1.SerializeToString() 2433 2434 # Add another unknown extension. 2435 item = raw.item.add() 2436 item.type_id = 1545010 2437 message1 = unittest_mset_pb2.TestMessageSetExtension2() 2438 message1.str = 'foo' 2439 item.message = message1.SerializeToString() 2440 2441 serialized = raw.SerializeToString() 2442 2443 # Parse message using the message set wire format. 2444 proto = unittest_mset_pb2.TestMessageSet() 2445 self.assertEqual( 2446 len(serialized), 2447 proto.MergeFromString(serialized)) 2448 2449 # Check that the message parsed well. 2450 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 2451 extension1 = extension_message1.message_set_extension 2452 self.assertEquals(12345, proto.Extensions[extension1].i) 2453 2454 def testUnknownFields(self): 2455 proto = unittest_pb2.TestAllTypes() 2456 test_util.SetAllFields(proto) 2457 2458 serialized = proto.SerializeToString() 2459 2460 # The empty message should be parsable with all of the fields 2461 # unknown. 2462 proto2 = unittest_pb2.TestEmptyMessage() 2463 2464 # Parsing this message should succeed. 2465 self.assertEqual( 2466 len(serialized), 2467 proto2.MergeFromString(serialized)) 2468 2469 # Now test with a int64 field set. 2470 proto = unittest_pb2.TestAllTypes() 2471 proto.optional_int64 = 0x0fffffffffffffff 2472 serialized = proto.SerializeToString() 2473 # The empty message should be parsable with all of the fields 2474 # unknown. 2475 proto2 = unittest_pb2.TestEmptyMessage() 2476 # Parsing this message should succeed. 2477 self.assertEqual( 2478 len(serialized), 2479 proto2.MergeFromString(serialized)) 2480 2481 def _CheckRaises(self, exc_class, callable_obj, exception): 2482 """This method checks if the excpetion type and message are as expected.""" 2483 try: 2484 callable_obj() 2485 except exc_class as ex: 2486 # Check if the exception message is the right one. 2487 self.assertEqual(exception, str(ex)) 2488 return 2489 else: 2490 raise self.failureException('%s not raised' % str(exc_class)) 2491 2492 def testSerializeUninitialized(self): 2493 proto = unittest_pb2.TestRequired() 2494 self._CheckRaises( 2495 message.EncodeError, 2496 proto.SerializeToString, 2497 'Message protobuf_unittest.TestRequired is missing required fields: ' 2498 'a,b,c') 2499 # Shouldn't raise exceptions. 2500 partial = proto.SerializePartialToString() 2501 2502 proto2 = unittest_pb2.TestRequired() 2503 self.assertFalse(proto2.HasField('a')) 2504 # proto2 ParseFromString does not check that required fields are set. 2505 proto2.ParseFromString(partial) 2506 self.assertFalse(proto2.HasField('a')) 2507 2508 proto.a = 1 2509 self._CheckRaises( 2510 message.EncodeError, 2511 proto.SerializeToString, 2512 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2513 # Shouldn't raise exceptions. 2514 partial = proto.SerializePartialToString() 2515 2516 proto.b = 2 2517 self._CheckRaises( 2518 message.EncodeError, 2519 proto.SerializeToString, 2520 'Message protobuf_unittest.TestRequired is missing required fields: c') 2521 # Shouldn't raise exceptions. 2522 partial = proto.SerializePartialToString() 2523 2524 proto.c = 3 2525 serialized = proto.SerializeToString() 2526 # Shouldn't raise exceptions. 2527 partial = proto.SerializePartialToString() 2528 2529 proto2 = unittest_pb2.TestRequired() 2530 self.assertEqual( 2531 len(serialized), 2532 proto2.MergeFromString(serialized)) 2533 self.assertEqual(1, proto2.a) 2534 self.assertEqual(2, proto2.b) 2535 self.assertEqual(3, proto2.c) 2536 self.assertEqual( 2537 len(partial), 2538 proto2.MergeFromString(partial)) 2539 self.assertEqual(1, proto2.a) 2540 self.assertEqual(2, proto2.b) 2541 self.assertEqual(3, proto2.c) 2542 2543 def testSerializeUninitializedSubMessage(self): 2544 proto = unittest_pb2.TestRequiredForeign() 2545 2546 # Sub-message doesn't exist yet, so this succeeds. 2547 proto.SerializeToString() 2548 2549 proto.optional_message.a = 1 2550 self._CheckRaises( 2551 message.EncodeError, 2552 proto.SerializeToString, 2553 'Message protobuf_unittest.TestRequiredForeign ' 2554 'is missing required fields: ' 2555 'optional_message.b,optional_message.c') 2556 2557 proto.optional_message.b = 2 2558 proto.optional_message.c = 3 2559 proto.SerializeToString() 2560 2561 proto.repeated_message.add().a = 1 2562 proto.repeated_message.add().b = 2 2563 self._CheckRaises( 2564 message.EncodeError, 2565 proto.SerializeToString, 2566 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2567 'repeated_message[0].b,repeated_message[0].c,' 2568 'repeated_message[1].a,repeated_message[1].c') 2569 2570 proto.repeated_message[0].b = 2 2571 proto.repeated_message[0].c = 3 2572 proto.repeated_message[1].a = 1 2573 proto.repeated_message[1].c = 3 2574 proto.SerializeToString() 2575 2576 def testSerializeAllPackedFields(self): 2577 first_proto = unittest_pb2.TestPackedTypes() 2578 second_proto = unittest_pb2.TestPackedTypes() 2579 test_util.SetAllPackedFields(first_proto) 2580 serialized = first_proto.SerializeToString() 2581 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2582 bytes_read = second_proto.MergeFromString(serialized) 2583 self.assertEqual(second_proto.ByteSize(), bytes_read) 2584 self.assertEqual(first_proto, second_proto) 2585 2586 def testSerializeAllPackedExtensions(self): 2587 first_proto = unittest_pb2.TestPackedExtensions() 2588 second_proto = unittest_pb2.TestPackedExtensions() 2589 test_util.SetAllPackedExtensions(first_proto) 2590 serialized = first_proto.SerializeToString() 2591 bytes_read = second_proto.MergeFromString(serialized) 2592 self.assertEqual(second_proto.ByteSize(), bytes_read) 2593 self.assertEqual(first_proto, second_proto) 2594 2595 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2596 first_proto = unittest_pb2.TestPackedTypes() 2597 first_proto.packed_int32.extend([1, 2]) 2598 first_proto.packed_double.append(3.0) 2599 serialized = first_proto.SerializeToString() 2600 2601 second_proto = unittest_pb2.TestPackedTypes() 2602 second_proto.packed_int32.append(3) 2603 second_proto.packed_double.extend([1.0, 2.0]) 2604 second_proto.packed_sint32.append(4) 2605 2606 self.assertEqual( 2607 len(serialized), 2608 second_proto.MergeFromString(serialized)) 2609 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2610 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2611 self.assertEqual([4], second_proto.packed_sint32) 2612 2613 def testPackedFieldsWireFormat(self): 2614 proto = unittest_pb2.TestPackedTypes() 2615 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2616 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2617 proto.packed_float.append(2.0) # 4 bytes, will be before double 2618 serialized = proto.SerializeToString() 2619 self.assertEqual(proto.ByteSize(), len(serialized)) 2620 d = _MiniDecoder(serialized) 2621 ReadTag = d.ReadFieldNumberAndWireType 2622 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2623 self.assertEqual(1+1+1+2, d.ReadInt32()) 2624 self.assertEqual(1, d.ReadInt32()) 2625 self.assertEqual(2, d.ReadInt32()) 2626 self.assertEqual(150, d.ReadInt32()) 2627 self.assertEqual(3, d.ReadInt32()) 2628 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2629 self.assertEqual(4, d.ReadInt32()) 2630 self.assertEqual(2.0, d.ReadFloat()) 2631 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2632 self.assertEqual(8+8, d.ReadInt32()) 2633 self.assertEqual(1.0, d.ReadDouble()) 2634 self.assertEqual(1000.0, d.ReadDouble()) 2635 self.assertTrue(d.EndOfStream()) 2636 2637 def testParsePackedFromUnpacked(self): 2638 unpacked = unittest_pb2.TestUnpackedTypes() 2639 test_util.SetAllUnpackedFields(unpacked) 2640 packed = unittest_pb2.TestPackedTypes() 2641 serialized = unpacked.SerializeToString() 2642 self.assertEqual( 2643 len(serialized), 2644 packed.MergeFromString(serialized)) 2645 expected = unittest_pb2.TestPackedTypes() 2646 test_util.SetAllPackedFields(expected) 2647 self.assertEqual(expected, packed) 2648 2649 def testParseUnpackedFromPacked(self): 2650 packed = unittest_pb2.TestPackedTypes() 2651 test_util.SetAllPackedFields(packed) 2652 unpacked = unittest_pb2.TestUnpackedTypes() 2653 serialized = packed.SerializeToString() 2654 self.assertEqual( 2655 len(serialized), 2656 unpacked.MergeFromString(serialized)) 2657 expected = unittest_pb2.TestUnpackedTypes() 2658 test_util.SetAllUnpackedFields(expected) 2659 self.assertEqual(expected, unpacked) 2660 2661 def testFieldNumbers(self): 2662 proto = unittest_pb2.TestAllTypes() 2663 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2664 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2665 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 2666 self.assertEqual( 2667 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 2668 self.assertEqual( 2669 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 2670 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 2671 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 2672 self.assertEqual( 2673 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 2674 self.assertEqual( 2675 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 2676 2677 def testExtensionFieldNumbers(self): 2678 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 2679 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 2680 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 2681 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 2682 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 2683 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 2684 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 2685 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 2686 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 2687 self.assertEqual( 2688 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 2689 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 2690 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2691 21) 2692 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 2693 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 2694 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 2695 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 2696 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 2697 self.assertEqual( 2698 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 2699 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 2700 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2701 51) 2702 2703 def testInitKwargs(self): 2704 proto = unittest_pb2.TestAllTypes( 2705 optional_int32=1, 2706 optional_string='foo', 2707 optional_bool=True, 2708 optional_bytes=b'bar', 2709 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 2710 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 2711 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 2712 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 2713 repeated_int32=[1, 2, 3]) 2714 self.assertTrue(proto.IsInitialized()) 2715 self.assertTrue(proto.HasField('optional_int32')) 2716 self.assertTrue(proto.HasField('optional_string')) 2717 self.assertTrue(proto.HasField('optional_bool')) 2718 self.assertTrue(proto.HasField('optional_bytes')) 2719 self.assertTrue(proto.HasField('optional_nested_message')) 2720 self.assertTrue(proto.HasField('optional_foreign_message')) 2721 self.assertTrue(proto.HasField('optional_nested_enum')) 2722 self.assertTrue(proto.HasField('optional_foreign_enum')) 2723 self.assertEqual(1, proto.optional_int32) 2724 self.assertEqual('foo', proto.optional_string) 2725 self.assertEqual(True, proto.optional_bool) 2726 self.assertEqual(b'bar', proto.optional_bytes) 2727 self.assertEqual(1, proto.optional_nested_message.bb) 2728 self.assertEqual(1, proto.optional_foreign_message.c) 2729 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 2730 proto.optional_nested_enum) 2731 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 2732 self.assertEqual([1, 2, 3], proto.repeated_int32) 2733 2734 def testInitArgsUnknownFieldName(self): 2735 def InitalizeEmptyMessageWithExtraKeywordArg(): 2736 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 2737 self._CheckRaises(ValueError, 2738 InitalizeEmptyMessageWithExtraKeywordArg, 2739 'Protocol message has no "unknown" field.') 2740 2741 def testInitRequiredKwargs(self): 2742 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 2743 self.assertTrue(proto.IsInitialized()) 2744 self.assertTrue(proto.HasField('a')) 2745 self.assertTrue(proto.HasField('b')) 2746 self.assertTrue(proto.HasField('c')) 2747 self.assertTrue(not proto.HasField('dummy2')) 2748 self.assertEqual(1, proto.a) 2749 self.assertEqual(1, proto.b) 2750 self.assertEqual(1, proto.c) 2751 2752 def testInitRequiredForeignKwargs(self): 2753 proto = unittest_pb2.TestRequiredForeign( 2754 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 2755 self.assertTrue(proto.IsInitialized()) 2756 self.assertTrue(proto.HasField('optional_message')) 2757 self.assertTrue(proto.optional_message.IsInitialized()) 2758 self.assertTrue(proto.optional_message.HasField('a')) 2759 self.assertTrue(proto.optional_message.HasField('b')) 2760 self.assertTrue(proto.optional_message.HasField('c')) 2761 self.assertTrue(not proto.optional_message.HasField('dummy2')) 2762 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 2763 proto.optional_message) 2764 self.assertEqual(1, proto.optional_message.a) 2765 self.assertEqual(1, proto.optional_message.b) 2766 self.assertEqual(1, proto.optional_message.c) 2767 2768 def testInitRepeatedKwargs(self): 2769 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 2770 self.assertTrue(proto.IsInitialized()) 2771 self.assertEqual(1, proto.repeated_int32[0]) 2772 self.assertEqual(2, proto.repeated_int32[1]) 2773 self.assertEqual(3, proto.repeated_int32[2]) 2774 2775 2776class OptionsTest(basetest.TestCase): 2777 2778 def testMessageOptions(self): 2779 proto = unittest_mset_pb2.TestMessageSet() 2780 self.assertEqual(True, 2781 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2782 proto = unittest_pb2.TestAllTypes() 2783 self.assertEqual(False, 2784 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2785 2786 def testPackedOptions(self): 2787 proto = unittest_pb2.TestAllTypes() 2788 proto.optional_int32 = 1 2789 proto.optional_double = 3.0 2790 for field_descriptor, _ in proto.ListFields(): 2791 self.assertEqual(False, field_descriptor.GetOptions().packed) 2792 2793 proto = unittest_pb2.TestPackedTypes() 2794 proto.packed_int32.append(1) 2795 proto.packed_double.append(3.0) 2796 for field_descriptor, _ in proto.ListFields(): 2797 self.assertEqual(True, field_descriptor.GetOptions().packed) 2798 self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, 2799 field_descriptor.label) 2800 2801 2802 2803class ClassAPITest(basetest.TestCase): 2804 2805 def testMakeClassWithNestedDescriptor(self): 2806 leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', 2807 containing_type=None, fields=[], 2808 nested_types=[], enum_types=[], 2809 extensions=[]) 2810 child_desc = descriptor.Descriptor('child', 'package.parent.child', '', 2811 containing_type=None, fields=[], 2812 nested_types=[leaf_desc], enum_types=[], 2813 extensions=[]) 2814 sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', 2815 '', containing_type=None, fields=[], 2816 nested_types=[], enum_types=[], 2817 extensions=[]) 2818 parent_desc = descriptor.Descriptor('parent', 'package.parent', '', 2819 containing_type=None, fields=[], 2820 nested_types=[child_desc, sibling_desc], 2821 enum_types=[], extensions=[]) 2822 message_class = reflection.MakeClass(parent_desc) 2823 self.assertIn('child', message_class.__dict__) 2824 self.assertIn('sibling', message_class.__dict__) 2825 self.assertIn('leaf', message_class.child.__dict__) 2826 2827 def _GetSerializedFileDescriptor(self, name): 2828 """Get a serialized representation of a test FileDescriptorProto. 2829 2830 Args: 2831 name: All calls to this must use a unique message name, to avoid 2832 collisions in the cpp descriptor pool. 2833 Returns: 2834 A string containing the serialized form of a test FileDescriptorProto. 2835 """ 2836 file_descriptor_str = ( 2837 'message_type {' 2838 ' name: "' + name + '"' 2839 ' field {' 2840 ' name: "flat"' 2841 ' number: 1' 2842 ' label: LABEL_REPEATED' 2843 ' type: TYPE_UINT32' 2844 ' }' 2845 ' field {' 2846 ' name: "bar"' 2847 ' number: 2' 2848 ' label: LABEL_OPTIONAL' 2849 ' type: TYPE_MESSAGE' 2850 ' type_name: "Bar"' 2851 ' }' 2852 ' nested_type {' 2853 ' name: "Bar"' 2854 ' field {' 2855 ' name: "baz"' 2856 ' number: 3' 2857 ' label: LABEL_OPTIONAL' 2858 ' type: TYPE_MESSAGE' 2859 ' type_name: "Baz"' 2860 ' }' 2861 ' nested_type {' 2862 ' name: "Baz"' 2863 ' enum_type {' 2864 ' name: "deep_enum"' 2865 ' value {' 2866 ' name: "VALUE_A"' 2867 ' number: 0' 2868 ' }' 2869 ' }' 2870 ' field {' 2871 ' name: "deep"' 2872 ' number: 4' 2873 ' label: LABEL_OPTIONAL' 2874 ' type: TYPE_UINT32' 2875 ' }' 2876 ' }' 2877 ' }' 2878 '}') 2879 file_descriptor = descriptor_pb2.FileDescriptorProto() 2880 text_format.Merge(file_descriptor_str, file_descriptor) 2881 return file_descriptor.SerializeToString() 2882 2883 def testParsingFlatClassWithExplicitClassDeclaration(self): 2884 """Test that the generated class can parse a flat message.""" 2885 file_descriptor = descriptor_pb2.FileDescriptorProto() 2886 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 2887 msg_descriptor = descriptor.MakeDescriptor( 2888 file_descriptor.message_type[0]) 2889 2890 class MessageClass(message.Message): 2891 __metaclass__ = reflection.GeneratedProtocolMessageType 2892 DESCRIPTOR = msg_descriptor 2893 msg = MessageClass() 2894 msg_str = ( 2895 'flat: 0 ' 2896 'flat: 1 ' 2897 'flat: 2 ') 2898 text_format.Merge(msg_str, msg) 2899 self.assertEqual(msg.flat, [0, 1, 2]) 2900 2901 def testParsingFlatClass(self): 2902 """Test that the generated class can parse a flat message.""" 2903 file_descriptor = descriptor_pb2.FileDescriptorProto() 2904 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 2905 msg_descriptor = descriptor.MakeDescriptor( 2906 file_descriptor.message_type[0]) 2907 msg_class = reflection.MakeClass(msg_descriptor) 2908 msg = msg_class() 2909 msg_str = ( 2910 'flat: 0 ' 2911 'flat: 1 ' 2912 'flat: 2 ') 2913 text_format.Merge(msg_str, msg) 2914 self.assertEqual(msg.flat, [0, 1, 2]) 2915 2916 def testParsingNestedClass(self): 2917 """Test that the generated class can parse a nested message.""" 2918 file_descriptor = descriptor_pb2.FileDescriptorProto() 2919 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 2920 msg_descriptor = descriptor.MakeDescriptor( 2921 file_descriptor.message_type[0]) 2922 msg_class = reflection.MakeClass(msg_descriptor) 2923 msg = msg_class() 2924 msg_str = ( 2925 'bar {' 2926 ' baz {' 2927 ' deep: 4' 2928 ' }' 2929 '}') 2930 text_format.Merge(msg_str, msg) 2931 self.assertEqual(msg.bar.baz.deep, 4) 2932 2933if __name__ == '__main__': 2934 basetest.main() 2935