1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38import copy
39import gc
40import operator
41import six
42import struct
43
44try:
45  import unittest2 as unittest  #PY26
46except ImportError:
47  import unittest
48
49from google.protobuf import unittest_import_pb2
50from google.protobuf import unittest_mset_pb2
51from google.protobuf import unittest_pb2
52from google.protobuf import descriptor_pb2
53from google.protobuf import descriptor
54from google.protobuf import message
55from google.protobuf import reflection
56from google.protobuf import text_format
57from google.protobuf.internal import api_implementation
58from google.protobuf.internal import more_extensions_pb2
59from google.protobuf.internal import more_messages_pb2
60from google.protobuf.internal import message_set_extensions_pb2
61from google.protobuf.internal import wire_format
62from google.protobuf.internal import test_util
63from google.protobuf.internal import decoder
64
65
66class _MiniDecoder(object):
67  """Decodes a stream of values from a string.
68
69  Once upon a time we actually had a class called decoder.Decoder.  Then we
70  got rid of it during a redesign that made decoding much, much faster overall.
71  But a couple tests in this file used it to check that the serialized form of
72  a message was correct.  So, this class implements just the methods that were
73  used by said tests, so that we don't have to rewrite the tests.
74  """
75
76  def __init__(self, bytes):
77    self._bytes = bytes
78    self._pos = 0
79
80  def ReadVarint(self):
81    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
82    return result
83
84  ReadInt32 = ReadVarint
85  ReadInt64 = ReadVarint
86  ReadUInt32 = ReadVarint
87  ReadUInt64 = ReadVarint
88
89  def ReadSInt64(self):
90    return wire_format.ZigZagDecode(self.ReadVarint())
91
92  ReadSInt32 = ReadSInt64
93
94  def ReadFieldNumberAndWireType(self):
95    return wire_format.UnpackTag(self.ReadVarint())
96
97  def ReadFloat(self):
98    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
99    self._pos += 4
100    return result
101
102  def ReadDouble(self):
103    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
104    self._pos += 8
105    return result
106
107  def EndOfStream(self):
108    return self._pos == len(self._bytes)
109
110
111class ReflectionTest(unittest.TestCase):
112
113  def assertListsEqual(self, values, others):
114    self.assertEqual(len(values), len(others))
115    for i in range(len(values)):
116      self.assertEqual(values[i], others[i])
117
118  def testScalarConstructor(self):
119    # Constructor with only scalar types should succeed.
120    proto = unittest_pb2.TestAllTypes(
121        optional_int32=24,
122        optional_double=54.321,
123        optional_string='optional_string',
124        optional_float=None)
125
126    self.assertEqual(24, proto.optional_int32)
127    self.assertEqual(54.321, proto.optional_double)
128    self.assertEqual('optional_string', proto.optional_string)
129    self.assertFalse(proto.HasField("optional_float"))
130
131  def testRepeatedScalarConstructor(self):
132    # Constructor with only repeated scalar types should succeed.
133    proto = unittest_pb2.TestAllTypes(
134        repeated_int32=[1, 2, 3, 4],
135        repeated_double=[1.23, 54.321],
136        repeated_bool=[True, False, False],
137        repeated_string=["optional_string"],
138        repeated_float=None)
139
140    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
141    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
142    self.assertEqual([True, False, False], list(proto.repeated_bool))
143    self.assertEqual(["optional_string"], list(proto.repeated_string))
144    self.assertEqual([], list(proto.repeated_float))
145
146  def testRepeatedCompositeConstructor(self):
147    # Constructor with only repeated composite types should succeed.
148    proto = unittest_pb2.TestAllTypes(
149        repeated_nested_message=[
150            unittest_pb2.TestAllTypes.NestedMessage(
151                bb=unittest_pb2.TestAllTypes.FOO),
152            unittest_pb2.TestAllTypes.NestedMessage(
153                bb=unittest_pb2.TestAllTypes.BAR)],
154        repeated_foreign_message=[
155            unittest_pb2.ForeignMessage(c=-43),
156            unittest_pb2.ForeignMessage(c=45324),
157            unittest_pb2.ForeignMessage(c=12)],
158        repeatedgroup=[
159            unittest_pb2.TestAllTypes.RepeatedGroup(),
160            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
161            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
162
163    self.assertEqual(
164        [unittest_pb2.TestAllTypes.NestedMessage(
165            bb=unittest_pb2.TestAllTypes.FOO),
166         unittest_pb2.TestAllTypes.NestedMessage(
167             bb=unittest_pb2.TestAllTypes.BAR)],
168        list(proto.repeated_nested_message))
169    self.assertEqual(
170        [unittest_pb2.ForeignMessage(c=-43),
171         unittest_pb2.ForeignMessage(c=45324),
172         unittest_pb2.ForeignMessage(c=12)],
173        list(proto.repeated_foreign_message))
174    self.assertEqual(
175        [unittest_pb2.TestAllTypes.RepeatedGroup(),
176         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
177         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
178        list(proto.repeatedgroup))
179
180  def testMixedConstructor(self):
181    # Constructor with only mixed types should succeed.
182    proto = unittest_pb2.TestAllTypes(
183        optional_int32=24,
184        optional_string='optional_string',
185        repeated_double=[1.23, 54.321],
186        repeated_bool=[True, False, False],
187        repeated_nested_message=[
188            unittest_pb2.TestAllTypes.NestedMessage(
189                bb=unittest_pb2.TestAllTypes.FOO),
190            unittest_pb2.TestAllTypes.NestedMessage(
191                bb=unittest_pb2.TestAllTypes.BAR)],
192        repeated_foreign_message=[
193            unittest_pb2.ForeignMessage(c=-43),
194            unittest_pb2.ForeignMessage(c=45324),
195            unittest_pb2.ForeignMessage(c=12)],
196        optional_nested_message=None)
197
198    self.assertEqual(24, proto.optional_int32)
199    self.assertEqual('optional_string', proto.optional_string)
200    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
201    self.assertEqual([True, False, False], list(proto.repeated_bool))
202    self.assertEqual(
203        [unittest_pb2.TestAllTypes.NestedMessage(
204            bb=unittest_pb2.TestAllTypes.FOO),
205         unittest_pb2.TestAllTypes.NestedMessage(
206             bb=unittest_pb2.TestAllTypes.BAR)],
207        list(proto.repeated_nested_message))
208    self.assertEqual(
209        [unittest_pb2.ForeignMessage(c=-43),
210         unittest_pb2.ForeignMessage(c=45324),
211         unittest_pb2.ForeignMessage(c=12)],
212        list(proto.repeated_foreign_message))
213    self.assertFalse(proto.HasField("optional_nested_message"))
214
215  def testConstructorTypeError(self):
216    self.assertRaises(
217        TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
218    self.assertRaises(
219        TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
220    self.assertRaises(
221        TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
222    self.assertRaises(
223        TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
224    self.assertRaises(
225        TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
226    self.assertRaises(
227        TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
228    self.assertRaises(
229        TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
230    self.assertRaises(
231        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
232    self.assertRaises(
233        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
234
235  def testConstructorInvalidatesCachedByteSize(self):
236    message = unittest_pb2.TestAllTypes(optional_int32 = 12)
237    self.assertEqual(2, message.ByteSize())
238
239    message = unittest_pb2.TestAllTypes(
240        optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
241    self.assertEqual(3, message.ByteSize())
242
243    message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
244    self.assertEqual(3, message.ByteSize())
245
246    message = unittest_pb2.TestAllTypes(
247        repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
248    self.assertEqual(3, message.ByteSize())
249
250  def testSimpleHasBits(self):
251    # Test a scalar.
252    proto = unittest_pb2.TestAllTypes()
253    self.assertTrue(not proto.HasField('optional_int32'))
254    self.assertEqual(0, proto.optional_int32)
255    # HasField() shouldn't be true if all we've done is
256    # read the default value.
257    self.assertTrue(not proto.HasField('optional_int32'))
258    proto.optional_int32 = 1
259    # Setting a value however *should* set the "has" bit.
260    self.assertTrue(proto.HasField('optional_int32'))
261    proto.ClearField('optional_int32')
262    # And clearing that value should unset the "has" bit.
263    self.assertTrue(not proto.HasField('optional_int32'))
264
265  def testHasBitsWithSinglyNestedScalar(self):
266    # Helper used to test foreign messages and groups.
267    #
268    # composite_field_name should be the name of a non-repeated
269    # composite (i.e., foreign or group) field in TestAllTypes,
270    # and scalar_field_name should be the name of an integer-valued
271    # scalar field within that composite.
272    #
273    # I never thought I'd miss C++ macros and templates so much. :(
274    # This helper is semantically just:
275    #
276    #   assert proto.composite_field.scalar_field == 0
277    #   assert not proto.composite_field.HasField('scalar_field')
278    #   assert not proto.HasField('composite_field')
279    #
280    #   proto.composite_field.scalar_field = 10
281    #   old_composite_field = proto.composite_field
282    #
283    #   assert proto.composite_field.scalar_field == 10
284    #   assert proto.composite_field.HasField('scalar_field')
285    #   assert proto.HasField('composite_field')
286    #
287    #   proto.ClearField('composite_field')
288    #
289    #   assert not proto.composite_field.HasField('scalar_field')
290    #   assert not proto.HasField('composite_field')
291    #   assert proto.composite_field.scalar_field == 0
292    #
293    #   # Now ensure that ClearField('composite_field') disconnected
294    #   # the old field object from the object tree...
295    #   assert old_composite_field is not proto.composite_field
296    #   old_composite_field.scalar_field = 20
297    #   assert not proto.composite_field.HasField('scalar_field')
298    #   assert not proto.HasField('composite_field')
299    def TestCompositeHasBits(composite_field_name, scalar_field_name):
300      proto = unittest_pb2.TestAllTypes()
301      # First, check that we can get the scalar value, and see that it's the
302      # default (0), but that proto.HasField('omposite') and
303      # proto.composite.HasField('scalar') will still return False.
304      composite_field = getattr(proto, composite_field_name)
305      original_scalar_value = getattr(composite_field, scalar_field_name)
306      self.assertEqual(0, original_scalar_value)
307      # Assert that the composite object does not "have" the scalar.
308      self.assertTrue(not composite_field.HasField(scalar_field_name))
309      # Assert that proto does not "have" the composite field.
310      self.assertTrue(not proto.HasField(composite_field_name))
311
312      # Now set the scalar within the composite field.  Ensure that the setting
313      # is reflected, and that proto.HasField('composite') and
314      # proto.composite.HasField('scalar') now both return True.
315      new_val = 20
316      setattr(composite_field, scalar_field_name, new_val)
317      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
318      # Hold on to a reference to the current composite_field object.
319      old_composite_field = composite_field
320      # Assert that the has methods now return true.
321      self.assertTrue(composite_field.HasField(scalar_field_name))
322      self.assertTrue(proto.HasField(composite_field_name))
323
324      # Now call the clear method...
325      proto.ClearField(composite_field_name)
326
327      # ...and ensure that the "has" bits are all back to False...
328      composite_field = getattr(proto, composite_field_name)
329      self.assertTrue(not composite_field.HasField(scalar_field_name))
330      self.assertTrue(not proto.HasField(composite_field_name))
331      # ...and ensure that the scalar field has returned to its default.
332      self.assertEqual(0, getattr(composite_field, scalar_field_name))
333
334      self.assertTrue(old_composite_field is not composite_field)
335      setattr(old_composite_field, scalar_field_name, new_val)
336      self.assertTrue(not composite_field.HasField(scalar_field_name))
337      self.assertTrue(not proto.HasField(composite_field_name))
338      self.assertEqual(0, getattr(composite_field, scalar_field_name))
339
340    # Test simple, single-level nesting when we set a scalar.
341    TestCompositeHasBits('optionalgroup', 'a')
342    TestCompositeHasBits('optional_nested_message', 'bb')
343    TestCompositeHasBits('optional_foreign_message', 'c')
344    TestCompositeHasBits('optional_import_message', 'd')
345
346  def testReferencesToNestedMessage(self):
347    proto = unittest_pb2.TestAllTypes()
348    nested = proto.optional_nested_message
349    del proto
350    # A previous version had a bug where this would raise an exception when
351    # hitting a now-dead weak reference.
352    nested.bb = 23
353
354  def testDisconnectingNestedMessageBeforeSettingField(self):
355    proto = unittest_pb2.TestAllTypes()
356    nested = proto.optional_nested_message
357    proto.ClearField('optional_nested_message')  # Should disconnect from parent
358    self.assertTrue(nested is not proto.optional_nested_message)
359    nested.bb = 23
360    self.assertTrue(not proto.HasField('optional_nested_message'))
361    self.assertEqual(0, proto.optional_nested_message.bb)
362
363  def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
364    proto = unittest_pb2.TestAllTypes()
365    nested = proto.optional_nested_message
366    proto.ClearField('optional_nested_message')
367    del proto
368    del nested
369    # Force a garbage collect so that the underlying CMessages are freed along
370    # with the Messages they point to. This is to make sure we're not deleting
371    # default message instances.
372    gc.collect()
373    proto = unittest_pb2.TestAllTypes()
374    nested = proto.optional_nested_message
375
376  def testDisconnectingNestedMessageAfterSettingField(self):
377    proto = unittest_pb2.TestAllTypes()
378    nested = proto.optional_nested_message
379    nested.bb = 5
380    self.assertTrue(proto.HasField('optional_nested_message'))
381    proto.ClearField('optional_nested_message')  # Should disconnect from parent
382    self.assertEqual(5, nested.bb)
383    self.assertEqual(0, proto.optional_nested_message.bb)
384    self.assertTrue(nested is not proto.optional_nested_message)
385    nested.bb = 23
386    self.assertTrue(not proto.HasField('optional_nested_message'))
387    self.assertEqual(0, proto.optional_nested_message.bb)
388
389  def testDisconnectingNestedMessageBeforeGettingField(self):
390    proto = unittest_pb2.TestAllTypes()
391    self.assertTrue(not proto.HasField('optional_nested_message'))
392    proto.ClearField('optional_nested_message')
393    self.assertTrue(not proto.HasField('optional_nested_message'))
394
395  def testDisconnectingNestedMessageAfterMerge(self):
396    # This test exercises the code path that does not use ReleaseMessage().
397    # The underlying fear is that if we use ReleaseMessage() incorrectly,
398    # we will have memory leaks.  It's hard to check that that doesn't happen,
399    # but at least we can exercise that code path to make sure it works.
400    proto1 = unittest_pb2.TestAllTypes()
401    proto2 = unittest_pb2.TestAllTypes()
402    proto2.optional_nested_message.bb = 5
403    proto1.MergeFrom(proto2)
404    self.assertTrue(proto1.HasField('optional_nested_message'))
405    proto1.ClearField('optional_nested_message')
406    self.assertTrue(not proto1.HasField('optional_nested_message'))
407
408  def testDisconnectingLazyNestedMessage(self):
409    # This test exercises releasing a nested message that is lazy. This test
410    # only exercises real code in the C++ implementation as Python does not
411    # support lazy parsing, but the current C++ implementation results in
412    # memory corruption and a crash.
413    if api_implementation.Type() != 'python':
414      return
415    proto = unittest_pb2.TestAllTypes()
416    proto.optional_lazy_message.bb = 5
417    proto.ClearField('optional_lazy_message')
418    del proto
419    gc.collect()
420
421  def testHasBitsWhenModifyingRepeatedFields(self):
422    # Test nesting when we add an element to a repeated field in a submessage.
423    proto = unittest_pb2.TestNestedMessageHasBits()
424    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
425    self.assertEqual(
426        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
427    self.assertTrue(proto.HasField('optional_nested_message'))
428
429    # Do the same test, but with a repeated composite field within the
430    # submessage.
431    proto.ClearField('optional_nested_message')
432    self.assertTrue(not proto.HasField('optional_nested_message'))
433    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
434    self.assertTrue(proto.HasField('optional_nested_message'))
435
436  def testHasBitsForManyLevelsOfNesting(self):
437    # Test nesting many levels deep.
438    recursive_proto = unittest_pb2.TestMutualRecursionA()
439    self.assertTrue(not recursive_proto.HasField('bb'))
440    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
441    self.assertTrue(not recursive_proto.HasField('bb'))
442    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
443    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
444    self.assertTrue(recursive_proto.HasField('bb'))
445    self.assertTrue(recursive_proto.bb.HasField('a'))
446    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
447    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
448    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
449    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
450    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
451
452  def testSingularListFields(self):
453    proto = unittest_pb2.TestAllTypes()
454    proto.optional_fixed32 = 1
455    proto.optional_int32 = 5
456    proto.optional_string = 'foo'
457    # Access sub-message but don't set it yet.
458    nested_message = proto.optional_nested_message
459    self.assertEqual(
460      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
461        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
462        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
463      proto.ListFields())
464
465    proto.optional_nested_message.bb = 123
466    self.assertEqual(
467      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
468        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
469        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
470        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
471             nested_message) ],
472      proto.ListFields())
473
474  def testRepeatedListFields(self):
475    proto = unittest_pb2.TestAllTypes()
476    proto.repeated_fixed32.append(1)
477    proto.repeated_int32.append(5)
478    proto.repeated_int32.append(11)
479    proto.repeated_string.extend(['foo', 'bar'])
480    proto.repeated_string.extend([])
481    proto.repeated_string.append('baz')
482    proto.repeated_string.extend(str(x) for x in range(2))
483    proto.optional_int32 = 21
484    proto.repeated_bool  # Access but don't set anything; should not be listed.
485    self.assertEqual(
486      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
487        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
488        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
489        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
490          ['foo', 'bar', 'baz', '0', '1']) ],
491      proto.ListFields())
492
493  def testSingularListExtensions(self):
494    proto = unittest_pb2.TestAllExtensions()
495    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
496    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
497    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
498    self.assertEqual(
499      [ (unittest_pb2.optional_int32_extension  , 5),
500        (unittest_pb2.optional_fixed32_extension, 1),
501        (unittest_pb2.optional_string_extension , 'foo') ],
502      proto.ListFields())
503
504  def testRepeatedListExtensions(self):
505    proto = unittest_pb2.TestAllExtensions()
506    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
507    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
508    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
509    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
510    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
511    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
512    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
513    self.assertEqual(
514      [ (unittest_pb2.optional_int32_extension  , 21),
515        (unittest_pb2.repeated_int32_extension  , [5, 11]),
516        (unittest_pb2.repeated_fixed32_extension, [1]),
517        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
518      proto.ListFields())
519
520  def testListFieldsAndExtensions(self):
521    proto = unittest_pb2.TestFieldOrderings()
522    test_util.SetAllFieldsAndExtensions(proto)
523    unittest_pb2.my_extension_int
524    self.assertEqual(
525      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
526        (unittest_pb2.my_extension_int               , 23),
527        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
528        (unittest_pb2.my_extension_string            , 'bar'),
529        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
530      proto.ListFields())
531
532  def testDefaultValues(self):
533    proto = unittest_pb2.TestAllTypes()
534    self.assertEqual(0, proto.optional_int32)
535    self.assertEqual(0, proto.optional_int64)
536    self.assertEqual(0, proto.optional_uint32)
537    self.assertEqual(0, proto.optional_uint64)
538    self.assertEqual(0, proto.optional_sint32)
539    self.assertEqual(0, proto.optional_sint64)
540    self.assertEqual(0, proto.optional_fixed32)
541    self.assertEqual(0, proto.optional_fixed64)
542    self.assertEqual(0, proto.optional_sfixed32)
543    self.assertEqual(0, proto.optional_sfixed64)
544    self.assertEqual(0.0, proto.optional_float)
545    self.assertEqual(0.0, proto.optional_double)
546    self.assertEqual(False, proto.optional_bool)
547    self.assertEqual('', proto.optional_string)
548    self.assertEqual(b'', proto.optional_bytes)
549
550    self.assertEqual(41, proto.default_int32)
551    self.assertEqual(42, proto.default_int64)
552    self.assertEqual(43, proto.default_uint32)
553    self.assertEqual(44, proto.default_uint64)
554    self.assertEqual(-45, proto.default_sint32)
555    self.assertEqual(46, proto.default_sint64)
556    self.assertEqual(47, proto.default_fixed32)
557    self.assertEqual(48, proto.default_fixed64)
558    self.assertEqual(49, proto.default_sfixed32)
559    self.assertEqual(-50, proto.default_sfixed64)
560    self.assertEqual(51.5, proto.default_float)
561    self.assertEqual(52e3, proto.default_double)
562    self.assertEqual(True, proto.default_bool)
563    self.assertEqual('hello', proto.default_string)
564    self.assertEqual(b'world', proto.default_bytes)
565    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
566    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
567    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
568                     proto.default_import_enum)
569
570    proto = unittest_pb2.TestExtremeDefaultValues()
571    self.assertEqual(u'\u1234', proto.utf8_string)
572
573  def testHasFieldWithUnknownFieldName(self):
574    proto = unittest_pb2.TestAllTypes()
575    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
576
577  def testClearFieldWithUnknownFieldName(self):
578    proto = unittest_pb2.TestAllTypes()
579    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
580
581  def testClearRemovesChildren(self):
582    # Make sure there aren't any implementation bugs that are only partially
583    # clearing the message (which can happen in the more complex C++
584    # implementation which has parallel message lists).
585    proto = unittest_pb2.TestRequiredForeign()
586    for i in range(10):
587      proto.repeated_message.add()
588    proto2 = unittest_pb2.TestRequiredForeign()
589    proto.CopyFrom(proto2)
590    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
591
592  def testDisallowedAssignments(self):
593    # It's illegal to assign values directly to repeated fields
594    # or to nonrepeated composite fields.  Ensure that this fails.
595    proto = unittest_pb2.TestAllTypes()
596    # Repeated fields.
597    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
598    # Lists shouldn't work, either.
599    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
600    # Composite fields.
601    self.assertRaises(AttributeError, setattr, proto,
602                      'optional_nested_message', 23)
603    # Assignment to a repeated nested message field without specifying
604    # the index in the array of nested messages.
605    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
606                      'bb', 34)
607    # Assignment to an attribute of a repeated field.
608    self.assertRaises(AttributeError, setattr, proto.repeated_float,
609                      'some_attribute', 34)
610    # proto.nonexistent_field = 23 should fail as well.
611    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
612
613  def testSingleScalarTypeSafety(self):
614    proto = unittest_pb2.TestAllTypes()
615    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
616    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
617    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
618    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
619
620  def testIntegerTypes(self):
621    def TestGetAndDeserialize(field_name, value, expected_type):
622      proto = unittest_pb2.TestAllTypes()
623      setattr(proto, field_name, value)
624      self.assertIsInstance(getattr(proto, field_name), expected_type)
625      proto2 = unittest_pb2.TestAllTypes()
626      proto2.ParseFromString(proto.SerializeToString())
627      self.assertIsInstance(getattr(proto2, field_name), expected_type)
628
629    TestGetAndDeserialize('optional_int32', 1, int)
630    TestGetAndDeserialize('optional_int32', 1 << 30, int)
631    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
632    try:
633      integer_64 = long
634    except NameError: # Python3
635      integer_64 = int
636    if struct.calcsize('L') == 4:
637      # Python only has signed ints, so 32-bit python can't fit an uint32
638      # in an int.
639      TestGetAndDeserialize('optional_uint32', 1 << 31, long)
640    else:
641      # 64-bit python can fit uint32 inside an int
642      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
643    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
644    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
645    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
646    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
647
648  def testSingleScalarBoundsChecking(self):
649    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
650      pb = unittest_pb2.TestAllTypes()
651      setattr(pb, field_name, expected_min)
652      self.assertEqual(expected_min, getattr(pb, field_name))
653      setattr(pb, field_name, expected_max)
654      self.assertEqual(expected_max, getattr(pb, field_name))
655      self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
656      self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
657
658    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
659    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
660    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
661    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
662
663    pb = unittest_pb2.TestAllTypes()
664    pb.optional_nested_enum = 1
665    self.assertEqual(1, pb.optional_nested_enum)
666
667  def testRepeatedScalarTypeSafety(self):
668    proto = unittest_pb2.TestAllTypes()
669    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
670    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
671    self.assertRaises(TypeError, proto.repeated_string, 10)
672    self.assertRaises(TypeError, proto.repeated_bytes, 10)
673
674    proto.repeated_int32.append(10)
675    proto.repeated_int32[0] = 23
676    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
677    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
678
679    # Repeated enums tests.
680    #proto.repeated_nested_enum.append(0)
681
682  def testSingleScalarGettersAndSetters(self):
683    proto = unittest_pb2.TestAllTypes()
684    self.assertEqual(0, proto.optional_int32)
685    proto.optional_int32 = 1
686    self.assertEqual(1, proto.optional_int32)
687
688    proto.optional_uint64 = 0xffffffffffff
689    self.assertEqual(0xffffffffffff, proto.optional_uint64)
690    proto.optional_uint64 = 0xffffffffffffffff
691    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
692    # TODO(robinson): Test all other scalar field types.
693
694  def testSingleScalarClearField(self):
695    proto = unittest_pb2.TestAllTypes()
696    # Should be allowed to clear something that's not there (a no-op).
697    proto.ClearField('optional_int32')
698    proto.optional_int32 = 1
699    self.assertTrue(proto.HasField('optional_int32'))
700    proto.ClearField('optional_int32')
701    self.assertEqual(0, proto.optional_int32)
702    self.assertTrue(not proto.HasField('optional_int32'))
703    # TODO(robinson): Test all other scalar field types.
704
705  def testEnums(self):
706    proto = unittest_pb2.TestAllTypes()
707    self.assertEqual(1, proto.FOO)
708    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
709    self.assertEqual(2, proto.BAR)
710    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
711    self.assertEqual(3, proto.BAZ)
712    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
713
714  def testEnum_Name(self):
715    self.assertEqual('FOREIGN_FOO',
716                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
717    self.assertEqual('FOREIGN_BAR',
718                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
719    self.assertEqual('FOREIGN_BAZ',
720                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
721    self.assertRaises(ValueError,
722                      unittest_pb2.ForeignEnum.Name, 11312)
723
724    proto = unittest_pb2.TestAllTypes()
725    self.assertEqual('FOO',
726                     proto.NestedEnum.Name(proto.FOO))
727    self.assertEqual('FOO',
728                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
729    self.assertEqual('BAR',
730                     proto.NestedEnum.Name(proto.BAR))
731    self.assertEqual('BAR',
732                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
733    self.assertEqual('BAZ',
734                     proto.NestedEnum.Name(proto.BAZ))
735    self.assertEqual('BAZ',
736                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
737    self.assertRaises(ValueError,
738                      proto.NestedEnum.Name, 11312)
739    self.assertRaises(ValueError,
740                      unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
741
742  def testEnum_Value(self):
743    self.assertEqual(unittest_pb2.FOREIGN_FOO,
744                     unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
745    self.assertEqual(unittest_pb2.FOREIGN_BAR,
746                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
747    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
748                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
749    self.assertRaises(ValueError,
750                      unittest_pb2.ForeignEnum.Value, 'FO')
751
752    proto = unittest_pb2.TestAllTypes()
753    self.assertEqual(proto.FOO,
754                     proto.NestedEnum.Value('FOO'))
755    self.assertEqual(proto.FOO,
756                     unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
757    self.assertEqual(proto.BAR,
758                     proto.NestedEnum.Value('BAR'))
759    self.assertEqual(proto.BAR,
760                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
761    self.assertEqual(proto.BAZ,
762                     proto.NestedEnum.Value('BAZ'))
763    self.assertEqual(proto.BAZ,
764                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
765    self.assertRaises(ValueError,
766                      proto.NestedEnum.Value, 'Foo')
767    self.assertRaises(ValueError,
768                      unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
769
770  def testEnum_KeysAndValues(self):
771    self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
772                     list(unittest_pb2.ForeignEnum.keys()))
773    self.assertEqual([4, 5, 6],
774                     list(unittest_pb2.ForeignEnum.values()))
775    self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
776                      ('FOREIGN_BAZ', 6)],
777                     list(unittest_pb2.ForeignEnum.items()))
778
779    proto = unittest_pb2.TestAllTypes()
780    self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
781    self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
782    self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
783                     list(proto.NestedEnum.items()))
784
785  def testRepeatedScalars(self):
786    proto = unittest_pb2.TestAllTypes()
787
788    self.assertTrue(not proto.repeated_int32)
789    self.assertEqual(0, len(proto.repeated_int32))
790    proto.repeated_int32.append(5)
791    proto.repeated_int32.append(10)
792    proto.repeated_int32.append(15)
793    self.assertTrue(proto.repeated_int32)
794    self.assertEqual(3, len(proto.repeated_int32))
795
796    self.assertEqual([5, 10, 15], proto.repeated_int32)
797
798    # Test single retrieval.
799    self.assertEqual(5, proto.repeated_int32[0])
800    self.assertEqual(15, proto.repeated_int32[-1])
801    # Test out-of-bounds indices.
802    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
803    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
804    # Test incorrect types passed to __getitem__.
805    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
806    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
807
808    # Test single assignment.
809    proto.repeated_int32[1] = 20
810    self.assertEqual([5, 20, 15], proto.repeated_int32)
811
812    # Test insertion.
813    proto.repeated_int32.insert(1, 25)
814    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
815
816    # Test slice retrieval.
817    proto.repeated_int32.append(30)
818    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
819    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
820
821    # Test slice assignment with an iterator
822    proto.repeated_int32[1:4] = (i for i in range(3))
823    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
824
825    # Test slice assignment.
826    proto.repeated_int32[1:4] = [35, 40, 45]
827    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
828
829    # Test that we can use the field as an iterator.
830    result = []
831    for i in proto.repeated_int32:
832      result.append(i)
833    self.assertEqual([5, 35, 40, 45, 30], result)
834
835    # Test single deletion.
836    del proto.repeated_int32[2]
837    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
838
839    # Test slice deletion.
840    del proto.repeated_int32[2:]
841    self.assertEqual([5, 35], proto.repeated_int32)
842
843    # Test extending.
844    proto.repeated_int32.extend([3, 13])
845    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
846
847    # Test clearing.
848    proto.ClearField('repeated_int32')
849    self.assertTrue(not proto.repeated_int32)
850    self.assertEqual(0, len(proto.repeated_int32))
851
852    proto.repeated_int32.append(1)
853    self.assertEqual(1, proto.repeated_int32[-1])
854    # Test assignment to a negative index.
855    proto.repeated_int32[-1] = 2
856    self.assertEqual(2, proto.repeated_int32[-1])
857
858    # Test deletion at negative indices.
859    proto.repeated_int32[:] = [0, 1, 2, 3]
860    del proto.repeated_int32[-1]
861    self.assertEqual([0, 1, 2], proto.repeated_int32)
862
863    del proto.repeated_int32[-2]
864    self.assertEqual([0, 2], proto.repeated_int32)
865
866    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
867    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
868
869    del proto.repeated_int32[-2:-1]
870    self.assertEqual([2], proto.repeated_int32)
871
872    del proto.repeated_int32[100:10000]
873    self.assertEqual([2], proto.repeated_int32)
874
875  def testRepeatedScalarsRemove(self):
876    proto = unittest_pb2.TestAllTypes()
877
878    self.assertTrue(not proto.repeated_int32)
879    self.assertEqual(0, len(proto.repeated_int32))
880    proto.repeated_int32.append(5)
881    proto.repeated_int32.append(10)
882    proto.repeated_int32.append(5)
883    proto.repeated_int32.append(5)
884
885    self.assertEqual(4, len(proto.repeated_int32))
886    proto.repeated_int32.remove(5)
887    self.assertEqual(3, len(proto.repeated_int32))
888    self.assertEqual(10, proto.repeated_int32[0])
889    self.assertEqual(5, proto.repeated_int32[1])
890    self.assertEqual(5, proto.repeated_int32[2])
891
892    proto.repeated_int32.remove(5)
893    self.assertEqual(2, len(proto.repeated_int32))
894    self.assertEqual(10, proto.repeated_int32[0])
895    self.assertEqual(5, proto.repeated_int32[1])
896
897    proto.repeated_int32.remove(10)
898    self.assertEqual(1, len(proto.repeated_int32))
899    self.assertEqual(5, proto.repeated_int32[0])
900
901    # Remove a non-existent element.
902    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
903
904  def testRepeatedComposites(self):
905    proto = unittest_pb2.TestAllTypes()
906    self.assertTrue(not proto.repeated_nested_message)
907    self.assertEqual(0, len(proto.repeated_nested_message))
908    m0 = proto.repeated_nested_message.add()
909    m1 = proto.repeated_nested_message.add()
910    self.assertTrue(proto.repeated_nested_message)
911    self.assertEqual(2, len(proto.repeated_nested_message))
912    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
913    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
914
915    # Test out-of-bounds indices.
916    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
917                      1234)
918    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
919                      -1234)
920
921    # Test incorrect types passed to __getitem__.
922    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
923                      'foo')
924    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
925                      None)
926
927    # Test slice retrieval.
928    m2 = proto.repeated_nested_message.add()
929    m3 = proto.repeated_nested_message.add()
930    m4 = proto.repeated_nested_message.add()
931    self.assertListsEqual(
932        [m1, m2, m3], proto.repeated_nested_message[1:4])
933    self.assertListsEqual(
934        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
935    self.assertListsEqual(
936        [m0, m1], proto.repeated_nested_message[:2])
937    self.assertListsEqual(
938        [m2, m3, m4], proto.repeated_nested_message[2:])
939    self.assertEqual(
940        m0, proto.repeated_nested_message[0])
941    self.assertListsEqual(
942        [m0], proto.repeated_nested_message[:1])
943
944    # Test that we can use the field as an iterator.
945    result = []
946    for i in proto.repeated_nested_message:
947      result.append(i)
948    self.assertListsEqual([m0, m1, m2, m3, m4], result)
949
950    # Test single deletion.
951    del proto.repeated_nested_message[2]
952    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
953
954    # Test slice deletion.
955    del proto.repeated_nested_message[2:]
956    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
957
958    # Test extending.
959    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
960    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
961    proto.repeated_nested_message.extend([n1,n2])
962    self.assertEqual(4, len(proto.repeated_nested_message))
963    self.assertEqual(n1, proto.repeated_nested_message[2])
964    self.assertEqual(n2, proto.repeated_nested_message[3])
965
966    # Test clearing.
967    proto.ClearField('repeated_nested_message')
968    self.assertTrue(not proto.repeated_nested_message)
969    self.assertEqual(0, len(proto.repeated_nested_message))
970
971    # Test constructing an element while adding it.
972    proto.repeated_nested_message.add(bb=23)
973    self.assertEqual(1, len(proto.repeated_nested_message))
974    self.assertEqual(23, proto.repeated_nested_message[0].bb)
975
976  def testRepeatedCompositeRemove(self):
977    proto = unittest_pb2.TestAllTypes()
978
979    self.assertEqual(0, len(proto.repeated_nested_message))
980    m0 = proto.repeated_nested_message.add()
981    # Need to set some differentiating variable so m0 != m1 != m2:
982    m0.bb = len(proto.repeated_nested_message)
983    m1 = proto.repeated_nested_message.add()
984    m1.bb = len(proto.repeated_nested_message)
985    self.assertTrue(m0 != m1)
986    m2 = proto.repeated_nested_message.add()
987    m2.bb = len(proto.repeated_nested_message)
988    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
989
990    self.assertEqual(3, len(proto.repeated_nested_message))
991    proto.repeated_nested_message.remove(m0)
992    self.assertEqual(2, len(proto.repeated_nested_message))
993    self.assertEqual(m1, proto.repeated_nested_message[0])
994    self.assertEqual(m2, proto.repeated_nested_message[1])
995
996    # Removing m0 again or removing None should raise error
997    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
998    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
999    self.assertEqual(2, len(proto.repeated_nested_message))
1000
1001    proto.repeated_nested_message.remove(m2)
1002    self.assertEqual(1, len(proto.repeated_nested_message))
1003    self.assertEqual(m1, proto.repeated_nested_message[0])
1004
1005  def testHandWrittenReflection(self):
1006    # Hand written extensions are only supported by the pure-Python
1007    # implementation of the API.
1008    if api_implementation.Type() != 'python':
1009      return
1010
1011    FieldDescriptor = descriptor.FieldDescriptor
1012    foo_field_descriptor = FieldDescriptor(
1013        name='foo_field', full_name='MyProto.foo_field',
1014        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1015        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1016        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1017        containing_type=None, message_type=None, enum_type=None,
1018        is_extension=False, extension_scope=None,
1019        options=descriptor_pb2.FieldOptions())
1020    mydescriptor = descriptor.Descriptor(
1021        name='MyProto', full_name='MyProto', filename='ignored',
1022        containing_type=None, nested_types=[], enum_types=[],
1023        fields=[foo_field_descriptor], extensions=[],
1024        options=descriptor_pb2.MessageOptions())
1025    class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1026      DESCRIPTOR = mydescriptor
1027    myproto_instance = MyProtoClass()
1028    self.assertEqual(0, myproto_instance.foo_field)
1029    self.assertTrue(not myproto_instance.HasField('foo_field'))
1030    myproto_instance.foo_field = 23
1031    self.assertEqual(23, myproto_instance.foo_field)
1032    self.assertTrue(myproto_instance.HasField('foo_field'))
1033
1034  def testDescriptorProtoSupport(self):
1035    # Hand written descriptors/reflection are only supported by the pure-Python
1036    # implementation of the API.
1037    if api_implementation.Type() != 'python':
1038      return
1039
1040    def AddDescriptorField(proto, field_name, field_type):
1041      AddDescriptorField.field_index += 1
1042      new_field = proto.field.add()
1043      new_field.name = field_name
1044      new_field.type = field_type
1045      new_field.number = AddDescriptorField.field_index
1046      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1047
1048    AddDescriptorField.field_index = 0
1049
1050    desc_proto = descriptor_pb2.DescriptorProto()
1051    desc_proto.name = 'Car'
1052    fdp = descriptor_pb2.FieldDescriptorProto
1053    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1054    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1055    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1056    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1057    # Add a repeated field
1058    AddDescriptorField.field_index += 1
1059    new_field = desc_proto.field.add()
1060    new_field.name = 'owners'
1061    new_field.type = fdp.TYPE_STRING
1062    new_field.number = AddDescriptorField.field_index
1063    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1064
1065    desc = descriptor.MakeDescriptor(desc_proto)
1066    self.assertTrue('name' in desc.fields_by_name)
1067    self.assertTrue('year' in desc.fields_by_name)
1068    self.assertTrue('automatic' in desc.fields_by_name)
1069    self.assertTrue('price' in desc.fields_by_name)
1070    self.assertTrue('owners' in desc.fields_by_name)
1071
1072    class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1073      DESCRIPTOR = desc
1074
1075    prius = CarMessage()
1076    prius.name = 'prius'
1077    prius.year = 2010
1078    prius.automatic = True
1079    prius.price = 25134.75
1080    prius.owners.extend(['bob', 'susan'])
1081
1082    serialized_prius = prius.SerializeToString()
1083    new_prius = reflection.ParseMessage(desc, serialized_prius)
1084    self.assertTrue(new_prius is not prius)
1085    self.assertEqual(prius, new_prius)
1086
1087    # these are unnecessary assuming message equality works as advertised but
1088    # explicitly check to be safe since we're mucking about in metaclass foo
1089    self.assertEqual(prius.name, new_prius.name)
1090    self.assertEqual(prius.year, new_prius.year)
1091    self.assertEqual(prius.automatic, new_prius.automatic)
1092    self.assertEqual(prius.price, new_prius.price)
1093    self.assertEqual(prius.owners, new_prius.owners)
1094
1095  def testTopLevelExtensionsForOptionalScalar(self):
1096    extendee_proto = unittest_pb2.TestAllExtensions()
1097    extension = unittest_pb2.optional_int32_extension
1098    self.assertTrue(not extendee_proto.HasExtension(extension))
1099    self.assertEqual(0, extendee_proto.Extensions[extension])
1100    # As with normal scalar fields, just doing a read doesn't actually set the
1101    # "has" bit.
1102    self.assertTrue(not extendee_proto.HasExtension(extension))
1103    # Actually set the thing.
1104    extendee_proto.Extensions[extension] = 23
1105    self.assertEqual(23, extendee_proto.Extensions[extension])
1106    self.assertTrue(extendee_proto.HasExtension(extension))
1107    # Ensure that clearing works as well.
1108    extendee_proto.ClearExtension(extension)
1109    self.assertEqual(0, extendee_proto.Extensions[extension])
1110    self.assertTrue(not extendee_proto.HasExtension(extension))
1111
1112  def testTopLevelExtensionsForRepeatedScalar(self):
1113    extendee_proto = unittest_pb2.TestAllExtensions()
1114    extension = unittest_pb2.repeated_string_extension
1115    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1116    extendee_proto.Extensions[extension].append('foo')
1117    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1118    string_list = extendee_proto.Extensions[extension]
1119    extendee_proto.ClearExtension(extension)
1120    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1121    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
1122    # Shouldn't be allowed to do Extensions[extension] = 'a'
1123    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1124                      extension, 'a')
1125
1126  def testTopLevelExtensionsForOptionalMessage(self):
1127    extendee_proto = unittest_pb2.TestAllExtensions()
1128    extension = unittest_pb2.optional_foreign_message_extension
1129    self.assertTrue(not extendee_proto.HasExtension(extension))
1130    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1131    # As with normal (non-extension) fields, merely reading from the
1132    # thing shouldn't set the "has" bit.
1133    self.assertTrue(not extendee_proto.HasExtension(extension))
1134    extendee_proto.Extensions[extension].c = 23
1135    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1136    self.assertTrue(extendee_proto.HasExtension(extension))
1137    # Save a reference here.
1138    foreign_message = extendee_proto.Extensions[extension]
1139    extendee_proto.ClearExtension(extension)
1140    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
1141    # Setting a field on foreign_message now shouldn't set
1142    # any "has" bits on extendee_proto.
1143    foreign_message.c = 42
1144    self.assertEqual(42, foreign_message.c)
1145    self.assertTrue(foreign_message.HasField('c'))
1146    self.assertTrue(not extendee_proto.HasExtension(extension))
1147    # Shouldn't be allowed to do Extensions[extension] = 'a'
1148    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1149                      extension, 'a')
1150
1151  def testTopLevelExtensionsForRepeatedMessage(self):
1152    extendee_proto = unittest_pb2.TestAllExtensions()
1153    extension = unittest_pb2.repeatedgroup_extension
1154    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1155    group = extendee_proto.Extensions[extension].add()
1156    group.a = 23
1157    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1158    group.a = 42
1159    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1160    group_list = extendee_proto.Extensions[extension]
1161    extendee_proto.ClearExtension(extension)
1162    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1163    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
1164    # Shouldn't be allowed to do Extensions[extension] = 'a'
1165    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1166                      extension, 'a')
1167
1168  def testNestedExtensions(self):
1169    extendee_proto = unittest_pb2.TestAllExtensions()
1170    extension = unittest_pb2.TestRequired.single
1171
1172    # We just test the non-repeated case.
1173    self.assertTrue(not extendee_proto.HasExtension(extension))
1174    required = extendee_proto.Extensions[extension]
1175    self.assertEqual(0, required.a)
1176    self.assertTrue(not extendee_proto.HasExtension(extension))
1177    required.a = 23
1178    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1179    self.assertTrue(extendee_proto.HasExtension(extension))
1180    extendee_proto.ClearExtension(extension)
1181    self.assertTrue(required is not extendee_proto.Extensions[extension])
1182    self.assertTrue(not extendee_proto.HasExtension(extension))
1183
1184  def testRegisteredExtensions(self):
1185    self.assertTrue('protobuf_unittest.optional_int32_extension' in
1186                    unittest_pb2.TestAllExtensions._extensions_by_name)
1187    self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
1188    # Make sure extensions haven't been registered into types that shouldn't
1189    # have any.
1190    self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
1191
1192  # If message A directly contains message B, and
1193  # a.HasField('b') is currently False, then mutating any
1194  # extension in B should change a.HasField('b') to True
1195  # (and so on up the object tree).
1196  def testHasBitsForAncestorsOfExtendedMessage(self):
1197    # Optional scalar extension.
1198    toplevel = more_extensions_pb2.TopLevelMessage()
1199    self.assertTrue(not toplevel.HasField('submessage'))
1200    self.assertEqual(0, toplevel.submessage.Extensions[
1201        more_extensions_pb2.optional_int_extension])
1202    self.assertTrue(not toplevel.HasField('submessage'))
1203    toplevel.submessage.Extensions[
1204        more_extensions_pb2.optional_int_extension] = 23
1205    self.assertEqual(23, toplevel.submessage.Extensions[
1206        more_extensions_pb2.optional_int_extension])
1207    self.assertTrue(toplevel.HasField('submessage'))
1208
1209    # Repeated scalar extension.
1210    toplevel = more_extensions_pb2.TopLevelMessage()
1211    self.assertTrue(not toplevel.HasField('submessage'))
1212    self.assertEqual([], toplevel.submessage.Extensions[
1213        more_extensions_pb2.repeated_int_extension])
1214    self.assertTrue(not toplevel.HasField('submessage'))
1215    toplevel.submessage.Extensions[
1216        more_extensions_pb2.repeated_int_extension].append(23)
1217    self.assertEqual([23], toplevel.submessage.Extensions[
1218        more_extensions_pb2.repeated_int_extension])
1219    self.assertTrue(toplevel.HasField('submessage'))
1220
1221    # Optional message extension.
1222    toplevel = more_extensions_pb2.TopLevelMessage()
1223    self.assertTrue(not toplevel.HasField('submessage'))
1224    self.assertEqual(0, toplevel.submessage.Extensions[
1225        more_extensions_pb2.optional_message_extension].foreign_message_int)
1226    self.assertTrue(not toplevel.HasField('submessage'))
1227    toplevel.submessage.Extensions[
1228        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1229    self.assertEqual(23, toplevel.submessage.Extensions[
1230        more_extensions_pb2.optional_message_extension].foreign_message_int)
1231    self.assertTrue(toplevel.HasField('submessage'))
1232
1233    # Repeated message extension.
1234    toplevel = more_extensions_pb2.TopLevelMessage()
1235    self.assertTrue(not toplevel.HasField('submessage'))
1236    self.assertEqual(0, len(toplevel.submessage.Extensions[
1237        more_extensions_pb2.repeated_message_extension]))
1238    self.assertTrue(not toplevel.HasField('submessage'))
1239    foreign = toplevel.submessage.Extensions[
1240        more_extensions_pb2.repeated_message_extension].add()
1241    self.assertEqual(foreign, toplevel.submessage.Extensions[
1242        more_extensions_pb2.repeated_message_extension][0])
1243    self.assertTrue(toplevel.HasField('submessage'))
1244
1245  def testDisconnectionAfterClearingEmptyMessage(self):
1246    toplevel = more_extensions_pb2.TopLevelMessage()
1247    extendee_proto = toplevel.submessage
1248    extension = more_extensions_pb2.optional_message_extension
1249    extension_proto = extendee_proto.Extensions[extension]
1250    extendee_proto.ClearExtension(extension)
1251    extension_proto.foreign_message_int = 23
1252
1253    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
1254
1255  def testExtensionFailureModes(self):
1256    extendee_proto = unittest_pb2.TestAllExtensions()
1257
1258    # Try non-extension-handle arguments to HasExtension,
1259    # ClearExtension(), and Extensions[]...
1260    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1261    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1262    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1263    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1264
1265    # Try something that *is* an extension handle, just not for
1266    # this message...
1267    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1268                           more_extensions_pb2.optional_message_extension,
1269                           more_extensions_pb2.repeated_int_extension,
1270                           more_extensions_pb2.repeated_message_extension):
1271      self.assertRaises(KeyError, extendee_proto.HasExtension,
1272                        unknown_handle)
1273      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1274                        unknown_handle)
1275      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1276                        unknown_handle)
1277      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1278                        unknown_handle, 5)
1279
1280    # Try call HasExtension() with a valid handle, but for a
1281    # *repeated* field.  (Just as with non-extension repeated
1282    # fields, Has*() isn't supported for extension repeated fields).
1283    self.assertRaises(KeyError, extendee_proto.HasExtension,
1284                      unittest_pb2.repeated_string_extension)
1285
1286  def testStaticParseFrom(self):
1287    proto1 = unittest_pb2.TestAllTypes()
1288    test_util.SetAllFields(proto1)
1289
1290    string1 = proto1.SerializeToString()
1291    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
1292
1293    # Messages should be equal.
1294    self.assertEqual(proto2, proto1)
1295
1296  def testMergeFromSingularField(self):
1297    # Test merge with just a singular field.
1298    proto1 = unittest_pb2.TestAllTypes()
1299    proto1.optional_int32 = 1
1300
1301    proto2 = unittest_pb2.TestAllTypes()
1302    # This shouldn't get overwritten.
1303    proto2.optional_string = 'value'
1304
1305    proto2.MergeFrom(proto1)
1306    self.assertEqual(1, proto2.optional_int32)
1307    self.assertEqual('value', proto2.optional_string)
1308
1309  def testMergeFromRepeatedField(self):
1310    # Test merge with just a repeated field.
1311    proto1 = unittest_pb2.TestAllTypes()
1312    proto1.repeated_int32.append(1)
1313    proto1.repeated_int32.append(2)
1314
1315    proto2 = unittest_pb2.TestAllTypes()
1316    proto2.repeated_int32.append(0)
1317    proto2.MergeFrom(proto1)
1318
1319    self.assertEqual(0, proto2.repeated_int32[0])
1320    self.assertEqual(1, proto2.repeated_int32[1])
1321    self.assertEqual(2, proto2.repeated_int32[2])
1322
1323  def testMergeFromOptionalGroup(self):
1324    # Test merge with an optional group.
1325    proto1 = unittest_pb2.TestAllTypes()
1326    proto1.optionalgroup.a = 12
1327    proto2 = unittest_pb2.TestAllTypes()
1328    proto2.MergeFrom(proto1)
1329    self.assertEqual(12, proto2.optionalgroup.a)
1330
1331  def testMergeFromRepeatedNestedMessage(self):
1332    # Test merge with a repeated nested message.
1333    proto1 = unittest_pb2.TestAllTypes()
1334    m = proto1.repeated_nested_message.add()
1335    m.bb = 123
1336    m = proto1.repeated_nested_message.add()
1337    m.bb = 321
1338
1339    proto2 = unittest_pb2.TestAllTypes()
1340    m = proto2.repeated_nested_message.add()
1341    m.bb = 999
1342    proto2.MergeFrom(proto1)
1343    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
1344    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
1345    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
1346
1347    proto3 = unittest_pb2.TestAllTypes()
1348    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
1349    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
1350    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
1351    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
1352
1353  def testMergeFromAllFields(self):
1354    # With all fields set.
1355    proto1 = unittest_pb2.TestAllTypes()
1356    test_util.SetAllFields(proto1)
1357    proto2 = unittest_pb2.TestAllTypes()
1358    proto2.MergeFrom(proto1)
1359
1360    # Messages should be equal.
1361    self.assertEqual(proto2, proto1)
1362
1363    # Serialized string should be equal too.
1364    string1 = proto1.SerializeToString()
1365    string2 = proto2.SerializeToString()
1366    self.assertEqual(string1, string2)
1367
1368  def testMergeFromExtensionsSingular(self):
1369    proto1 = unittest_pb2.TestAllExtensions()
1370    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1371
1372    proto2 = unittest_pb2.TestAllExtensions()
1373    proto2.MergeFrom(proto1)
1374    self.assertEqual(
1375        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1376
1377  def testMergeFromExtensionsRepeated(self):
1378    proto1 = unittest_pb2.TestAllExtensions()
1379    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1380    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1381
1382    proto2 = unittest_pb2.TestAllExtensions()
1383    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1384    proto2.MergeFrom(proto1)
1385    self.assertEqual(
1386        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1387    self.assertEqual(
1388        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1389    self.assertEqual(
1390        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1391    self.assertEqual(
1392        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1393
1394  def testMergeFromExtensionsNestedMessage(self):
1395    proto1 = unittest_pb2.TestAllExtensions()
1396    ext1 = proto1.Extensions[
1397        unittest_pb2.repeated_nested_message_extension]
1398    m = ext1.add()
1399    m.bb = 222
1400    m = ext1.add()
1401    m.bb = 333
1402
1403    proto2 = unittest_pb2.TestAllExtensions()
1404    ext2 = proto2.Extensions[
1405        unittest_pb2.repeated_nested_message_extension]
1406    m = ext2.add()
1407    m.bb = 111
1408
1409    proto2.MergeFrom(proto1)
1410    ext2 = proto2.Extensions[
1411        unittest_pb2.repeated_nested_message_extension]
1412    self.assertEqual(3, len(ext2))
1413    self.assertEqual(111, ext2[0].bb)
1414    self.assertEqual(222, ext2[1].bb)
1415    self.assertEqual(333, ext2[2].bb)
1416
1417  def testMergeFromBug(self):
1418    message1 = unittest_pb2.TestAllTypes()
1419    message2 = unittest_pb2.TestAllTypes()
1420
1421    # Cause optional_nested_message to be instantiated within message1, even
1422    # though it is not considered to be "present".
1423    message1.optional_nested_message
1424    self.assertFalse(message1.HasField('optional_nested_message'))
1425
1426    # Merge into message2.  This should not instantiate the field is message2.
1427    message2.MergeFrom(message1)
1428    self.assertFalse(message2.HasField('optional_nested_message'))
1429
1430  def testCopyFromSingularField(self):
1431    # Test copy with just a singular field.
1432    proto1 = unittest_pb2.TestAllTypes()
1433    proto1.optional_int32 = 1
1434    proto1.optional_string = 'important-text'
1435
1436    proto2 = unittest_pb2.TestAllTypes()
1437    proto2.optional_string = 'value'
1438
1439    proto2.CopyFrom(proto1)
1440    self.assertEqual(1, proto2.optional_int32)
1441    self.assertEqual('important-text', proto2.optional_string)
1442
1443  def testCopyFromRepeatedField(self):
1444    # Test copy with a repeated field.
1445    proto1 = unittest_pb2.TestAllTypes()
1446    proto1.repeated_int32.append(1)
1447    proto1.repeated_int32.append(2)
1448
1449    proto2 = unittest_pb2.TestAllTypes()
1450    proto2.repeated_int32.append(0)
1451    proto2.CopyFrom(proto1)
1452
1453    self.assertEqual(1, proto2.repeated_int32[0])
1454    self.assertEqual(2, proto2.repeated_int32[1])
1455
1456  def testCopyFromAllFields(self):
1457    # With all fields set.
1458    proto1 = unittest_pb2.TestAllTypes()
1459    test_util.SetAllFields(proto1)
1460    proto2 = unittest_pb2.TestAllTypes()
1461    proto2.CopyFrom(proto1)
1462
1463    # Messages should be equal.
1464    self.assertEqual(proto2, proto1)
1465
1466    # Serialized string should be equal too.
1467    string1 = proto1.SerializeToString()
1468    string2 = proto2.SerializeToString()
1469    self.assertEqual(string1, string2)
1470
1471  def testCopyFromSelf(self):
1472    proto1 = unittest_pb2.TestAllTypes()
1473    proto1.repeated_int32.append(1)
1474    proto1.optional_int32 = 2
1475    proto1.optional_string = 'important-text'
1476
1477    proto1.CopyFrom(proto1)
1478    self.assertEqual(1, proto1.repeated_int32[0])
1479    self.assertEqual(2, proto1.optional_int32)
1480    self.assertEqual('important-text', proto1.optional_string)
1481
1482  def testCopyFromBadType(self):
1483    # The python implementation doesn't raise an exception in this
1484    # case. In theory it should.
1485    if api_implementation.Type() == 'python':
1486      return
1487    proto1 = unittest_pb2.TestAllTypes()
1488    proto2 = unittest_pb2.TestAllExtensions()
1489    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1490
1491  def testDeepCopy(self):
1492    proto1 = unittest_pb2.TestAllTypes()
1493    proto1.optional_int32 = 1
1494    proto2 = copy.deepcopy(proto1)
1495    self.assertEqual(1, proto2.optional_int32)
1496
1497    proto1.repeated_int32.append(2)
1498    proto1.repeated_int32.append(3)
1499    container = copy.deepcopy(proto1.repeated_int32)
1500    self.assertEqual([2, 3], container)
1501
1502    # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
1503
1504  def testClear(self):
1505    proto = unittest_pb2.TestAllTypes()
1506    # C++ implementation does not support lazy fields right now so leave it
1507    # out for now.
1508    if api_implementation.Type() == 'python':
1509      test_util.SetAllFields(proto)
1510    else:
1511      test_util.SetAllNonLazyFields(proto)
1512    # Clear the message.
1513    proto.Clear()
1514    self.assertEqual(proto.ByteSize(), 0)
1515    empty_proto = unittest_pb2.TestAllTypes()
1516    self.assertEqual(proto, empty_proto)
1517
1518    # Test if extensions which were set are cleared.
1519    proto = unittest_pb2.TestAllExtensions()
1520    test_util.SetAllExtensions(proto)
1521    # Clear the message.
1522    proto.Clear()
1523    self.assertEqual(proto.ByteSize(), 0)
1524    empty_proto = unittest_pb2.TestAllExtensions()
1525    self.assertEqual(proto, empty_proto)
1526
1527  def testDisconnectingBeforeClear(self):
1528    proto = unittest_pb2.TestAllTypes()
1529    nested = proto.optional_nested_message
1530    proto.Clear()
1531    self.assertTrue(nested is not proto.optional_nested_message)
1532    nested.bb = 23
1533    self.assertTrue(not proto.HasField('optional_nested_message'))
1534    self.assertEqual(0, proto.optional_nested_message.bb)
1535
1536    proto = unittest_pb2.TestAllTypes()
1537    nested = proto.optional_nested_message
1538    nested.bb = 5
1539    foreign = proto.optional_foreign_message
1540    foreign.c = 6
1541
1542    proto.Clear()
1543    self.assertTrue(nested is not proto.optional_nested_message)
1544    self.assertTrue(foreign is not proto.optional_foreign_message)
1545    self.assertEqual(5, nested.bb)
1546    self.assertEqual(6, foreign.c)
1547    nested.bb = 15
1548    foreign.c = 16
1549    self.assertFalse(proto.HasField('optional_nested_message'))
1550    self.assertEqual(0, proto.optional_nested_message.bb)
1551    self.assertFalse(proto.HasField('optional_foreign_message'))
1552    self.assertEqual(0, proto.optional_foreign_message.c)
1553
1554  def testOneOf(self):
1555    proto = unittest_pb2.TestAllTypes()
1556    proto.oneof_uint32 = 10
1557    proto.oneof_nested_message.bb = 11
1558    self.assertEqual(11, proto.oneof_nested_message.bb)
1559    self.assertFalse(proto.HasField('oneof_uint32'))
1560    nested = proto.oneof_nested_message
1561    proto.oneof_string = 'abc'
1562    self.assertEqual('abc', proto.oneof_string)
1563    self.assertEqual(11, nested.bb)
1564    self.assertFalse(proto.HasField('oneof_nested_message'))
1565
1566  def assertInitialized(self, proto):
1567    self.assertTrue(proto.IsInitialized())
1568    # Neither method should raise an exception.
1569    proto.SerializeToString()
1570    proto.SerializePartialToString()
1571
1572  def assertNotInitialized(self, proto):
1573    self.assertFalse(proto.IsInitialized())
1574    self.assertRaises(message.EncodeError, proto.SerializeToString)
1575    # "Partial" serialization doesn't care if message is uninitialized.
1576    proto.SerializePartialToString()
1577
1578  def testIsInitialized(self):
1579    # Trivial cases - all optional fields and extensions.
1580    proto = unittest_pb2.TestAllTypes()
1581    self.assertInitialized(proto)
1582    proto = unittest_pb2.TestAllExtensions()
1583    self.assertInitialized(proto)
1584
1585    # The case of uninitialized required fields.
1586    proto = unittest_pb2.TestRequired()
1587    self.assertNotInitialized(proto)
1588    proto.a = proto.b = proto.c = 2
1589    self.assertInitialized(proto)
1590
1591    # The case of uninitialized submessage.
1592    proto = unittest_pb2.TestRequiredForeign()
1593    self.assertInitialized(proto)
1594    proto.optional_message.a = 1
1595    self.assertNotInitialized(proto)
1596    proto.optional_message.b = 0
1597    proto.optional_message.c = 0
1598    self.assertInitialized(proto)
1599
1600    # Uninitialized repeated submessage.
1601    message1 = proto.repeated_message.add()
1602    self.assertNotInitialized(proto)
1603    message1.a = message1.b = message1.c = 0
1604    self.assertInitialized(proto)
1605
1606    # Uninitialized repeated group in an extension.
1607    proto = unittest_pb2.TestAllExtensions()
1608    extension = unittest_pb2.TestRequired.multi
1609    message1 = proto.Extensions[extension].add()
1610    message2 = proto.Extensions[extension].add()
1611    self.assertNotInitialized(proto)
1612    message1.a = 1
1613    message1.b = 1
1614    message1.c = 1
1615    self.assertNotInitialized(proto)
1616    message2.a = 2
1617    message2.b = 2
1618    message2.c = 2
1619    self.assertInitialized(proto)
1620
1621    # Uninitialized nonrepeated message in an extension.
1622    proto = unittest_pb2.TestAllExtensions()
1623    extension = unittest_pb2.TestRequired.single
1624    proto.Extensions[extension].a = 1
1625    self.assertNotInitialized(proto)
1626    proto.Extensions[extension].b = 2
1627    proto.Extensions[extension].c = 3
1628    self.assertInitialized(proto)
1629
1630    # Try passing an errors list.
1631    errors = []
1632    proto = unittest_pb2.TestRequired()
1633    self.assertFalse(proto.IsInitialized(errors))
1634    self.assertEqual(errors, ['a', 'b', 'c'])
1635
1636  @unittest.skipIf(
1637      api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1638      'Errors are only available from the most recent C++ implementation.')
1639  def testFileDescriptorErrors(self):
1640    file_name = 'test_file_descriptor_errors.proto'
1641    package_name = 'test_file_descriptor_errors.proto'
1642    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1643    file_descriptor_proto.name = file_name
1644    file_descriptor_proto.package = package_name
1645    m1 = file_descriptor_proto.message_type.add()
1646    m1.name = 'msg1'
1647    # Compiles the proto into the C++ descriptor pool
1648    descriptor.FileDescriptor(
1649        file_name,
1650        package_name,
1651        serialized_pb=file_descriptor_proto.SerializeToString())
1652    # Add a FileDescriptorProto that has duplicate symbols
1653    another_file_name = 'another_test_file_descriptor_errors.proto'
1654    file_descriptor_proto.name = another_file_name
1655    m2 = file_descriptor_proto.message_type.add()
1656    m2.name = 'msg2'
1657    with self.assertRaises(TypeError) as cm:
1658      descriptor.FileDescriptor(
1659          another_file_name,
1660          package_name,
1661          serialized_pb=file_descriptor_proto.SerializeToString())
1662      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1663                      getattr(cm.expected, '__name__', cm.expected))
1664      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1665      # Error message will say something about this definition being a
1666      # duplicate, though we don't check the message exactly to avoid a
1667      # dependency on the C++ logging code.
1668      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
1669
1670  def testStringUTF8Encoding(self):
1671    proto = unittest_pb2.TestAllTypes()
1672
1673    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1674    self.assertRaises(TypeError,
1675                      setattr, proto, 'optional_bytes', u'unicode object')
1676
1677    # Check that the default value is of python's 'unicode' type.
1678    self.assertEqual(type(proto.optional_string), six.text_type)
1679
1680    proto.optional_string = six.text_type('Testing')
1681    self.assertEqual(proto.optional_string, str('Testing'))
1682
1683    # Assign a value of type 'str' which can be encoded in UTF-8.
1684    proto.optional_string = str('Testing')
1685    self.assertEqual(proto.optional_string, six.text_type('Testing'))
1686
1687    # Try to assign a 'bytes' object which contains non-UTF-8.
1688    self.assertRaises(ValueError,
1689                      setattr, proto, 'optional_string', b'a\x80a')
1690    # No exception: Assign already encoded UTF-8 bytes to a string field.
1691    utf8_bytes = u'Тест'.encode('utf-8')
1692    proto.optional_string = utf8_bytes
1693    # No exception: Assign the a non-ascii unicode object.
1694    proto.optional_string = u'Тест'
1695    # No exception thrown (normal str assignment containing ASCII).
1696    proto.optional_string = 'abc'
1697
1698  def testStringUTF8Serialization(self):
1699    proto = message_set_extensions_pb2.TestMessageSet()
1700    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
1701    extension = extension_message.message_set_extension
1702
1703    test_utf8 = u'Тест'
1704    test_utf8_bytes = test_utf8.encode('utf-8')
1705
1706    # 'Test' in another language, using UTF-8 charset.
1707    proto.Extensions[extension].str = test_utf8
1708
1709    # Serialize using the MessageSet wire format (this is specified in the
1710    # .proto file).
1711    serialized = proto.SerializeToString()
1712
1713    # Check byte size.
1714    self.assertEqual(proto.ByteSize(), len(serialized))
1715
1716    raw = unittest_mset_pb2.RawMessageSet()
1717    bytes_read = raw.MergeFromString(serialized)
1718    self.assertEqual(len(serialized), bytes_read)
1719
1720    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
1721
1722    self.assertEqual(1, len(raw.item))
1723    # Check that the type_id is the same as the tag ID in the .proto file.
1724    self.assertEqual(raw.item[0].type_id, 98418634)
1725
1726    # Check the actual bytes on the wire.
1727    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
1728    bytes_read = message2.MergeFromString(raw.item[0].message)
1729    self.assertEqual(len(raw.item[0].message), bytes_read)
1730
1731    self.assertEqual(type(message2.str), six.text_type)
1732    self.assertEqual(message2.str, test_utf8)
1733
1734    # The pure Python API throws an exception on MergeFromString(),
1735    # if any of the string fields of the message can't be UTF-8 decoded.
1736    # The C++ implementation of the API has no way to check that on
1737    # MergeFromString and thus has no way to throw the exception.
1738    #
1739    # The pure Python API always returns objects of type 'unicode' (UTF-8
1740    # encoded), or 'bytes' (in 7 bit ASCII).
1741    badbytes = raw.item[0].message.replace(
1742        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
1743
1744    unicode_decode_failed = False
1745    try:
1746      message2.MergeFromString(badbytes)
1747    except UnicodeDecodeError:
1748      unicode_decode_failed = True
1749    string_field = message2.str
1750    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
1751
1752  def testBytesInTextFormat(self):
1753    proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
1754    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
1755                     six.text_type(proto))
1756
1757  def testEmptyNestedMessage(self):
1758    proto = unittest_pb2.TestAllTypes()
1759    proto.optional_nested_message.MergeFrom(
1760        unittest_pb2.TestAllTypes.NestedMessage())
1761    self.assertTrue(proto.HasField('optional_nested_message'))
1762
1763    proto = unittest_pb2.TestAllTypes()
1764    proto.optional_nested_message.CopyFrom(
1765        unittest_pb2.TestAllTypes.NestedMessage())
1766    self.assertTrue(proto.HasField('optional_nested_message'))
1767
1768    proto = unittest_pb2.TestAllTypes()
1769    bytes_read = proto.optional_nested_message.MergeFromString(b'')
1770    self.assertEqual(0, bytes_read)
1771    self.assertTrue(proto.HasField('optional_nested_message'))
1772
1773    proto = unittest_pb2.TestAllTypes()
1774    proto.optional_nested_message.ParseFromString(b'')
1775    self.assertTrue(proto.HasField('optional_nested_message'))
1776
1777    serialized = proto.SerializeToString()
1778    proto2 = unittest_pb2.TestAllTypes()
1779    self.assertEqual(
1780        len(serialized),
1781        proto2.MergeFromString(serialized))
1782    self.assertTrue(proto2.HasField('optional_nested_message'))
1783
1784  def testSetInParent(self):
1785    proto = unittest_pb2.TestAllTypes()
1786    self.assertFalse(proto.HasField('optionalgroup'))
1787    proto.optionalgroup.SetInParent()
1788    self.assertTrue(proto.HasField('optionalgroup'))
1789
1790  def testPackageInitializationImport(self):
1791    """Test that we can import nested messages from their __init__.py.
1792
1793    Such setup is not trivial since at the time of processing of __init__.py one
1794    can't refer to its submodules by name in code, so expressions like
1795    google.protobuf.internal.import_test_package.inner_pb2
1796    don't work. They do work in imports, so we have assign an alias at import
1797    and then use that alias in generated code.
1798    """
1799    # We import here since it's the import that used to fail, and we want
1800    # the failure to have the right context.
1801    # pylint: disable=g-import-not-at-top
1802    from google.protobuf.internal import import_test_package
1803    # pylint: enable=g-import-not-at-top
1804    msg = import_test_package.myproto.Outer()
1805    # Just check the default value.
1806    self.assertEqual(57, msg.inner.value)
1807
1808#  Since we had so many tests for protocol buffer equality, we broke these out
1809#  into separate TestCase classes.
1810
1811
1812class TestAllTypesEqualityTest(unittest.TestCase):
1813
1814  def setUp(self):
1815    self.first_proto = unittest_pb2.TestAllTypes()
1816    self.second_proto = unittest_pb2.TestAllTypes()
1817
1818  def testNotHashable(self):
1819    self.assertRaises(TypeError, hash, self.first_proto)
1820
1821  def testSelfEquality(self):
1822    self.assertEqual(self.first_proto, self.first_proto)
1823
1824  def testEmptyProtosEqual(self):
1825    self.assertEqual(self.first_proto, self.second_proto)
1826
1827
1828class FullProtosEqualityTest(unittest.TestCase):
1829
1830  """Equality tests using completely-full protos as a starting point."""
1831
1832  def setUp(self):
1833    self.first_proto = unittest_pb2.TestAllTypes()
1834    self.second_proto = unittest_pb2.TestAllTypes()
1835    test_util.SetAllFields(self.first_proto)
1836    test_util.SetAllFields(self.second_proto)
1837
1838  def testNotHashable(self):
1839    self.assertRaises(TypeError, hash, self.first_proto)
1840
1841  def testNoneNotEqual(self):
1842    self.assertNotEqual(self.first_proto, None)
1843    self.assertNotEqual(None, self.second_proto)
1844
1845  def testNotEqualToOtherMessage(self):
1846    third_proto = unittest_pb2.TestRequired()
1847    self.assertNotEqual(self.first_proto, third_proto)
1848    self.assertNotEqual(third_proto, self.second_proto)
1849
1850  def testAllFieldsFilledEquality(self):
1851    self.assertEqual(self.first_proto, self.second_proto)
1852
1853  def testNonRepeatedScalar(self):
1854    # Nonrepeated scalar field change should cause inequality.
1855    self.first_proto.optional_int32 += 1
1856    self.assertNotEqual(self.first_proto, self.second_proto)
1857    # ...as should clearing a field.
1858    self.first_proto.ClearField('optional_int32')
1859    self.assertNotEqual(self.first_proto, self.second_proto)
1860
1861  def testNonRepeatedComposite(self):
1862    # Change a nonrepeated composite field.
1863    self.first_proto.optional_nested_message.bb += 1
1864    self.assertNotEqual(self.first_proto, self.second_proto)
1865    self.first_proto.optional_nested_message.bb -= 1
1866    self.assertEqual(self.first_proto, self.second_proto)
1867    # Clear a field in the nested message.
1868    self.first_proto.optional_nested_message.ClearField('bb')
1869    self.assertNotEqual(self.first_proto, self.second_proto)
1870    self.first_proto.optional_nested_message.bb = (
1871        self.second_proto.optional_nested_message.bb)
1872    self.assertEqual(self.first_proto, self.second_proto)
1873    # Remove the nested message entirely.
1874    self.first_proto.ClearField('optional_nested_message')
1875    self.assertNotEqual(self.first_proto, self.second_proto)
1876
1877  def testRepeatedScalar(self):
1878    # Change a repeated scalar field.
1879    self.first_proto.repeated_int32.append(5)
1880    self.assertNotEqual(self.first_proto, self.second_proto)
1881    self.first_proto.ClearField('repeated_int32')
1882    self.assertNotEqual(self.first_proto, self.second_proto)
1883
1884  def testRepeatedComposite(self):
1885    # Change value within a repeated composite field.
1886    self.first_proto.repeated_nested_message[0].bb += 1
1887    self.assertNotEqual(self.first_proto, self.second_proto)
1888    self.first_proto.repeated_nested_message[0].bb -= 1
1889    self.assertEqual(self.first_proto, self.second_proto)
1890    # Add a value to a repeated composite field.
1891    self.first_proto.repeated_nested_message.add()
1892    self.assertNotEqual(self.first_proto, self.second_proto)
1893    self.second_proto.repeated_nested_message.add()
1894    self.assertEqual(self.first_proto, self.second_proto)
1895
1896  def testNonRepeatedScalarHasBits(self):
1897    # Ensure that we test "has" bits as well as value for
1898    # nonrepeated scalar field.
1899    self.first_proto.ClearField('optional_int32')
1900    self.second_proto.optional_int32 = 0
1901    self.assertNotEqual(self.first_proto, self.second_proto)
1902
1903  def testNonRepeatedCompositeHasBits(self):
1904    # Ensure that we test "has" bits as well as value for
1905    # nonrepeated composite field.
1906    self.first_proto.ClearField('optional_nested_message')
1907    self.second_proto.optional_nested_message.ClearField('bb')
1908    self.assertNotEqual(self.first_proto, self.second_proto)
1909    self.first_proto.optional_nested_message.bb = 0
1910    self.first_proto.optional_nested_message.ClearField('bb')
1911    self.assertEqual(self.first_proto, self.second_proto)
1912
1913
1914class ExtensionEqualityTest(unittest.TestCase):
1915
1916  def testExtensionEquality(self):
1917    first_proto = unittest_pb2.TestAllExtensions()
1918    second_proto = unittest_pb2.TestAllExtensions()
1919    self.assertEqual(first_proto, second_proto)
1920    test_util.SetAllExtensions(first_proto)
1921    self.assertNotEqual(first_proto, second_proto)
1922    test_util.SetAllExtensions(second_proto)
1923    self.assertEqual(first_proto, second_proto)
1924
1925    # Ensure that we check value equality.
1926    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
1927    self.assertNotEqual(first_proto, second_proto)
1928    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
1929    self.assertEqual(first_proto, second_proto)
1930
1931    # Ensure that we also look at "has" bits.
1932    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
1933    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1934    self.assertNotEqual(first_proto, second_proto)
1935    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1936    self.assertEqual(first_proto, second_proto)
1937
1938    # Ensure that differences in cached values
1939    # don't matter if "has" bits are both false.
1940    first_proto = unittest_pb2.TestAllExtensions()
1941    second_proto = unittest_pb2.TestAllExtensions()
1942    self.assertEqual(
1943        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
1944    self.assertEqual(first_proto, second_proto)
1945
1946
1947class MutualRecursionEqualityTest(unittest.TestCase):
1948
1949  def testEqualityWithMutualRecursion(self):
1950    first_proto = unittest_pb2.TestMutualRecursionA()
1951    second_proto = unittest_pb2.TestMutualRecursionA()
1952    self.assertEqual(first_proto, second_proto)
1953    first_proto.bb.a.bb.optional_int32 = 23
1954    self.assertNotEqual(first_proto, second_proto)
1955    second_proto.bb.a.bb.optional_int32 = 23
1956    self.assertEqual(first_proto, second_proto)
1957
1958
1959class ByteSizeTest(unittest.TestCase):
1960
1961  def setUp(self):
1962    self.proto = unittest_pb2.TestAllTypes()
1963    self.extended_proto = more_extensions_pb2.ExtendedMessage()
1964    self.packed_proto = unittest_pb2.TestPackedTypes()
1965    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
1966
1967  def Size(self):
1968    return self.proto.ByteSize()
1969
1970  def testEmptyMessage(self):
1971    self.assertEqual(0, self.proto.ByteSize())
1972
1973  def testSizedOnKwargs(self):
1974    # Use a separate message to ensure testing right after creation.
1975    proto = unittest_pb2.TestAllTypes()
1976    self.assertEqual(0, proto.ByteSize())
1977    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
1978    # One byte for the tag, one to encode varint 1.
1979    self.assertEqual(2, proto_kwargs.ByteSize())
1980
1981  def testVarints(self):
1982    def Test(i, expected_varint_size):
1983      self.proto.Clear()
1984      self.proto.optional_int64 = i
1985      # Add one to the varint size for the tag info
1986      # for tag 1.
1987      self.assertEqual(expected_varint_size + 1, self.Size())
1988    Test(0, 1)
1989    Test(1, 1)
1990    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
1991      Test((1 << i) - 1, num_bytes)
1992    Test(-1, 10)
1993    Test(-2, 10)
1994    Test(-(1 << 63), 10)
1995
1996  def testStrings(self):
1997    self.proto.optional_string = ''
1998    # Need one byte for tag info (tag #14), and one byte for length.
1999    self.assertEqual(2, self.Size())
2000
2001    self.proto.optional_string = 'abc'
2002    # Need one byte for tag info (tag #14), and one byte for length.
2003    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2004
2005    self.proto.optional_string = 'x' * 128
2006    # Need one byte for tag info (tag #14), and TWO bytes for length.
2007    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2008
2009  def testOtherNumerics(self):
2010    self.proto.optional_fixed32 = 1234
2011    # One byte for tag and 4 bytes for fixed32.
2012    self.assertEqual(5, self.Size())
2013    self.proto = unittest_pb2.TestAllTypes()
2014
2015    self.proto.optional_fixed64 = 1234
2016    # One byte for tag and 8 bytes for fixed64.
2017    self.assertEqual(9, self.Size())
2018    self.proto = unittest_pb2.TestAllTypes()
2019
2020    self.proto.optional_float = 1.234
2021    # One byte for tag and 4 bytes for float.
2022    self.assertEqual(5, self.Size())
2023    self.proto = unittest_pb2.TestAllTypes()
2024
2025    self.proto.optional_double = 1.234
2026    # One byte for tag and 8 bytes for float.
2027    self.assertEqual(9, self.Size())
2028    self.proto = unittest_pb2.TestAllTypes()
2029
2030    self.proto.optional_sint32 = 64
2031    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2032    self.assertEqual(3, self.Size())
2033    self.proto = unittest_pb2.TestAllTypes()
2034
2035  def testComposites(self):
2036    # 3 bytes.
2037    self.proto.optional_nested_message.bb = (1 << 14)
2038    # Plus one byte for bb tag.
2039    # Plus 1 byte for optional_nested_message serialized size.
2040    # Plus two bytes for optional_nested_message tag.
2041    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2042
2043  def testGroups(self):
2044    # 4 bytes.
2045    self.proto.optionalgroup.a = (1 << 21)
2046    # Plus two bytes for |a| tag.
2047    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2048    self.assertEqual(4 + 2 + 2*2, self.Size())
2049
2050  def testRepeatedScalars(self):
2051    self.proto.repeated_int32.append(10)  # 1 byte.
2052    self.proto.repeated_int32.append(128)  # 2 bytes.
2053    # Also need 2 bytes for each entry for tag.
2054    self.assertEqual(1 + 2 + 2*2, self.Size())
2055
2056  def testRepeatedScalarsExtend(self):
2057    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2058    # Also need 2 bytes for each entry for tag.
2059    self.assertEqual(1 + 2 + 2*2, self.Size())
2060
2061  def testRepeatedScalarsRemove(self):
2062    self.proto.repeated_int32.append(10)  # 1 byte.
2063    self.proto.repeated_int32.append(128)  # 2 bytes.
2064    # Also need 2 bytes for each entry for tag.
2065    self.assertEqual(1 + 2 + 2*2, self.Size())
2066    self.proto.repeated_int32.remove(128)
2067    self.assertEqual(1 + 2, self.Size())
2068
2069  def testRepeatedComposites(self):
2070    # Empty message.  2 bytes tag plus 1 byte length.
2071    foreign_message_0 = self.proto.repeated_nested_message.add()
2072    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2073    foreign_message_1 = self.proto.repeated_nested_message.add()
2074    foreign_message_1.bb = 7
2075    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2076
2077  def testRepeatedCompositesDelete(self):
2078    # Empty message.  2 bytes tag plus 1 byte length.
2079    foreign_message_0 = self.proto.repeated_nested_message.add()
2080    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2081    foreign_message_1 = self.proto.repeated_nested_message.add()
2082    foreign_message_1.bb = 9
2083    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2084
2085    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2086    del self.proto.repeated_nested_message[0]
2087    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2088
2089    # Now add a new message.
2090    foreign_message_2 = self.proto.repeated_nested_message.add()
2091    foreign_message_2.bb = 12
2092
2093    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2094    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2095    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2096
2097    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2098    del self.proto.repeated_nested_message[1]
2099    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2100
2101    del self.proto.repeated_nested_message[0]
2102    self.assertEqual(0, self.Size())
2103
2104  def testRepeatedGroups(self):
2105    # 2-byte START_GROUP plus 2-byte END_GROUP.
2106    group_0 = self.proto.repeatedgroup.add()
2107    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2108    # plus 2-byte END_GROUP.
2109    group_1 = self.proto.repeatedgroup.add()
2110    group_1.a =  7
2111    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2112
2113  def testExtensions(self):
2114    proto = unittest_pb2.TestAllExtensions()
2115    self.assertEqual(0, proto.ByteSize())
2116    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2117    proto.Extensions[extension] = 23
2118    # 1 byte for tag, 1 byte for value.
2119    self.assertEqual(2, proto.ByteSize())
2120
2121  def testCacheInvalidationForNonrepeatedScalar(self):
2122    # Test non-extension.
2123    self.proto.optional_int32 = 1
2124    self.assertEqual(2, self.proto.ByteSize())
2125    self.proto.optional_int32 = 128
2126    self.assertEqual(3, self.proto.ByteSize())
2127    self.proto.ClearField('optional_int32')
2128    self.assertEqual(0, self.proto.ByteSize())
2129
2130    # Test within extension.
2131    extension = more_extensions_pb2.optional_int_extension
2132    self.extended_proto.Extensions[extension] = 1
2133    self.assertEqual(2, self.extended_proto.ByteSize())
2134    self.extended_proto.Extensions[extension] = 128
2135    self.assertEqual(3, self.extended_proto.ByteSize())
2136    self.extended_proto.ClearExtension(extension)
2137    self.assertEqual(0, self.extended_proto.ByteSize())
2138
2139  def testCacheInvalidationForRepeatedScalar(self):
2140    # Test non-extension.
2141    self.proto.repeated_int32.append(1)
2142    self.assertEqual(3, self.proto.ByteSize())
2143    self.proto.repeated_int32.append(1)
2144    self.assertEqual(6, self.proto.ByteSize())
2145    self.proto.repeated_int32[1] = 128
2146    self.assertEqual(7, self.proto.ByteSize())
2147    self.proto.ClearField('repeated_int32')
2148    self.assertEqual(0, self.proto.ByteSize())
2149
2150    # Test within extension.
2151    extension = more_extensions_pb2.repeated_int_extension
2152    repeated = self.extended_proto.Extensions[extension]
2153    repeated.append(1)
2154    self.assertEqual(2, self.extended_proto.ByteSize())
2155    repeated.append(1)
2156    self.assertEqual(4, self.extended_proto.ByteSize())
2157    repeated[1] = 128
2158    self.assertEqual(5, self.extended_proto.ByteSize())
2159    self.extended_proto.ClearExtension(extension)
2160    self.assertEqual(0, self.extended_proto.ByteSize())
2161
2162  def testCacheInvalidationForNonrepeatedMessage(self):
2163    # Test non-extension.
2164    self.proto.optional_foreign_message.c = 1
2165    self.assertEqual(5, self.proto.ByteSize())
2166    self.proto.optional_foreign_message.c = 128
2167    self.assertEqual(6, self.proto.ByteSize())
2168    self.proto.optional_foreign_message.ClearField('c')
2169    self.assertEqual(3, self.proto.ByteSize())
2170    self.proto.ClearField('optional_foreign_message')
2171    self.assertEqual(0, self.proto.ByteSize())
2172
2173    if api_implementation.Type() == 'python':
2174      # This is only possible in pure-Python implementation of the API.
2175      child = self.proto.optional_foreign_message
2176      self.proto.ClearField('optional_foreign_message')
2177      child.c = 128
2178      self.assertEqual(0, self.proto.ByteSize())
2179
2180    # Test within extension.
2181    extension = more_extensions_pb2.optional_message_extension
2182    child = self.extended_proto.Extensions[extension]
2183    self.assertEqual(0, self.extended_proto.ByteSize())
2184    child.foreign_message_int = 1
2185    self.assertEqual(4, self.extended_proto.ByteSize())
2186    child.foreign_message_int = 128
2187    self.assertEqual(5, self.extended_proto.ByteSize())
2188    self.extended_proto.ClearExtension(extension)
2189    self.assertEqual(0, self.extended_proto.ByteSize())
2190
2191  def testCacheInvalidationForRepeatedMessage(self):
2192    # Test non-extension.
2193    child0 = self.proto.repeated_foreign_message.add()
2194    self.assertEqual(3, self.proto.ByteSize())
2195    self.proto.repeated_foreign_message.add()
2196    self.assertEqual(6, self.proto.ByteSize())
2197    child0.c = 1
2198    self.assertEqual(8, self.proto.ByteSize())
2199    self.proto.ClearField('repeated_foreign_message')
2200    self.assertEqual(0, self.proto.ByteSize())
2201
2202    # Test within extension.
2203    extension = more_extensions_pb2.repeated_message_extension
2204    child_list = self.extended_proto.Extensions[extension]
2205    child0 = child_list.add()
2206    self.assertEqual(2, self.extended_proto.ByteSize())
2207    child_list.add()
2208    self.assertEqual(4, self.extended_proto.ByteSize())
2209    child0.foreign_message_int = 1
2210    self.assertEqual(6, self.extended_proto.ByteSize())
2211    child0.ClearField('foreign_message_int')
2212    self.assertEqual(4, self.extended_proto.ByteSize())
2213    self.extended_proto.ClearExtension(extension)
2214    self.assertEqual(0, self.extended_proto.ByteSize())
2215
2216  def testPackedRepeatedScalars(self):
2217    self.assertEqual(0, self.packed_proto.ByteSize())
2218
2219    self.packed_proto.packed_int32.append(10)   # 1 byte.
2220    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2221    # The tag is 2 bytes (the field number is 90), and the varint
2222    # storing the length is 1 byte.
2223    int_size = 1 + 2 + 3
2224    self.assertEqual(int_size, self.packed_proto.ByteSize())
2225
2226    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2227    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2228    # 2 more tag bytes, 1 more length byte.
2229    double_size = 8 + 8 + 3
2230    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2231
2232    self.packed_proto.ClearField('packed_int32')
2233    self.assertEqual(double_size, self.packed_proto.ByteSize())
2234
2235  def testPackedExtensions(self):
2236    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2237    extension = self.packed_extended_proto.Extensions[
2238        unittest_pb2.packed_fixed32_extension]
2239    extension.extend([1, 2, 3, 4])   # 16 bytes
2240    # Tag is 3 bytes.
2241    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2242
2243
2244# Issues to be sure to cover include:
2245#   * Handling of unrecognized tags ("uninterpreted_bytes").
2246#   * Handling of MessageSets.
2247#   * Consistent ordering of tags in the wire format,
2248#     including ordering between extensions and non-extension
2249#     fields.
2250#   * Consistent serialization of negative numbers, especially
2251#     negative int32s.
2252#   * Handling of empty submessages (with and without "has"
2253#     bits set).
2254
2255class SerializationTest(unittest.TestCase):
2256
2257  def testSerializeEmtpyMessage(self):
2258    first_proto = unittest_pb2.TestAllTypes()
2259    second_proto = unittest_pb2.TestAllTypes()
2260    serialized = first_proto.SerializeToString()
2261    self.assertEqual(first_proto.ByteSize(), len(serialized))
2262    self.assertEqual(
2263        len(serialized),
2264        second_proto.MergeFromString(serialized))
2265    self.assertEqual(first_proto, second_proto)
2266
2267  def testSerializeAllFields(self):
2268    first_proto = unittest_pb2.TestAllTypes()
2269    second_proto = unittest_pb2.TestAllTypes()
2270    test_util.SetAllFields(first_proto)
2271    serialized = first_proto.SerializeToString()
2272    self.assertEqual(first_proto.ByteSize(), len(serialized))
2273    self.assertEqual(
2274        len(serialized),
2275        second_proto.MergeFromString(serialized))
2276    self.assertEqual(first_proto, second_proto)
2277
2278  def testSerializeAllExtensions(self):
2279    first_proto = unittest_pb2.TestAllExtensions()
2280    second_proto = unittest_pb2.TestAllExtensions()
2281    test_util.SetAllExtensions(first_proto)
2282    serialized = first_proto.SerializeToString()
2283    self.assertEqual(
2284        len(serialized),
2285        second_proto.MergeFromString(serialized))
2286    self.assertEqual(first_proto, second_proto)
2287
2288  def testSerializeWithOptionalGroup(self):
2289    first_proto = unittest_pb2.TestAllTypes()
2290    second_proto = unittest_pb2.TestAllTypes()
2291    first_proto.optionalgroup.a = 242
2292    serialized = first_proto.SerializeToString()
2293    self.assertEqual(
2294        len(serialized),
2295        second_proto.MergeFromString(serialized))
2296    self.assertEqual(first_proto, second_proto)
2297
2298  def testSerializeNegativeValues(self):
2299    first_proto = unittest_pb2.TestAllTypes()
2300
2301    first_proto.optional_int32 = -1
2302    first_proto.optional_int64 = -(2 << 40)
2303    first_proto.optional_sint32 = -3
2304    first_proto.optional_sint64 = -(4 << 40)
2305    first_proto.optional_sfixed32 = -5
2306    first_proto.optional_sfixed64 = -(6 << 40)
2307
2308    second_proto = unittest_pb2.TestAllTypes.FromString(
2309        first_proto.SerializeToString())
2310
2311    self.assertEqual(first_proto, second_proto)
2312
2313  def testParseTruncated(self):
2314    # This test is only applicable for the Python implementation of the API.
2315    if api_implementation.Type() != 'python':
2316      return
2317
2318    first_proto = unittest_pb2.TestAllTypes()
2319    test_util.SetAllFields(first_proto)
2320    serialized = first_proto.SerializeToString()
2321
2322    for truncation_point in range(len(serialized) + 1):
2323      try:
2324        second_proto = unittest_pb2.TestAllTypes()
2325        unknown_fields = unittest_pb2.TestEmptyMessage()
2326        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2327        # If we didn't raise an error then we read exactly the amount expected.
2328        self.assertEqual(truncation_point, pos)
2329
2330        # Parsing to unknown fields should not throw if parsing to known fields
2331        # did not.
2332        try:
2333          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2334          self.assertEqual(truncation_point, pos2)
2335        except message.DecodeError:
2336          self.fail('Parsing unknown fields failed when parsing known fields '
2337                    'did not.')
2338      except message.DecodeError:
2339        # Parsing unknown fields should also fail.
2340        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2341                          serialized, 0, truncation_point)
2342
2343  def testCanonicalSerializationOrder(self):
2344    proto = more_messages_pb2.OutOfOrderFields()
2345    # These are also their tag numbers.  Even though we're setting these in
2346    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2347    # file, they should nonetheless be serialized in tag order.
2348    proto.optional_sint32 = 5
2349    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2350    proto.optional_uint32 = 3
2351    proto.Extensions[more_messages_pb2.optional_int64] = 2
2352    proto.optional_int32 = 1
2353    serialized = proto.SerializeToString()
2354    self.assertEqual(proto.ByteSize(), len(serialized))
2355    d = _MiniDecoder(serialized)
2356    ReadTag = d.ReadFieldNumberAndWireType
2357    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2358    self.assertEqual(1, d.ReadInt32())
2359    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2360    self.assertEqual(2, d.ReadInt64())
2361    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2362    self.assertEqual(3, d.ReadUInt32())
2363    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2364    self.assertEqual(4, d.ReadUInt64())
2365    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2366    self.assertEqual(5, d.ReadSInt32())
2367
2368  def testCanonicalSerializationOrderSameAsCpp(self):
2369    # Copy of the same test we use for C++.
2370    proto = unittest_pb2.TestFieldOrderings()
2371    test_util.SetAllFieldsAndExtensions(proto)
2372    serialized = proto.SerializeToString()
2373    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2374
2375  def testMergeFromStringWhenFieldsAlreadySet(self):
2376    first_proto = unittest_pb2.TestAllTypes()
2377    first_proto.repeated_string.append('foobar')
2378    first_proto.optional_int32 = 23
2379    first_proto.optional_nested_message.bb = 42
2380    serialized = first_proto.SerializeToString()
2381
2382    second_proto = unittest_pb2.TestAllTypes()
2383    second_proto.repeated_string.append('baz')
2384    second_proto.optional_int32 = 100
2385    second_proto.optional_nested_message.bb = 999
2386
2387    bytes_parsed = second_proto.MergeFromString(serialized)
2388    self.assertEqual(len(serialized), bytes_parsed)
2389
2390    # Ensure that we append to repeated fields.
2391    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2392    # Ensure that we overwrite nonrepeatd scalars.
2393    self.assertEqual(23, second_proto.optional_int32)
2394    # Ensure that we recursively call MergeFromString() on
2395    # submessages.
2396    self.assertEqual(42, second_proto.optional_nested_message.bb)
2397
2398  def testMessageSetWireFormat(self):
2399    proto = message_set_extensions_pb2.TestMessageSet()
2400    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2401    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2402    extension1 = extension_message1.message_set_extension
2403    extension2 = extension_message2.message_set_extension
2404    extension3 = message_set_extensions_pb2.message_set_extension3
2405    proto.Extensions[extension1].i = 123
2406    proto.Extensions[extension2].str = 'foo'
2407    proto.Extensions[extension3].text = 'bar'
2408
2409    # Serialize using the MessageSet wire format (this is specified in the
2410    # .proto file).
2411    serialized = proto.SerializeToString()
2412
2413    raw = unittest_mset_pb2.RawMessageSet()
2414    self.assertEqual(False,
2415                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2416    self.assertEqual(
2417        len(serialized),
2418        raw.MergeFromString(serialized))
2419    self.assertEqual(3, len(raw.item))
2420
2421    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2422    self.assertEqual(
2423        len(raw.item[0].message),
2424        message1.MergeFromString(raw.item[0].message))
2425    self.assertEqual(123, message1.i)
2426
2427    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2428    self.assertEqual(
2429        len(raw.item[1].message),
2430        message2.MergeFromString(raw.item[1].message))
2431    self.assertEqual('foo', message2.str)
2432
2433    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2434    self.assertEqual(
2435        len(raw.item[2].message),
2436        message3.MergeFromString(raw.item[2].message))
2437    self.assertEqual('bar', message3.text)
2438
2439    # Deserialize using the MessageSet wire format.
2440    proto2 = message_set_extensions_pb2.TestMessageSet()
2441    self.assertEqual(
2442        len(serialized),
2443        proto2.MergeFromString(serialized))
2444    self.assertEqual(123, proto2.Extensions[extension1].i)
2445    self.assertEqual('foo', proto2.Extensions[extension2].str)
2446    self.assertEqual('bar', proto2.Extensions[extension3].text)
2447
2448    # Check byte size.
2449    self.assertEqual(proto2.ByteSize(), len(serialized))
2450    self.assertEqual(proto.ByteSize(), len(serialized))
2451
2452  def testMessageSetWireFormatUnknownExtension(self):
2453    # Create a message using the message set wire format with an unknown
2454    # message.
2455    raw = unittest_mset_pb2.RawMessageSet()
2456
2457    # Add an item.
2458    item = raw.item.add()
2459    item.type_id = 98418603
2460    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2461    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2462    message1.i = 12345
2463    item.message = message1.SerializeToString()
2464
2465    # Add a second, unknown extension.
2466    item = raw.item.add()
2467    item.type_id = 98418604
2468    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2469    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2470    message1.i = 12346
2471    item.message = message1.SerializeToString()
2472
2473    # Add another unknown extension.
2474    item = raw.item.add()
2475    item.type_id = 98418605
2476    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2477    message1.str = 'foo'
2478    item.message = message1.SerializeToString()
2479
2480    serialized = raw.SerializeToString()
2481
2482    # Parse message using the message set wire format.
2483    proto = message_set_extensions_pb2.TestMessageSet()
2484    self.assertEqual(
2485        len(serialized),
2486        proto.MergeFromString(serialized))
2487
2488    # Check that the message parsed well.
2489    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2490    extension1 = extension_message1.message_set_extension
2491    self.assertEqual(12345, proto.Extensions[extension1].i)
2492
2493  def testUnknownFields(self):
2494    proto = unittest_pb2.TestAllTypes()
2495    test_util.SetAllFields(proto)
2496
2497    serialized = proto.SerializeToString()
2498
2499    # The empty message should be parsable with all of the fields
2500    # unknown.
2501    proto2 = unittest_pb2.TestEmptyMessage()
2502
2503    # Parsing this message should succeed.
2504    self.assertEqual(
2505        len(serialized),
2506        proto2.MergeFromString(serialized))
2507
2508    # Now test with a int64 field set.
2509    proto = unittest_pb2.TestAllTypes()
2510    proto.optional_int64 = 0x0fffffffffffffff
2511    serialized = proto.SerializeToString()
2512    # The empty message should be parsable with all of the fields
2513    # unknown.
2514    proto2 = unittest_pb2.TestEmptyMessage()
2515    # Parsing this message should succeed.
2516    self.assertEqual(
2517        len(serialized),
2518        proto2.MergeFromString(serialized))
2519
2520  def _CheckRaises(self, exc_class, callable_obj, exception):
2521    """This method checks if the excpetion type and message are as expected."""
2522    try:
2523      callable_obj()
2524    except exc_class as ex:
2525      # Check if the exception message is the right one.
2526      self.assertEqual(exception, str(ex))
2527      return
2528    else:
2529      raise self.failureException('%s not raised' % str(exc_class))
2530
2531  def testSerializeUninitialized(self):
2532    proto = unittest_pb2.TestRequired()
2533    self._CheckRaises(
2534        message.EncodeError,
2535        proto.SerializeToString,
2536        'Message protobuf_unittest.TestRequired is missing required fields: '
2537        'a,b,c')
2538    # Shouldn't raise exceptions.
2539    partial = proto.SerializePartialToString()
2540
2541    proto2 = unittest_pb2.TestRequired()
2542    self.assertFalse(proto2.HasField('a'))
2543    # proto2 ParseFromString does not check that required fields are set.
2544    proto2.ParseFromString(partial)
2545    self.assertFalse(proto2.HasField('a'))
2546
2547    proto.a = 1
2548    self._CheckRaises(
2549        message.EncodeError,
2550        proto.SerializeToString,
2551        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2552    # Shouldn't raise exceptions.
2553    partial = proto.SerializePartialToString()
2554
2555    proto.b = 2
2556    self._CheckRaises(
2557        message.EncodeError,
2558        proto.SerializeToString,
2559        'Message protobuf_unittest.TestRequired is missing required fields: c')
2560    # Shouldn't raise exceptions.
2561    partial = proto.SerializePartialToString()
2562
2563    proto.c = 3
2564    serialized = proto.SerializeToString()
2565    # Shouldn't raise exceptions.
2566    partial = proto.SerializePartialToString()
2567
2568    proto2 = unittest_pb2.TestRequired()
2569    self.assertEqual(
2570        len(serialized),
2571        proto2.MergeFromString(serialized))
2572    self.assertEqual(1, proto2.a)
2573    self.assertEqual(2, proto2.b)
2574    self.assertEqual(3, proto2.c)
2575    self.assertEqual(
2576        len(partial),
2577        proto2.MergeFromString(partial))
2578    self.assertEqual(1, proto2.a)
2579    self.assertEqual(2, proto2.b)
2580    self.assertEqual(3, proto2.c)
2581
2582  def testSerializeUninitializedSubMessage(self):
2583    proto = unittest_pb2.TestRequiredForeign()
2584
2585    # Sub-message doesn't exist yet, so this succeeds.
2586    proto.SerializeToString()
2587
2588    proto.optional_message.a = 1
2589    self._CheckRaises(
2590        message.EncodeError,
2591        proto.SerializeToString,
2592        'Message protobuf_unittest.TestRequiredForeign '
2593        'is missing required fields: '
2594        'optional_message.b,optional_message.c')
2595
2596    proto.optional_message.b = 2
2597    proto.optional_message.c = 3
2598    proto.SerializeToString()
2599
2600    proto.repeated_message.add().a = 1
2601    proto.repeated_message.add().b = 2
2602    self._CheckRaises(
2603        message.EncodeError,
2604        proto.SerializeToString,
2605        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2606        'repeated_message[0].b,repeated_message[0].c,'
2607        'repeated_message[1].a,repeated_message[1].c')
2608
2609    proto.repeated_message[0].b = 2
2610    proto.repeated_message[0].c = 3
2611    proto.repeated_message[1].a = 1
2612    proto.repeated_message[1].c = 3
2613    proto.SerializeToString()
2614
2615  def testSerializeAllPackedFields(self):
2616    first_proto = unittest_pb2.TestPackedTypes()
2617    second_proto = unittest_pb2.TestPackedTypes()
2618    test_util.SetAllPackedFields(first_proto)
2619    serialized = first_proto.SerializeToString()
2620    self.assertEqual(first_proto.ByteSize(), len(serialized))
2621    bytes_read = second_proto.MergeFromString(serialized)
2622    self.assertEqual(second_proto.ByteSize(), bytes_read)
2623    self.assertEqual(first_proto, second_proto)
2624
2625  def testSerializeAllPackedExtensions(self):
2626    first_proto = unittest_pb2.TestPackedExtensions()
2627    second_proto = unittest_pb2.TestPackedExtensions()
2628    test_util.SetAllPackedExtensions(first_proto)
2629    serialized = first_proto.SerializeToString()
2630    bytes_read = second_proto.MergeFromString(serialized)
2631    self.assertEqual(second_proto.ByteSize(), bytes_read)
2632    self.assertEqual(first_proto, second_proto)
2633
2634  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2635    first_proto = unittest_pb2.TestPackedTypes()
2636    first_proto.packed_int32.extend([1, 2])
2637    first_proto.packed_double.append(3.0)
2638    serialized = first_proto.SerializeToString()
2639
2640    second_proto = unittest_pb2.TestPackedTypes()
2641    second_proto.packed_int32.append(3)
2642    second_proto.packed_double.extend([1.0, 2.0])
2643    second_proto.packed_sint32.append(4)
2644
2645    self.assertEqual(
2646        len(serialized),
2647        second_proto.MergeFromString(serialized))
2648    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2649    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2650    self.assertEqual([4], second_proto.packed_sint32)
2651
2652  def testPackedFieldsWireFormat(self):
2653    proto = unittest_pb2.TestPackedTypes()
2654    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2655    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2656    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2657    serialized = proto.SerializeToString()
2658    self.assertEqual(proto.ByteSize(), len(serialized))
2659    d = _MiniDecoder(serialized)
2660    ReadTag = d.ReadFieldNumberAndWireType
2661    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2662    self.assertEqual(1+1+1+2, d.ReadInt32())
2663    self.assertEqual(1, d.ReadInt32())
2664    self.assertEqual(2, d.ReadInt32())
2665    self.assertEqual(150, d.ReadInt32())
2666    self.assertEqual(3, d.ReadInt32())
2667    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2668    self.assertEqual(4, d.ReadInt32())
2669    self.assertEqual(2.0, d.ReadFloat())
2670    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2671    self.assertEqual(8+8, d.ReadInt32())
2672    self.assertEqual(1.0, d.ReadDouble())
2673    self.assertEqual(1000.0, d.ReadDouble())
2674    self.assertTrue(d.EndOfStream())
2675
2676  def testParsePackedFromUnpacked(self):
2677    unpacked = unittest_pb2.TestUnpackedTypes()
2678    test_util.SetAllUnpackedFields(unpacked)
2679    packed = unittest_pb2.TestPackedTypes()
2680    serialized = unpacked.SerializeToString()
2681    self.assertEqual(
2682        len(serialized),
2683        packed.MergeFromString(serialized))
2684    expected = unittest_pb2.TestPackedTypes()
2685    test_util.SetAllPackedFields(expected)
2686    self.assertEqual(expected, packed)
2687
2688  def testParseUnpackedFromPacked(self):
2689    packed = unittest_pb2.TestPackedTypes()
2690    test_util.SetAllPackedFields(packed)
2691    unpacked = unittest_pb2.TestUnpackedTypes()
2692    serialized = packed.SerializeToString()
2693    self.assertEqual(
2694        len(serialized),
2695        unpacked.MergeFromString(serialized))
2696    expected = unittest_pb2.TestUnpackedTypes()
2697    test_util.SetAllUnpackedFields(expected)
2698    self.assertEqual(expected, unpacked)
2699
2700  def testFieldNumbers(self):
2701    proto = unittest_pb2.TestAllTypes()
2702    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2703    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2704    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2705    self.assertEqual(
2706      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2707    self.assertEqual(
2708      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2709    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2710    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2711    self.assertEqual(
2712      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2713    self.assertEqual(
2714      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2715
2716  def testExtensionFieldNumbers(self):
2717    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2718    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2719    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2720    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2721    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2722    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2723    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2724    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2725    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2726    self.assertEqual(
2727      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2728    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2729    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2730      21)
2731    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2732    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2733    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2734    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2735    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2736    self.assertEqual(
2737      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2738    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2739    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2740      51)
2741
2742  def testInitKwargs(self):
2743    proto = unittest_pb2.TestAllTypes(
2744        optional_int32=1,
2745        optional_string='foo',
2746        optional_bool=True,
2747        optional_bytes=b'bar',
2748        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2749        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2750        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2751        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2752        repeated_int32=[1, 2, 3])
2753    self.assertTrue(proto.IsInitialized())
2754    self.assertTrue(proto.HasField('optional_int32'))
2755    self.assertTrue(proto.HasField('optional_string'))
2756    self.assertTrue(proto.HasField('optional_bool'))
2757    self.assertTrue(proto.HasField('optional_bytes'))
2758    self.assertTrue(proto.HasField('optional_nested_message'))
2759    self.assertTrue(proto.HasField('optional_foreign_message'))
2760    self.assertTrue(proto.HasField('optional_nested_enum'))
2761    self.assertTrue(proto.HasField('optional_foreign_enum'))
2762    self.assertEqual(1, proto.optional_int32)
2763    self.assertEqual('foo', proto.optional_string)
2764    self.assertEqual(True, proto.optional_bool)
2765    self.assertEqual(b'bar', proto.optional_bytes)
2766    self.assertEqual(1, proto.optional_nested_message.bb)
2767    self.assertEqual(1, proto.optional_foreign_message.c)
2768    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
2769                     proto.optional_nested_enum)
2770    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
2771    self.assertEqual([1, 2, 3], proto.repeated_int32)
2772
2773  def testInitArgsUnknownFieldName(self):
2774    def InitalizeEmptyMessageWithExtraKeywordArg():
2775      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
2776    self._CheckRaises(
2777        ValueError,
2778        InitalizeEmptyMessageWithExtraKeywordArg,
2779        'Protocol message TestEmptyMessage has no "unknown" field.')
2780
2781  def testInitRequiredKwargs(self):
2782    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
2783    self.assertTrue(proto.IsInitialized())
2784    self.assertTrue(proto.HasField('a'))
2785    self.assertTrue(proto.HasField('b'))
2786    self.assertTrue(proto.HasField('c'))
2787    self.assertTrue(not proto.HasField('dummy2'))
2788    self.assertEqual(1, proto.a)
2789    self.assertEqual(1, proto.b)
2790    self.assertEqual(1, proto.c)
2791
2792  def testInitRequiredForeignKwargs(self):
2793    proto = unittest_pb2.TestRequiredForeign(
2794        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
2795    self.assertTrue(proto.IsInitialized())
2796    self.assertTrue(proto.HasField('optional_message'))
2797    self.assertTrue(proto.optional_message.IsInitialized())
2798    self.assertTrue(proto.optional_message.HasField('a'))
2799    self.assertTrue(proto.optional_message.HasField('b'))
2800    self.assertTrue(proto.optional_message.HasField('c'))
2801    self.assertTrue(not proto.optional_message.HasField('dummy2'))
2802    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
2803                     proto.optional_message)
2804    self.assertEqual(1, proto.optional_message.a)
2805    self.assertEqual(1, proto.optional_message.b)
2806    self.assertEqual(1, proto.optional_message.c)
2807
2808  def testInitRepeatedKwargs(self):
2809    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
2810    self.assertTrue(proto.IsInitialized())
2811    self.assertEqual(1, proto.repeated_int32[0])
2812    self.assertEqual(2, proto.repeated_int32[1])
2813    self.assertEqual(3, proto.repeated_int32[2])
2814
2815
2816class OptionsTest(unittest.TestCase):
2817
2818  def testMessageOptions(self):
2819    proto = message_set_extensions_pb2.TestMessageSet()
2820    self.assertEqual(True,
2821                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2822    proto = unittest_pb2.TestAllTypes()
2823    self.assertEqual(False,
2824                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2825
2826  def testPackedOptions(self):
2827    proto = unittest_pb2.TestAllTypes()
2828    proto.optional_int32 = 1
2829    proto.optional_double = 3.0
2830    for field_descriptor, _ in proto.ListFields():
2831      self.assertEqual(False, field_descriptor.GetOptions().packed)
2832
2833    proto = unittest_pb2.TestPackedTypes()
2834    proto.packed_int32.append(1)
2835    proto.packed_double.append(3.0)
2836    for field_descriptor, _ in proto.ListFields():
2837      self.assertEqual(True, field_descriptor.GetOptions().packed)
2838      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
2839                       field_descriptor.label)
2840
2841
2842
2843class ClassAPITest(unittest.TestCase):
2844
2845  @unittest.skipIf(
2846      api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
2847      'C++ implementation requires a call to MakeDescriptor()')
2848  def testMakeClassWithNestedDescriptor(self):
2849    leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
2850                                      containing_type=None, fields=[],
2851                                      nested_types=[], enum_types=[],
2852                                      extensions=[])
2853    child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
2854                                       containing_type=None, fields=[],
2855                                       nested_types=[leaf_desc], enum_types=[],
2856                                       extensions=[])
2857    sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
2858                                         '', containing_type=None, fields=[],
2859                                         nested_types=[], enum_types=[],
2860                                         extensions=[])
2861    parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
2862                                        containing_type=None, fields=[],
2863                                        nested_types=[child_desc, sibling_desc],
2864                                        enum_types=[], extensions=[])
2865    message_class = reflection.MakeClass(parent_desc)
2866    self.assertIn('child', message_class.__dict__)
2867    self.assertIn('sibling', message_class.__dict__)
2868    self.assertIn('leaf', message_class.child.__dict__)
2869
2870  def _GetSerializedFileDescriptor(self, name):
2871    """Get a serialized representation of a test FileDescriptorProto.
2872
2873    Args:
2874      name: All calls to this must use a unique message name, to avoid
2875          collisions in the cpp descriptor pool.
2876    Returns:
2877      A string containing the serialized form of a test FileDescriptorProto.
2878    """
2879    file_descriptor_str = (
2880        'message_type {'
2881        '  name: "' + name + '"'
2882        '  field {'
2883        '    name: "flat"'
2884        '    number: 1'
2885        '    label: LABEL_REPEATED'
2886        '    type: TYPE_UINT32'
2887        '  }'
2888        '  field {'
2889        '    name: "bar"'
2890        '    number: 2'
2891        '    label: LABEL_OPTIONAL'
2892        '    type: TYPE_MESSAGE'
2893        '    type_name: "Bar"'
2894        '  }'
2895        '  nested_type {'
2896        '    name: "Bar"'
2897        '    field {'
2898        '      name: "baz"'
2899        '      number: 3'
2900        '      label: LABEL_OPTIONAL'
2901        '      type: TYPE_MESSAGE'
2902        '      type_name: "Baz"'
2903        '    }'
2904        '    nested_type {'
2905        '      name: "Baz"'
2906        '      enum_type {'
2907        '        name: "deep_enum"'
2908        '        value {'
2909        '          name: "VALUE_A"'
2910        '          number: 0'
2911        '        }'
2912        '      }'
2913        '      field {'
2914        '        name: "deep"'
2915        '        number: 4'
2916        '        label: LABEL_OPTIONAL'
2917        '        type: TYPE_UINT32'
2918        '      }'
2919        '    }'
2920        '  }'
2921        '}')
2922    file_descriptor = descriptor_pb2.FileDescriptorProto()
2923    text_format.Merge(file_descriptor_str, file_descriptor)
2924    return file_descriptor.SerializeToString()
2925
2926  def testParsingFlatClassWithExplicitClassDeclaration(self):
2927    """Test that the generated class can parse a flat message."""
2928    # TODO(xiaofeng): This test fails with cpp implemetnation in the call
2929    # of six.with_metaclass(). The other two callsites of with_metaclass
2930    # in this file are both excluded from cpp test, so it might be expected
2931    # to fail. Need someone more familiar with the python code to take a
2932    # look at this.
2933    if api_implementation.Type() != 'python':
2934      return
2935    file_descriptor = descriptor_pb2.FileDescriptorProto()
2936    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
2937    msg_descriptor = descriptor.MakeDescriptor(
2938        file_descriptor.message_type[0])
2939
2940    class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
2941      DESCRIPTOR = msg_descriptor
2942    msg = MessageClass()
2943    msg_str = (
2944        'flat: 0 '
2945        'flat: 1 '
2946        'flat: 2 ')
2947    text_format.Merge(msg_str, msg)
2948    self.assertEqual(msg.flat, [0, 1, 2])
2949
2950  def testParsingFlatClass(self):
2951    """Test that the generated class can parse a flat message."""
2952    file_descriptor = descriptor_pb2.FileDescriptorProto()
2953    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
2954    msg_descriptor = descriptor.MakeDescriptor(
2955        file_descriptor.message_type[0])
2956    msg_class = reflection.MakeClass(msg_descriptor)
2957    msg = msg_class()
2958    msg_str = (
2959        'flat: 0 '
2960        'flat: 1 '
2961        'flat: 2 ')
2962    text_format.Merge(msg_str, msg)
2963    self.assertEqual(msg.flat, [0, 1, 2])
2964
2965  def testParsingNestedClass(self):
2966    """Test that the generated class can parse a nested message."""
2967    file_descriptor = descriptor_pb2.FileDescriptorProto()
2968    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
2969    msg_descriptor = descriptor.MakeDescriptor(
2970        file_descriptor.message_type[0])
2971    msg_class = reflection.MakeClass(msg_descriptor)
2972    msg = msg_class()
2973    msg_str = (
2974        'bar {'
2975        '  baz {'
2976        '    deep: 4'
2977        '  }'
2978        '}')
2979    text_format.Merge(msg_str, msg)
2980    self.assertEqual(msg.bar.baz.deep, 4)
2981
2982if __name__ == '__main__':
2983  unittest.main()
2984