• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38import copy
39import gc
40import operator
41import six
42import struct
43
44try:
45  import unittest2 as unittest  #PY26
46except ImportError:
47  import unittest
48
49from google.protobuf import unittest_import_pb2
50from google.protobuf import unittest_mset_pb2
51from google.protobuf import unittest_pb2
52from google.protobuf import descriptor_pb2
53from google.protobuf import descriptor
54from google.protobuf import message
55from google.protobuf import reflection
56from google.protobuf import text_format
57from google.protobuf.internal import api_implementation
58from google.protobuf.internal import more_extensions_pb2
59from google.protobuf.internal import more_messages_pb2
60from google.protobuf.internal import message_set_extensions_pb2
61from google.protobuf.internal import wire_format
62from google.protobuf.internal import test_util
63from google.protobuf.internal import decoder
64
65
66class _MiniDecoder(object):
67  """Decodes a stream of values from a string.
68
69  Once upon a time we actually had a class called decoder.Decoder.  Then we
70  got rid of it during a redesign that made decoding much, much faster overall.
71  But a couple tests in this file used it to check that the serialized form of
72  a message was correct.  So, this class implements just the methods that were
73  used by said tests, so that we don't have to rewrite the tests.
74  """
75
76  def __init__(self, bytes):
77    self._bytes = bytes
78    self._pos = 0
79
80  def ReadVarint(self):
81    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
82    return result
83
84  ReadInt32 = ReadVarint
85  ReadInt64 = ReadVarint
86  ReadUInt32 = ReadVarint
87  ReadUInt64 = ReadVarint
88
89  def ReadSInt64(self):
90    return wire_format.ZigZagDecode(self.ReadVarint())
91
92  ReadSInt32 = ReadSInt64
93
94  def ReadFieldNumberAndWireType(self):
95    return wire_format.UnpackTag(self.ReadVarint())
96
97  def ReadFloat(self):
98    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
99    self._pos += 4
100    return result
101
102  def ReadDouble(self):
103    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
104    self._pos += 8
105    return result
106
107  def EndOfStream(self):
108    return self._pos == len(self._bytes)
109
110
111class ReflectionTest(unittest.TestCase):
112
113  def assertListsEqual(self, values, others):
114    self.assertEqual(len(values), len(others))
115    for i in range(len(values)):
116      self.assertEqual(values[i], others[i])
117
118  def testScalarConstructor(self):
119    # Constructor with only scalar types should succeed.
120    proto = unittest_pb2.TestAllTypes(
121        optional_int32=24,
122        optional_double=54.321,
123        optional_string='optional_string',
124        optional_float=None)
125
126    self.assertEqual(24, proto.optional_int32)
127    self.assertEqual(54.321, proto.optional_double)
128    self.assertEqual('optional_string', proto.optional_string)
129    self.assertFalse(proto.HasField("optional_float"))
130
131  def testRepeatedScalarConstructor(self):
132    # Constructor with only repeated scalar types should succeed.
133    proto = unittest_pb2.TestAllTypes(
134        repeated_int32=[1, 2, 3, 4],
135        repeated_double=[1.23, 54.321],
136        repeated_bool=[True, False, False],
137        repeated_string=["optional_string"],
138        repeated_float=None)
139
140    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
141    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
142    self.assertEqual([True, False, False], list(proto.repeated_bool))
143    self.assertEqual(["optional_string"], list(proto.repeated_string))
144    self.assertEqual([], list(proto.repeated_float))
145
146  def testRepeatedCompositeConstructor(self):
147    # Constructor with only repeated composite types should succeed.
148    proto = unittest_pb2.TestAllTypes(
149        repeated_nested_message=[
150            unittest_pb2.TestAllTypes.NestedMessage(
151                bb=unittest_pb2.TestAllTypes.FOO),
152            unittest_pb2.TestAllTypes.NestedMessage(
153                bb=unittest_pb2.TestAllTypes.BAR)],
154        repeated_foreign_message=[
155            unittest_pb2.ForeignMessage(c=-43),
156            unittest_pb2.ForeignMessage(c=45324),
157            unittest_pb2.ForeignMessage(c=12)],
158        repeatedgroup=[
159            unittest_pb2.TestAllTypes.RepeatedGroup(),
160            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
161            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
162
163    self.assertEqual(
164        [unittest_pb2.TestAllTypes.NestedMessage(
165            bb=unittest_pb2.TestAllTypes.FOO),
166         unittest_pb2.TestAllTypes.NestedMessage(
167             bb=unittest_pb2.TestAllTypes.BAR)],
168        list(proto.repeated_nested_message))
169    self.assertEqual(
170        [unittest_pb2.ForeignMessage(c=-43),
171         unittest_pb2.ForeignMessage(c=45324),
172         unittest_pb2.ForeignMessage(c=12)],
173        list(proto.repeated_foreign_message))
174    self.assertEqual(
175        [unittest_pb2.TestAllTypes.RepeatedGroup(),
176         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
177         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
178        list(proto.repeatedgroup))
179
180  def testMixedConstructor(self):
181    # Constructor with only mixed types should succeed.
182    proto = unittest_pb2.TestAllTypes(
183        optional_int32=24,
184        optional_string='optional_string',
185        repeated_double=[1.23, 54.321],
186        repeated_bool=[True, False, False],
187        repeated_nested_message=[
188            unittest_pb2.TestAllTypes.NestedMessage(
189                bb=unittest_pb2.TestAllTypes.FOO),
190            unittest_pb2.TestAllTypes.NestedMessage(
191                bb=unittest_pb2.TestAllTypes.BAR)],
192        repeated_foreign_message=[
193            unittest_pb2.ForeignMessage(c=-43),
194            unittest_pb2.ForeignMessage(c=45324),
195            unittest_pb2.ForeignMessage(c=12)],
196        optional_nested_message=None)
197
198    self.assertEqual(24, proto.optional_int32)
199    self.assertEqual('optional_string', proto.optional_string)
200    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
201    self.assertEqual([True, False, False], list(proto.repeated_bool))
202    self.assertEqual(
203        [unittest_pb2.TestAllTypes.NestedMessage(
204            bb=unittest_pb2.TestAllTypes.FOO),
205         unittest_pb2.TestAllTypes.NestedMessage(
206             bb=unittest_pb2.TestAllTypes.BAR)],
207        list(proto.repeated_nested_message))
208    self.assertEqual(
209        [unittest_pb2.ForeignMessage(c=-43),
210         unittest_pb2.ForeignMessage(c=45324),
211         unittest_pb2.ForeignMessage(c=12)],
212        list(proto.repeated_foreign_message))
213    self.assertFalse(proto.HasField("optional_nested_message"))
214
215  def testConstructorTypeError(self):
216    self.assertRaises(
217        TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
218    self.assertRaises(
219        TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
220    self.assertRaises(
221        TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
222    self.assertRaises(
223        TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
224    self.assertRaises(
225        TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
226    self.assertRaises(
227        TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
228    self.assertRaises(
229        TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
230    self.assertRaises(
231        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
232    self.assertRaises(
233        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
234
235  def testConstructorInvalidatesCachedByteSize(self):
236    message = unittest_pb2.TestAllTypes(optional_int32 = 12)
237    self.assertEqual(2, message.ByteSize())
238
239    message = unittest_pb2.TestAllTypes(
240        optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
241    self.assertEqual(3, message.ByteSize())
242
243    message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
244    self.assertEqual(3, message.ByteSize())
245
246    message = unittest_pb2.TestAllTypes(
247        repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
248    self.assertEqual(3, message.ByteSize())
249
250  def testSimpleHasBits(self):
251    # Test a scalar.
252    proto = unittest_pb2.TestAllTypes()
253    self.assertTrue(not proto.HasField('optional_int32'))
254    self.assertEqual(0, proto.optional_int32)
255    # HasField() shouldn't be true if all we've done is
256    # read the default value.
257    self.assertTrue(not proto.HasField('optional_int32'))
258    proto.optional_int32 = 1
259    # Setting a value however *should* set the "has" bit.
260    self.assertTrue(proto.HasField('optional_int32'))
261    proto.ClearField('optional_int32')
262    # And clearing that value should unset the "has" bit.
263    self.assertTrue(not proto.HasField('optional_int32'))
264
265  def testHasBitsWithSinglyNestedScalar(self):
266    # Helper used to test foreign messages and groups.
267    #
268    # composite_field_name should be the name of a non-repeated
269    # composite (i.e., foreign or group) field in TestAllTypes,
270    # and scalar_field_name should be the name of an integer-valued
271    # scalar field within that composite.
272    #
273    # I never thought I'd miss C++ macros and templates so much. :(
274    # This helper is semantically just:
275    #
276    #   assert proto.composite_field.scalar_field == 0
277    #   assert not proto.composite_field.HasField('scalar_field')
278    #   assert not proto.HasField('composite_field')
279    #
280    #   proto.composite_field.scalar_field = 10
281    #   old_composite_field = proto.composite_field
282    #
283    #   assert proto.composite_field.scalar_field == 10
284    #   assert proto.composite_field.HasField('scalar_field')
285    #   assert proto.HasField('composite_field')
286    #
287    #   proto.ClearField('composite_field')
288    #
289    #   assert not proto.composite_field.HasField('scalar_field')
290    #   assert not proto.HasField('composite_field')
291    #   assert proto.composite_field.scalar_field == 0
292    #
293    #   # Now ensure that ClearField('composite_field') disconnected
294    #   # the old field object from the object tree...
295    #   assert old_composite_field is not proto.composite_field
296    #   old_composite_field.scalar_field = 20
297    #   assert not proto.composite_field.HasField('scalar_field')
298    #   assert not proto.HasField('composite_field')
299    def TestCompositeHasBits(composite_field_name, scalar_field_name):
300      proto = unittest_pb2.TestAllTypes()
301      # First, check that we can get the scalar value, and see that it's the
302      # default (0), but that proto.HasField('omposite') and
303      # proto.composite.HasField('scalar') will still return False.
304      composite_field = getattr(proto, composite_field_name)
305      original_scalar_value = getattr(composite_field, scalar_field_name)
306      self.assertEqual(0, original_scalar_value)
307      # Assert that the composite object does not "have" the scalar.
308      self.assertTrue(not composite_field.HasField(scalar_field_name))
309      # Assert that proto does not "have" the composite field.
310      self.assertTrue(not proto.HasField(composite_field_name))
311
312      # Now set the scalar within the composite field.  Ensure that the setting
313      # is reflected, and that proto.HasField('composite') and
314      # proto.composite.HasField('scalar') now both return True.
315      new_val = 20
316      setattr(composite_field, scalar_field_name, new_val)
317      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
318      # Hold on to a reference to the current composite_field object.
319      old_composite_field = composite_field
320      # Assert that the has methods now return true.
321      self.assertTrue(composite_field.HasField(scalar_field_name))
322      self.assertTrue(proto.HasField(composite_field_name))
323
324      # Now call the clear method...
325      proto.ClearField(composite_field_name)
326
327      # ...and ensure that the "has" bits are all back to False...
328      composite_field = getattr(proto, composite_field_name)
329      self.assertTrue(not composite_field.HasField(scalar_field_name))
330      self.assertTrue(not proto.HasField(composite_field_name))
331      # ...and ensure that the scalar field has returned to its default.
332      self.assertEqual(0, getattr(composite_field, scalar_field_name))
333
334      self.assertTrue(old_composite_field is not composite_field)
335      setattr(old_composite_field, scalar_field_name, new_val)
336      self.assertTrue(not composite_field.HasField(scalar_field_name))
337      self.assertTrue(not proto.HasField(composite_field_name))
338      self.assertEqual(0, getattr(composite_field, scalar_field_name))
339
340    # Test simple, single-level nesting when we set a scalar.
341    TestCompositeHasBits('optionalgroup', 'a')
342    TestCompositeHasBits('optional_nested_message', 'bb')
343    TestCompositeHasBits('optional_foreign_message', 'c')
344    TestCompositeHasBits('optional_import_message', 'd')
345
346  def testReferencesToNestedMessage(self):
347    proto = unittest_pb2.TestAllTypes()
348    nested = proto.optional_nested_message
349    del proto
350    # A previous version had a bug where this would raise an exception when
351    # hitting a now-dead weak reference.
352    nested.bb = 23
353
354  def testDisconnectingNestedMessageBeforeSettingField(self):
355    proto = unittest_pb2.TestAllTypes()
356    nested = proto.optional_nested_message
357    proto.ClearField('optional_nested_message')  # Should disconnect from parent
358    self.assertTrue(nested is not proto.optional_nested_message)
359    nested.bb = 23
360    self.assertTrue(not proto.HasField('optional_nested_message'))
361    self.assertEqual(0, proto.optional_nested_message.bb)
362
363  def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
364    proto = unittest_pb2.TestAllTypes()
365    nested = proto.optional_nested_message
366    proto.ClearField('optional_nested_message')
367    del proto
368    del nested
369    # Force a garbage collect so that the underlying CMessages are freed along
370    # with the Messages they point to. This is to make sure we're not deleting
371    # default message instances.
372    gc.collect()
373    proto = unittest_pb2.TestAllTypes()
374    nested = proto.optional_nested_message
375
376  def testDisconnectingNestedMessageAfterSettingField(self):
377    proto = unittest_pb2.TestAllTypes()
378    nested = proto.optional_nested_message
379    nested.bb = 5
380    self.assertTrue(proto.HasField('optional_nested_message'))
381    proto.ClearField('optional_nested_message')  # Should disconnect from parent
382    self.assertEqual(5, nested.bb)
383    self.assertEqual(0, proto.optional_nested_message.bb)
384    self.assertTrue(nested is not proto.optional_nested_message)
385    nested.bb = 23
386    self.assertTrue(not proto.HasField('optional_nested_message'))
387    self.assertEqual(0, proto.optional_nested_message.bb)
388
389  def testDisconnectingNestedMessageBeforeGettingField(self):
390    proto = unittest_pb2.TestAllTypes()
391    self.assertTrue(not proto.HasField('optional_nested_message'))
392    proto.ClearField('optional_nested_message')
393    self.assertTrue(not proto.HasField('optional_nested_message'))
394
395  def testDisconnectingNestedMessageAfterMerge(self):
396    # This test exercises the code path that does not use ReleaseMessage().
397    # The underlying fear is that if we use ReleaseMessage() incorrectly,
398    # we will have memory leaks.  It's hard to check that that doesn't happen,
399    # but at least we can exercise that code path to make sure it works.
400    proto1 = unittest_pb2.TestAllTypes()
401    proto2 = unittest_pb2.TestAllTypes()
402    proto2.optional_nested_message.bb = 5
403    proto1.MergeFrom(proto2)
404    self.assertTrue(proto1.HasField('optional_nested_message'))
405    proto1.ClearField('optional_nested_message')
406    self.assertTrue(not proto1.HasField('optional_nested_message'))
407
408  def testDisconnectingLazyNestedMessage(self):
409    # This test exercises releasing a nested message that is lazy. This test
410    # only exercises real code in the C++ implementation as Python does not
411    # support lazy parsing, but the current C++ implementation results in
412    # memory corruption and a crash.
413    if api_implementation.Type() != 'python':
414      return
415    proto = unittest_pb2.TestAllTypes()
416    proto.optional_lazy_message.bb = 5
417    proto.ClearField('optional_lazy_message')
418    del proto
419    gc.collect()
420
421  def testHasBitsWhenModifyingRepeatedFields(self):
422    # Test nesting when we add an element to a repeated field in a submessage.
423    proto = unittest_pb2.TestNestedMessageHasBits()
424    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
425    self.assertEqual(
426        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
427    self.assertTrue(proto.HasField('optional_nested_message'))
428
429    # Do the same test, but with a repeated composite field within the
430    # submessage.
431    proto.ClearField('optional_nested_message')
432    self.assertTrue(not proto.HasField('optional_nested_message'))
433    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
434    self.assertTrue(proto.HasField('optional_nested_message'))
435
436  def testHasBitsForManyLevelsOfNesting(self):
437    # Test nesting many levels deep.
438    recursive_proto = unittest_pb2.TestMutualRecursionA()
439    self.assertTrue(not recursive_proto.HasField('bb'))
440    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
441    self.assertTrue(not recursive_proto.HasField('bb'))
442    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
443    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
444    self.assertTrue(recursive_proto.HasField('bb'))
445    self.assertTrue(recursive_proto.bb.HasField('a'))
446    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
447    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
448    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
449    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
450    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
451
452  def testSingularListFields(self):
453    proto = unittest_pb2.TestAllTypes()
454    proto.optional_fixed32 = 1
455    proto.optional_int32 = 5
456    proto.optional_string = 'foo'
457    # Access sub-message but don't set it yet.
458    nested_message = proto.optional_nested_message
459    self.assertEqual(
460      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
461        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
462        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
463      proto.ListFields())
464
465    proto.optional_nested_message.bb = 123
466    self.assertEqual(
467      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
468        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
469        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
470        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
471             nested_message) ],
472      proto.ListFields())
473
474  def testRepeatedListFields(self):
475    proto = unittest_pb2.TestAllTypes()
476    proto.repeated_fixed32.append(1)
477    proto.repeated_int32.append(5)
478    proto.repeated_int32.append(11)
479    proto.repeated_string.extend(['foo', 'bar'])
480    proto.repeated_string.extend([])
481    proto.repeated_string.append('baz')
482    proto.repeated_string.extend(str(x) for x in range(2))
483    proto.optional_int32 = 21
484    proto.repeated_bool  # Access but don't set anything; should not be listed.
485    self.assertEqual(
486      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
487        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
488        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
489        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
490          ['foo', 'bar', 'baz', '0', '1']) ],
491      proto.ListFields())
492
493  def testSingularListExtensions(self):
494    proto = unittest_pb2.TestAllExtensions()
495    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
496    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
497    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
498    self.assertEqual(
499      [ (unittest_pb2.optional_int32_extension  , 5),
500        (unittest_pb2.optional_fixed32_extension, 1),
501        (unittest_pb2.optional_string_extension , 'foo') ],
502      proto.ListFields())
503
504  def testRepeatedListExtensions(self):
505    proto = unittest_pb2.TestAllExtensions()
506    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
507    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
508    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
509    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
510    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
511    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
512    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
513    self.assertEqual(
514      [ (unittest_pb2.optional_int32_extension  , 21),
515        (unittest_pb2.repeated_int32_extension  , [5, 11]),
516        (unittest_pb2.repeated_fixed32_extension, [1]),
517        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
518      proto.ListFields())
519
520  def testListFieldsAndExtensions(self):
521    proto = unittest_pb2.TestFieldOrderings()
522    test_util.SetAllFieldsAndExtensions(proto)
523    unittest_pb2.my_extension_int
524    self.assertEqual(
525      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
526        (unittest_pb2.my_extension_int               , 23),
527        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
528        (unittest_pb2.my_extension_string            , 'bar'),
529        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
530      proto.ListFields())
531
532  def testDefaultValues(self):
533    proto = unittest_pb2.TestAllTypes()
534    self.assertEqual(0, proto.optional_int32)
535    self.assertEqual(0, proto.optional_int64)
536    self.assertEqual(0, proto.optional_uint32)
537    self.assertEqual(0, proto.optional_uint64)
538    self.assertEqual(0, proto.optional_sint32)
539    self.assertEqual(0, proto.optional_sint64)
540    self.assertEqual(0, proto.optional_fixed32)
541    self.assertEqual(0, proto.optional_fixed64)
542    self.assertEqual(0, proto.optional_sfixed32)
543    self.assertEqual(0, proto.optional_sfixed64)
544    self.assertEqual(0.0, proto.optional_float)
545    self.assertEqual(0.0, proto.optional_double)
546    self.assertEqual(False, proto.optional_bool)
547    self.assertEqual('', proto.optional_string)
548    self.assertEqual(b'', proto.optional_bytes)
549
550    self.assertEqual(41, proto.default_int32)
551    self.assertEqual(42, proto.default_int64)
552    self.assertEqual(43, proto.default_uint32)
553    self.assertEqual(44, proto.default_uint64)
554    self.assertEqual(-45, proto.default_sint32)
555    self.assertEqual(46, proto.default_sint64)
556    self.assertEqual(47, proto.default_fixed32)
557    self.assertEqual(48, proto.default_fixed64)
558    self.assertEqual(49, proto.default_sfixed32)
559    self.assertEqual(-50, proto.default_sfixed64)
560    self.assertEqual(51.5, proto.default_float)
561    self.assertEqual(52e3, proto.default_double)
562    self.assertEqual(True, proto.default_bool)
563    self.assertEqual('hello', proto.default_string)
564    self.assertEqual(b'world', proto.default_bytes)
565    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
566    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
567    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
568                     proto.default_import_enum)
569
570    proto = unittest_pb2.TestExtremeDefaultValues()
571    self.assertEqual(u'\u1234', proto.utf8_string)
572
573  def testHasFieldWithUnknownFieldName(self):
574    proto = unittest_pb2.TestAllTypes()
575    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
576
577  def testClearFieldWithUnknownFieldName(self):
578    proto = unittest_pb2.TestAllTypes()
579    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
580
581  def testClearRemovesChildren(self):
582    # Make sure there aren't any implementation bugs that are only partially
583    # clearing the message (which can happen in the more complex C++
584    # implementation which has parallel message lists).
585    proto = unittest_pb2.TestRequiredForeign()
586    for i in range(10):
587      proto.repeated_message.add()
588    proto2 = unittest_pb2.TestRequiredForeign()
589    proto.CopyFrom(proto2)
590    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
591
592  def testDisallowedAssignments(self):
593    # It's illegal to assign values directly to repeated fields
594    # or to nonrepeated composite fields.  Ensure that this fails.
595    proto = unittest_pb2.TestAllTypes()
596    # Repeated fields.
597    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
598    # Lists shouldn't work, either.
599    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
600    # Composite fields.
601    self.assertRaises(AttributeError, setattr, proto,
602                      'optional_nested_message', 23)
603    # Assignment to a repeated nested message field without specifying
604    # the index in the array of nested messages.
605    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
606                      'bb', 34)
607    # Assignment to an attribute of a repeated field.
608    self.assertRaises(AttributeError, setattr, proto.repeated_float,
609                      'some_attribute', 34)
610    # proto.nonexistent_field = 23 should fail as well.
611    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
612
613  def testSingleScalarTypeSafety(self):
614    proto = unittest_pb2.TestAllTypes()
615    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
616    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
617    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
618    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
619
620  def testIntegerTypes(self):
621    def TestGetAndDeserialize(field_name, value, expected_type):
622      proto = unittest_pb2.TestAllTypes()
623      setattr(proto, field_name, value)
624      self.assertIsInstance(getattr(proto, field_name), expected_type)
625      proto2 = unittest_pb2.TestAllTypes()
626      proto2.ParseFromString(proto.SerializeToString())
627      self.assertIsInstance(getattr(proto2, field_name), expected_type)
628
629    TestGetAndDeserialize('optional_int32', 1, int)
630    TestGetAndDeserialize('optional_int32', 1 << 30, int)
631    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
632    try:
633      integer_64 = long
634    except NameError: # Python3
635      integer_64 = int
636    if struct.calcsize('L') == 4:
637      # Python only has signed ints, so 32-bit python can't fit an uint32
638      # in an int.
639      TestGetAndDeserialize('optional_uint32', 1 << 31, long)
640    else:
641      # 64-bit python can fit uint32 inside an int
642      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
643    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
644    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
645    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
646    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
647
648  def testSingleScalarBoundsChecking(self):
649    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
650      pb = unittest_pb2.TestAllTypes()
651      setattr(pb, field_name, expected_min)
652      self.assertEqual(expected_min, getattr(pb, field_name))
653      setattr(pb, field_name, expected_max)
654      self.assertEqual(expected_max, getattr(pb, field_name))
655      self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
656      self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
657
658    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
659    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
660    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
661    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
662
663    pb = unittest_pb2.TestAllTypes()
664    pb.optional_nested_enum = 1
665    self.assertEqual(1, pb.optional_nested_enum)
666
667  def testRepeatedScalarTypeSafety(self):
668    proto = unittest_pb2.TestAllTypes()
669    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
670    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
671    self.assertRaises(TypeError, proto.repeated_string, 10)
672    self.assertRaises(TypeError, proto.repeated_bytes, 10)
673
674    proto.repeated_int32.append(10)
675    proto.repeated_int32[0] = 23
676    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
677    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
678
679    # Repeated enums tests.
680    #proto.repeated_nested_enum.append(0)
681
682  def testSingleScalarGettersAndSetters(self):
683    proto = unittest_pb2.TestAllTypes()
684    self.assertEqual(0, proto.optional_int32)
685    proto.optional_int32 = 1
686    self.assertEqual(1, proto.optional_int32)
687
688    proto.optional_uint64 = 0xffffffffffff
689    self.assertEqual(0xffffffffffff, proto.optional_uint64)
690    proto.optional_uint64 = 0xffffffffffffffff
691    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
692    # TODO(robinson): Test all other scalar field types.
693
694  def testSingleScalarClearField(self):
695    proto = unittest_pb2.TestAllTypes()
696    # Should be allowed to clear something that's not there (a no-op).
697    proto.ClearField('optional_int32')
698    proto.optional_int32 = 1
699    self.assertTrue(proto.HasField('optional_int32'))
700    proto.ClearField('optional_int32')
701    self.assertEqual(0, proto.optional_int32)
702    self.assertTrue(not proto.HasField('optional_int32'))
703    # TODO(robinson): Test all other scalar field types.
704
705  def testEnums(self):
706    proto = unittest_pb2.TestAllTypes()
707    self.assertEqual(1, proto.FOO)
708    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
709    self.assertEqual(2, proto.BAR)
710    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
711    self.assertEqual(3, proto.BAZ)
712    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
713
714  def testEnum_Name(self):
715    self.assertEqual('FOREIGN_FOO',
716                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
717    self.assertEqual('FOREIGN_BAR',
718                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
719    self.assertEqual('FOREIGN_BAZ',
720                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
721    self.assertRaises(ValueError,
722                      unittest_pb2.ForeignEnum.Name, 11312)
723
724    proto = unittest_pb2.TestAllTypes()
725    self.assertEqual('FOO',
726                     proto.NestedEnum.Name(proto.FOO))
727    self.assertEqual('FOO',
728                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
729    self.assertEqual('BAR',
730                     proto.NestedEnum.Name(proto.BAR))
731    self.assertEqual('BAR',
732                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
733    self.assertEqual('BAZ',
734                     proto.NestedEnum.Name(proto.BAZ))
735    self.assertEqual('BAZ',
736                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
737    self.assertRaises(ValueError,
738                      proto.NestedEnum.Name, 11312)
739    self.assertRaises(ValueError,
740                      unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
741
742  def testEnum_Value(self):
743    self.assertEqual(unittest_pb2.FOREIGN_FOO,
744                     unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
745    self.assertEqual(unittest_pb2.FOREIGN_BAR,
746                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
747    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
748                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
749    self.assertRaises(ValueError,
750                      unittest_pb2.ForeignEnum.Value, 'FO')
751
752    proto = unittest_pb2.TestAllTypes()
753    self.assertEqual(proto.FOO,
754                     proto.NestedEnum.Value('FOO'))
755    self.assertEqual(proto.FOO,
756                     unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
757    self.assertEqual(proto.BAR,
758                     proto.NestedEnum.Value('BAR'))
759    self.assertEqual(proto.BAR,
760                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
761    self.assertEqual(proto.BAZ,
762                     proto.NestedEnum.Value('BAZ'))
763    self.assertEqual(proto.BAZ,
764                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
765    self.assertRaises(ValueError,
766                      proto.NestedEnum.Value, 'Foo')
767    self.assertRaises(ValueError,
768                      unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
769
770  def testEnum_KeysAndValues(self):
771    self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
772                     list(unittest_pb2.ForeignEnum.keys()))
773    self.assertEqual([4, 5, 6],
774                     list(unittest_pb2.ForeignEnum.values()))
775    self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
776                      ('FOREIGN_BAZ', 6)],
777                     list(unittest_pb2.ForeignEnum.items()))
778
779    proto = unittest_pb2.TestAllTypes()
780    self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys()))
781    self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values()))
782    self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
783                     list(proto.NestedEnum.items()))
784
785  def testRepeatedScalars(self):
786    proto = unittest_pb2.TestAllTypes()
787
788    self.assertTrue(not proto.repeated_int32)
789    self.assertEqual(0, len(proto.repeated_int32))
790    proto.repeated_int32.append(5)
791    proto.repeated_int32.append(10)
792    proto.repeated_int32.append(15)
793    self.assertTrue(proto.repeated_int32)
794    self.assertEqual(3, len(proto.repeated_int32))
795
796    self.assertEqual([5, 10, 15], proto.repeated_int32)
797
798    # Test single retrieval.
799    self.assertEqual(5, proto.repeated_int32[0])
800    self.assertEqual(15, proto.repeated_int32[-1])
801    # Test out-of-bounds indices.
802    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
803    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
804    # Test incorrect types passed to __getitem__.
805    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
806    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
807
808    # Test single assignment.
809    proto.repeated_int32[1] = 20
810    self.assertEqual([5, 20, 15], proto.repeated_int32)
811
812    # Test insertion.
813    proto.repeated_int32.insert(1, 25)
814    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
815
816    # Test slice retrieval.
817    proto.repeated_int32.append(30)
818    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
819    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
820
821    # Test slice assignment with an iterator
822    proto.repeated_int32[1:4] = (i for i in range(3))
823    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
824
825    # Test slice assignment.
826    proto.repeated_int32[1:4] = [35, 40, 45]
827    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
828
829    # Test that we can use the field as an iterator.
830    result = []
831    for i in proto.repeated_int32:
832      result.append(i)
833    self.assertEqual([5, 35, 40, 45, 30], result)
834
835    # Test single deletion.
836    del proto.repeated_int32[2]
837    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
838
839    # Test slice deletion.
840    del proto.repeated_int32[2:]
841    self.assertEqual([5, 35], proto.repeated_int32)
842
843    # Test extending.
844    proto.repeated_int32.extend([3, 13])
845    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
846
847    # Test clearing.
848    proto.ClearField('repeated_int32')
849    self.assertTrue(not proto.repeated_int32)
850    self.assertEqual(0, len(proto.repeated_int32))
851
852    proto.repeated_int32.append(1)
853    self.assertEqual(1, proto.repeated_int32[-1])
854    # Test assignment to a negative index.
855    proto.repeated_int32[-1] = 2
856    self.assertEqual(2, proto.repeated_int32[-1])
857
858    # Test deletion at negative indices.
859    proto.repeated_int32[:] = [0, 1, 2, 3]
860    del proto.repeated_int32[-1]
861    self.assertEqual([0, 1, 2], proto.repeated_int32)
862
863    del proto.repeated_int32[-2]
864    self.assertEqual([0, 2], proto.repeated_int32)
865
866    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
867    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
868
869    del proto.repeated_int32[-2:-1]
870    self.assertEqual([2], proto.repeated_int32)
871
872    del proto.repeated_int32[100:10000]
873    self.assertEqual([2], proto.repeated_int32)
874
875  def testRepeatedScalarsRemove(self):
876    proto = unittest_pb2.TestAllTypes()
877
878    self.assertTrue(not proto.repeated_int32)
879    self.assertEqual(0, len(proto.repeated_int32))
880    proto.repeated_int32.append(5)
881    proto.repeated_int32.append(10)
882    proto.repeated_int32.append(5)
883    proto.repeated_int32.append(5)
884
885    self.assertEqual(4, len(proto.repeated_int32))
886    proto.repeated_int32.remove(5)
887    self.assertEqual(3, len(proto.repeated_int32))
888    self.assertEqual(10, proto.repeated_int32[0])
889    self.assertEqual(5, proto.repeated_int32[1])
890    self.assertEqual(5, proto.repeated_int32[2])
891
892    proto.repeated_int32.remove(5)
893    self.assertEqual(2, len(proto.repeated_int32))
894    self.assertEqual(10, proto.repeated_int32[0])
895    self.assertEqual(5, proto.repeated_int32[1])
896
897    proto.repeated_int32.remove(10)
898    self.assertEqual(1, len(proto.repeated_int32))
899    self.assertEqual(5, proto.repeated_int32[0])
900
901    # Remove a non-existent element.
902    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
903
904  def testRepeatedComposites(self):
905    proto = unittest_pb2.TestAllTypes()
906    self.assertTrue(not proto.repeated_nested_message)
907    self.assertEqual(0, len(proto.repeated_nested_message))
908    m0 = proto.repeated_nested_message.add()
909    m1 = proto.repeated_nested_message.add()
910    self.assertTrue(proto.repeated_nested_message)
911    self.assertEqual(2, len(proto.repeated_nested_message))
912    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
913    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
914
915    # Test out-of-bounds indices.
916    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
917                      1234)
918    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
919                      -1234)
920
921    # Test incorrect types passed to __getitem__.
922    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
923                      'foo')
924    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
925                      None)
926
927    # Test slice retrieval.
928    m2 = proto.repeated_nested_message.add()
929    m3 = proto.repeated_nested_message.add()
930    m4 = proto.repeated_nested_message.add()
931    self.assertListsEqual(
932        [m1, m2, m3], proto.repeated_nested_message[1:4])
933    self.assertListsEqual(
934        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
935    self.assertListsEqual(
936        [m0, m1], proto.repeated_nested_message[:2])
937    self.assertListsEqual(
938        [m2, m3, m4], proto.repeated_nested_message[2:])
939    self.assertEqual(
940        m0, proto.repeated_nested_message[0])
941    self.assertListsEqual(
942        [m0], proto.repeated_nested_message[:1])
943
944    # Test that we can use the field as an iterator.
945    result = []
946    for i in proto.repeated_nested_message:
947      result.append(i)
948    self.assertListsEqual([m0, m1, m2, m3, m4], result)
949
950    # Test single deletion.
951    del proto.repeated_nested_message[2]
952    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
953
954    # Test slice deletion.
955    del proto.repeated_nested_message[2:]
956    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
957
958    # Test extending.
959    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
960    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
961    proto.repeated_nested_message.extend([n1,n2])
962    self.assertEqual(4, len(proto.repeated_nested_message))
963    self.assertEqual(n1, proto.repeated_nested_message[2])
964    self.assertEqual(n2, proto.repeated_nested_message[3])
965
966    # Test clearing.
967    proto.ClearField('repeated_nested_message')
968    self.assertTrue(not proto.repeated_nested_message)
969    self.assertEqual(0, len(proto.repeated_nested_message))
970
971    # Test constructing an element while adding it.
972    proto.repeated_nested_message.add(bb=23)
973    self.assertEqual(1, len(proto.repeated_nested_message))
974    self.assertEqual(23, proto.repeated_nested_message[0].bb)
975    self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
976
977  def testRepeatedCompositeRemove(self):
978    proto = unittest_pb2.TestAllTypes()
979
980    self.assertEqual(0, len(proto.repeated_nested_message))
981    m0 = proto.repeated_nested_message.add()
982    # Need to set some differentiating variable so m0 != m1 != m2:
983    m0.bb = len(proto.repeated_nested_message)
984    m1 = proto.repeated_nested_message.add()
985    m1.bb = len(proto.repeated_nested_message)
986    self.assertTrue(m0 != m1)
987    m2 = proto.repeated_nested_message.add()
988    m2.bb = len(proto.repeated_nested_message)
989    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
990
991    self.assertEqual(3, len(proto.repeated_nested_message))
992    proto.repeated_nested_message.remove(m0)
993    self.assertEqual(2, len(proto.repeated_nested_message))
994    self.assertEqual(m1, proto.repeated_nested_message[0])
995    self.assertEqual(m2, proto.repeated_nested_message[1])
996
997    # Removing m0 again or removing None should raise error
998    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
999    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
1000    self.assertEqual(2, len(proto.repeated_nested_message))
1001
1002    proto.repeated_nested_message.remove(m2)
1003    self.assertEqual(1, len(proto.repeated_nested_message))
1004    self.assertEqual(m1, proto.repeated_nested_message[0])
1005
1006  def testHandWrittenReflection(self):
1007    # Hand written extensions are only supported by the pure-Python
1008    # implementation of the API.
1009    if api_implementation.Type() != 'python':
1010      return
1011
1012    FieldDescriptor = descriptor.FieldDescriptor
1013    foo_field_descriptor = FieldDescriptor(
1014        name='foo_field', full_name='MyProto.foo_field',
1015        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1016        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1017        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1018        containing_type=None, message_type=None, enum_type=None,
1019        is_extension=False, extension_scope=None,
1020        options=descriptor_pb2.FieldOptions())
1021    mydescriptor = descriptor.Descriptor(
1022        name='MyProto', full_name='MyProto', filename='ignored',
1023        containing_type=None, nested_types=[], enum_types=[],
1024        fields=[foo_field_descriptor], extensions=[],
1025        options=descriptor_pb2.MessageOptions())
1026    class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1027      DESCRIPTOR = mydescriptor
1028    myproto_instance = MyProtoClass()
1029    self.assertEqual(0, myproto_instance.foo_field)
1030    self.assertTrue(not myproto_instance.HasField('foo_field'))
1031    myproto_instance.foo_field = 23
1032    self.assertEqual(23, myproto_instance.foo_field)
1033    self.assertTrue(myproto_instance.HasField('foo_field'))
1034
1035  def testDescriptorProtoSupport(self):
1036    # Hand written descriptors/reflection are only supported by the pure-Python
1037    # implementation of the API.
1038    if api_implementation.Type() != 'python':
1039      return
1040
1041    def AddDescriptorField(proto, field_name, field_type):
1042      AddDescriptorField.field_index += 1
1043      new_field = proto.field.add()
1044      new_field.name = field_name
1045      new_field.type = field_type
1046      new_field.number = AddDescriptorField.field_index
1047      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1048
1049    AddDescriptorField.field_index = 0
1050
1051    desc_proto = descriptor_pb2.DescriptorProto()
1052    desc_proto.name = 'Car'
1053    fdp = descriptor_pb2.FieldDescriptorProto
1054    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1055    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1056    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1057    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1058    # Add a repeated field
1059    AddDescriptorField.field_index += 1
1060    new_field = desc_proto.field.add()
1061    new_field.name = 'owners'
1062    new_field.type = fdp.TYPE_STRING
1063    new_field.number = AddDescriptorField.field_index
1064    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1065
1066    desc = descriptor.MakeDescriptor(desc_proto)
1067    self.assertTrue('name' in desc.fields_by_name)
1068    self.assertTrue('year' in desc.fields_by_name)
1069    self.assertTrue('automatic' in desc.fields_by_name)
1070    self.assertTrue('price' in desc.fields_by_name)
1071    self.assertTrue('owners' in desc.fields_by_name)
1072
1073    class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
1074      DESCRIPTOR = desc
1075
1076    prius = CarMessage()
1077    prius.name = 'prius'
1078    prius.year = 2010
1079    prius.automatic = True
1080    prius.price = 25134.75
1081    prius.owners.extend(['bob', 'susan'])
1082
1083    serialized_prius = prius.SerializeToString()
1084    new_prius = reflection.ParseMessage(desc, serialized_prius)
1085    self.assertTrue(new_prius is not prius)
1086    self.assertEqual(prius, new_prius)
1087
1088    # these are unnecessary assuming message equality works as advertised but
1089    # explicitly check to be safe since we're mucking about in metaclass foo
1090    self.assertEqual(prius.name, new_prius.name)
1091    self.assertEqual(prius.year, new_prius.year)
1092    self.assertEqual(prius.automatic, new_prius.automatic)
1093    self.assertEqual(prius.price, new_prius.price)
1094    self.assertEqual(prius.owners, new_prius.owners)
1095
1096  def testTopLevelExtensionsForOptionalScalar(self):
1097    extendee_proto = unittest_pb2.TestAllExtensions()
1098    extension = unittest_pb2.optional_int32_extension
1099    self.assertTrue(not extendee_proto.HasExtension(extension))
1100    self.assertEqual(0, extendee_proto.Extensions[extension])
1101    # As with normal scalar fields, just doing a read doesn't actually set the
1102    # "has" bit.
1103    self.assertTrue(not extendee_proto.HasExtension(extension))
1104    # Actually set the thing.
1105    extendee_proto.Extensions[extension] = 23
1106    self.assertEqual(23, extendee_proto.Extensions[extension])
1107    self.assertTrue(extendee_proto.HasExtension(extension))
1108    # Ensure that clearing works as well.
1109    extendee_proto.ClearExtension(extension)
1110    self.assertEqual(0, extendee_proto.Extensions[extension])
1111    self.assertTrue(not extendee_proto.HasExtension(extension))
1112
1113  def testTopLevelExtensionsForRepeatedScalar(self):
1114    extendee_proto = unittest_pb2.TestAllExtensions()
1115    extension = unittest_pb2.repeated_string_extension
1116    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1117    extendee_proto.Extensions[extension].append('foo')
1118    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1119    string_list = extendee_proto.Extensions[extension]
1120    extendee_proto.ClearExtension(extension)
1121    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1122    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
1123    # Shouldn't be allowed to do Extensions[extension] = 'a'
1124    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1125                      extension, 'a')
1126
1127  def testTopLevelExtensionsForOptionalMessage(self):
1128    extendee_proto = unittest_pb2.TestAllExtensions()
1129    extension = unittest_pb2.optional_foreign_message_extension
1130    self.assertTrue(not extendee_proto.HasExtension(extension))
1131    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1132    # As with normal (non-extension) fields, merely reading from the
1133    # thing shouldn't set the "has" bit.
1134    self.assertTrue(not extendee_proto.HasExtension(extension))
1135    extendee_proto.Extensions[extension].c = 23
1136    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1137    self.assertTrue(extendee_proto.HasExtension(extension))
1138    # Save a reference here.
1139    foreign_message = extendee_proto.Extensions[extension]
1140    extendee_proto.ClearExtension(extension)
1141    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
1142    # Setting a field on foreign_message now shouldn't set
1143    # any "has" bits on extendee_proto.
1144    foreign_message.c = 42
1145    self.assertEqual(42, foreign_message.c)
1146    self.assertTrue(foreign_message.HasField('c'))
1147    self.assertTrue(not extendee_proto.HasExtension(extension))
1148    # Shouldn't be allowed to do Extensions[extension] = 'a'
1149    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1150                      extension, 'a')
1151
1152  def testTopLevelExtensionsForRepeatedMessage(self):
1153    extendee_proto = unittest_pb2.TestAllExtensions()
1154    extension = unittest_pb2.repeatedgroup_extension
1155    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1156    group = extendee_proto.Extensions[extension].add()
1157    group.a = 23
1158    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1159    group.a = 42
1160    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1161    group_list = extendee_proto.Extensions[extension]
1162    extendee_proto.ClearExtension(extension)
1163    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1164    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
1165    # Shouldn't be allowed to do Extensions[extension] = 'a'
1166    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1167                      extension, 'a')
1168
1169  def testNestedExtensions(self):
1170    extendee_proto = unittest_pb2.TestAllExtensions()
1171    extension = unittest_pb2.TestRequired.single
1172
1173    # We just test the non-repeated case.
1174    self.assertTrue(not extendee_proto.HasExtension(extension))
1175    required = extendee_proto.Extensions[extension]
1176    self.assertEqual(0, required.a)
1177    self.assertTrue(not extendee_proto.HasExtension(extension))
1178    required.a = 23
1179    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1180    self.assertTrue(extendee_proto.HasExtension(extension))
1181    extendee_proto.ClearExtension(extension)
1182    self.assertTrue(required is not extendee_proto.Extensions[extension])
1183    self.assertTrue(not extendee_proto.HasExtension(extension))
1184
1185  def testRegisteredExtensions(self):
1186    self.assertTrue('protobuf_unittest.optional_int32_extension' in
1187                    unittest_pb2.TestAllExtensions._extensions_by_name)
1188    self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
1189    # Make sure extensions haven't been registered into types that shouldn't
1190    # have any.
1191    self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
1192
1193  # If message A directly contains message B, and
1194  # a.HasField('b') is currently False, then mutating any
1195  # extension in B should change a.HasField('b') to True
1196  # (and so on up the object tree).
1197  def testHasBitsForAncestorsOfExtendedMessage(self):
1198    # Optional scalar extension.
1199    toplevel = more_extensions_pb2.TopLevelMessage()
1200    self.assertTrue(not toplevel.HasField('submessage'))
1201    self.assertEqual(0, toplevel.submessage.Extensions[
1202        more_extensions_pb2.optional_int_extension])
1203    self.assertTrue(not toplevel.HasField('submessage'))
1204    toplevel.submessage.Extensions[
1205        more_extensions_pb2.optional_int_extension] = 23
1206    self.assertEqual(23, toplevel.submessage.Extensions[
1207        more_extensions_pb2.optional_int_extension])
1208    self.assertTrue(toplevel.HasField('submessage'))
1209
1210    # Repeated scalar extension.
1211    toplevel = more_extensions_pb2.TopLevelMessage()
1212    self.assertTrue(not toplevel.HasField('submessage'))
1213    self.assertEqual([], toplevel.submessage.Extensions[
1214        more_extensions_pb2.repeated_int_extension])
1215    self.assertTrue(not toplevel.HasField('submessage'))
1216    toplevel.submessage.Extensions[
1217        more_extensions_pb2.repeated_int_extension].append(23)
1218    self.assertEqual([23], toplevel.submessage.Extensions[
1219        more_extensions_pb2.repeated_int_extension])
1220    self.assertTrue(toplevel.HasField('submessage'))
1221
1222    # Optional message extension.
1223    toplevel = more_extensions_pb2.TopLevelMessage()
1224    self.assertTrue(not toplevel.HasField('submessage'))
1225    self.assertEqual(0, toplevel.submessage.Extensions[
1226        more_extensions_pb2.optional_message_extension].foreign_message_int)
1227    self.assertTrue(not toplevel.HasField('submessage'))
1228    toplevel.submessage.Extensions[
1229        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1230    self.assertEqual(23, toplevel.submessage.Extensions[
1231        more_extensions_pb2.optional_message_extension].foreign_message_int)
1232    self.assertTrue(toplevel.HasField('submessage'))
1233
1234    # Repeated message extension.
1235    toplevel = more_extensions_pb2.TopLevelMessage()
1236    self.assertTrue(not toplevel.HasField('submessage'))
1237    self.assertEqual(0, len(toplevel.submessage.Extensions[
1238        more_extensions_pb2.repeated_message_extension]))
1239    self.assertTrue(not toplevel.HasField('submessage'))
1240    foreign = toplevel.submessage.Extensions[
1241        more_extensions_pb2.repeated_message_extension].add()
1242    self.assertEqual(foreign, toplevel.submessage.Extensions[
1243        more_extensions_pb2.repeated_message_extension][0])
1244    self.assertTrue(toplevel.HasField('submessage'))
1245
1246  def testDisconnectionAfterClearingEmptyMessage(self):
1247    toplevel = more_extensions_pb2.TopLevelMessage()
1248    extendee_proto = toplevel.submessage
1249    extension = more_extensions_pb2.optional_message_extension
1250    extension_proto = extendee_proto.Extensions[extension]
1251    extendee_proto.ClearExtension(extension)
1252    extension_proto.foreign_message_int = 23
1253
1254    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
1255
1256  def testExtensionFailureModes(self):
1257    extendee_proto = unittest_pb2.TestAllExtensions()
1258
1259    # Try non-extension-handle arguments to HasExtension,
1260    # ClearExtension(), and Extensions[]...
1261    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1262    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1263    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1264    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1265
1266    # Try something that *is* an extension handle, just not for
1267    # this message...
1268    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1269                           more_extensions_pb2.optional_message_extension,
1270                           more_extensions_pb2.repeated_int_extension,
1271                           more_extensions_pb2.repeated_message_extension):
1272      self.assertRaises(KeyError, extendee_proto.HasExtension,
1273                        unknown_handle)
1274      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1275                        unknown_handle)
1276      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1277                        unknown_handle)
1278      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1279                        unknown_handle, 5)
1280
1281    # Try call HasExtension() with a valid handle, but for a
1282    # *repeated* field.  (Just as with non-extension repeated
1283    # fields, Has*() isn't supported for extension repeated fields).
1284    self.assertRaises(KeyError, extendee_proto.HasExtension,
1285                      unittest_pb2.repeated_string_extension)
1286
1287  def testStaticParseFrom(self):
1288    proto1 = unittest_pb2.TestAllTypes()
1289    test_util.SetAllFields(proto1)
1290
1291    string1 = proto1.SerializeToString()
1292    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
1293
1294    # Messages should be equal.
1295    self.assertEqual(proto2, proto1)
1296
1297  def testMergeFromSingularField(self):
1298    # Test merge with just a singular field.
1299    proto1 = unittest_pb2.TestAllTypes()
1300    proto1.optional_int32 = 1
1301
1302    proto2 = unittest_pb2.TestAllTypes()
1303    # This shouldn't get overwritten.
1304    proto2.optional_string = 'value'
1305
1306    proto2.MergeFrom(proto1)
1307    self.assertEqual(1, proto2.optional_int32)
1308    self.assertEqual('value', proto2.optional_string)
1309
1310  def testMergeFromRepeatedField(self):
1311    # Test merge with just a repeated field.
1312    proto1 = unittest_pb2.TestAllTypes()
1313    proto1.repeated_int32.append(1)
1314    proto1.repeated_int32.append(2)
1315
1316    proto2 = unittest_pb2.TestAllTypes()
1317    proto2.repeated_int32.append(0)
1318    proto2.MergeFrom(proto1)
1319
1320    self.assertEqual(0, proto2.repeated_int32[0])
1321    self.assertEqual(1, proto2.repeated_int32[1])
1322    self.assertEqual(2, proto2.repeated_int32[2])
1323
1324  def testMergeFromOptionalGroup(self):
1325    # Test merge with an optional group.
1326    proto1 = unittest_pb2.TestAllTypes()
1327    proto1.optionalgroup.a = 12
1328    proto2 = unittest_pb2.TestAllTypes()
1329    proto2.MergeFrom(proto1)
1330    self.assertEqual(12, proto2.optionalgroup.a)
1331
1332  def testMergeFromRepeatedNestedMessage(self):
1333    # Test merge with a repeated nested message.
1334    proto1 = unittest_pb2.TestAllTypes()
1335    m = proto1.repeated_nested_message.add()
1336    m.bb = 123
1337    m = proto1.repeated_nested_message.add()
1338    m.bb = 321
1339
1340    proto2 = unittest_pb2.TestAllTypes()
1341    m = proto2.repeated_nested_message.add()
1342    m.bb = 999
1343    proto2.MergeFrom(proto1)
1344    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
1345    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
1346    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
1347
1348    proto3 = unittest_pb2.TestAllTypes()
1349    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
1350    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
1351    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
1352    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
1353
1354  def testMergeFromAllFields(self):
1355    # With all fields set.
1356    proto1 = unittest_pb2.TestAllTypes()
1357    test_util.SetAllFields(proto1)
1358    proto2 = unittest_pb2.TestAllTypes()
1359    proto2.MergeFrom(proto1)
1360
1361    # Messages should be equal.
1362    self.assertEqual(proto2, proto1)
1363
1364    # Serialized string should be equal too.
1365    string1 = proto1.SerializeToString()
1366    string2 = proto2.SerializeToString()
1367    self.assertEqual(string1, string2)
1368
1369  def testMergeFromExtensionsSingular(self):
1370    proto1 = unittest_pb2.TestAllExtensions()
1371    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1372
1373    proto2 = unittest_pb2.TestAllExtensions()
1374    proto2.MergeFrom(proto1)
1375    self.assertEqual(
1376        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1377
1378  def testMergeFromExtensionsRepeated(self):
1379    proto1 = unittest_pb2.TestAllExtensions()
1380    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1381    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1382
1383    proto2 = unittest_pb2.TestAllExtensions()
1384    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1385    proto2.MergeFrom(proto1)
1386    self.assertEqual(
1387        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1388    self.assertEqual(
1389        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1390    self.assertEqual(
1391        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1392    self.assertEqual(
1393        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1394
1395  def testMergeFromExtensionsNestedMessage(self):
1396    proto1 = unittest_pb2.TestAllExtensions()
1397    ext1 = proto1.Extensions[
1398        unittest_pb2.repeated_nested_message_extension]
1399    m = ext1.add()
1400    m.bb = 222
1401    m = ext1.add()
1402    m.bb = 333
1403
1404    proto2 = unittest_pb2.TestAllExtensions()
1405    ext2 = proto2.Extensions[
1406        unittest_pb2.repeated_nested_message_extension]
1407    m = ext2.add()
1408    m.bb = 111
1409
1410    proto2.MergeFrom(proto1)
1411    ext2 = proto2.Extensions[
1412        unittest_pb2.repeated_nested_message_extension]
1413    self.assertEqual(3, len(ext2))
1414    self.assertEqual(111, ext2[0].bb)
1415    self.assertEqual(222, ext2[1].bb)
1416    self.assertEqual(333, ext2[2].bb)
1417
1418  def testMergeFromBug(self):
1419    message1 = unittest_pb2.TestAllTypes()
1420    message2 = unittest_pb2.TestAllTypes()
1421
1422    # Cause optional_nested_message to be instantiated within message1, even
1423    # though it is not considered to be "present".
1424    message1.optional_nested_message
1425    self.assertFalse(message1.HasField('optional_nested_message'))
1426
1427    # Merge into message2.  This should not instantiate the field is message2.
1428    message2.MergeFrom(message1)
1429    self.assertFalse(message2.HasField('optional_nested_message'))
1430
1431  def testCopyFromSingularField(self):
1432    # Test copy with just a singular field.
1433    proto1 = unittest_pb2.TestAllTypes()
1434    proto1.optional_int32 = 1
1435    proto1.optional_string = 'important-text'
1436
1437    proto2 = unittest_pb2.TestAllTypes()
1438    proto2.optional_string = 'value'
1439
1440    proto2.CopyFrom(proto1)
1441    self.assertEqual(1, proto2.optional_int32)
1442    self.assertEqual('important-text', proto2.optional_string)
1443
1444  def testCopyFromRepeatedField(self):
1445    # Test copy with a repeated field.
1446    proto1 = unittest_pb2.TestAllTypes()
1447    proto1.repeated_int32.append(1)
1448    proto1.repeated_int32.append(2)
1449
1450    proto2 = unittest_pb2.TestAllTypes()
1451    proto2.repeated_int32.append(0)
1452    proto2.CopyFrom(proto1)
1453
1454    self.assertEqual(1, proto2.repeated_int32[0])
1455    self.assertEqual(2, proto2.repeated_int32[1])
1456
1457  def testCopyFromAllFields(self):
1458    # With all fields set.
1459    proto1 = unittest_pb2.TestAllTypes()
1460    test_util.SetAllFields(proto1)
1461    proto2 = unittest_pb2.TestAllTypes()
1462    proto2.CopyFrom(proto1)
1463
1464    # Messages should be equal.
1465    self.assertEqual(proto2, proto1)
1466
1467    # Serialized string should be equal too.
1468    string1 = proto1.SerializeToString()
1469    string2 = proto2.SerializeToString()
1470    self.assertEqual(string1, string2)
1471
1472  def testCopyFromSelf(self):
1473    proto1 = unittest_pb2.TestAllTypes()
1474    proto1.repeated_int32.append(1)
1475    proto1.optional_int32 = 2
1476    proto1.optional_string = 'important-text'
1477
1478    proto1.CopyFrom(proto1)
1479    self.assertEqual(1, proto1.repeated_int32[0])
1480    self.assertEqual(2, proto1.optional_int32)
1481    self.assertEqual('important-text', proto1.optional_string)
1482
1483  def testCopyFromBadType(self):
1484    # The python implementation doesn't raise an exception in this
1485    # case. In theory it should.
1486    if api_implementation.Type() == 'python':
1487      return
1488    proto1 = unittest_pb2.TestAllTypes()
1489    proto2 = unittest_pb2.TestAllExtensions()
1490    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1491
1492  def testDeepCopy(self):
1493    proto1 = unittest_pb2.TestAllTypes()
1494    proto1.optional_int32 = 1
1495    proto2 = copy.deepcopy(proto1)
1496    self.assertEqual(1, proto2.optional_int32)
1497
1498    proto1.repeated_int32.append(2)
1499    proto1.repeated_int32.append(3)
1500    container = copy.deepcopy(proto1.repeated_int32)
1501    self.assertEqual([2, 3], container)
1502
1503    # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
1504
1505  def testClear(self):
1506    proto = unittest_pb2.TestAllTypes()
1507    # C++ implementation does not support lazy fields right now so leave it
1508    # out for now.
1509    if api_implementation.Type() == 'python':
1510      test_util.SetAllFields(proto)
1511    else:
1512      test_util.SetAllNonLazyFields(proto)
1513    # Clear the message.
1514    proto.Clear()
1515    self.assertEqual(proto.ByteSize(), 0)
1516    empty_proto = unittest_pb2.TestAllTypes()
1517    self.assertEqual(proto, empty_proto)
1518
1519    # Test if extensions which were set are cleared.
1520    proto = unittest_pb2.TestAllExtensions()
1521    test_util.SetAllExtensions(proto)
1522    # Clear the message.
1523    proto.Clear()
1524    self.assertEqual(proto.ByteSize(), 0)
1525    empty_proto = unittest_pb2.TestAllExtensions()
1526    self.assertEqual(proto, empty_proto)
1527
1528  def testDisconnectingBeforeClear(self):
1529    proto = unittest_pb2.TestAllTypes()
1530    nested = proto.optional_nested_message
1531    proto.Clear()
1532    self.assertTrue(nested is not proto.optional_nested_message)
1533    nested.bb = 23
1534    self.assertTrue(not proto.HasField('optional_nested_message'))
1535    self.assertEqual(0, proto.optional_nested_message.bb)
1536
1537    proto = unittest_pb2.TestAllTypes()
1538    nested = proto.optional_nested_message
1539    nested.bb = 5
1540    foreign = proto.optional_foreign_message
1541    foreign.c = 6
1542
1543    proto.Clear()
1544    self.assertTrue(nested is not proto.optional_nested_message)
1545    self.assertTrue(foreign is not proto.optional_foreign_message)
1546    self.assertEqual(5, nested.bb)
1547    self.assertEqual(6, foreign.c)
1548    nested.bb = 15
1549    foreign.c = 16
1550    self.assertFalse(proto.HasField('optional_nested_message'))
1551    self.assertEqual(0, proto.optional_nested_message.bb)
1552    self.assertFalse(proto.HasField('optional_foreign_message'))
1553    self.assertEqual(0, proto.optional_foreign_message.c)
1554
1555  def testOneOf(self):
1556    proto = unittest_pb2.TestAllTypes()
1557    proto.oneof_uint32 = 10
1558    proto.oneof_nested_message.bb = 11
1559    self.assertEqual(11, proto.oneof_nested_message.bb)
1560    self.assertFalse(proto.HasField('oneof_uint32'))
1561    nested = proto.oneof_nested_message
1562    proto.oneof_string = 'abc'
1563    self.assertEqual('abc', proto.oneof_string)
1564    self.assertEqual(11, nested.bb)
1565    self.assertFalse(proto.HasField('oneof_nested_message'))
1566
1567  def assertInitialized(self, proto):
1568    self.assertTrue(proto.IsInitialized())
1569    # Neither method should raise an exception.
1570    proto.SerializeToString()
1571    proto.SerializePartialToString()
1572
1573  def assertNotInitialized(self, proto):
1574    self.assertFalse(proto.IsInitialized())
1575    self.assertRaises(message.EncodeError, proto.SerializeToString)
1576    # "Partial" serialization doesn't care if message is uninitialized.
1577    proto.SerializePartialToString()
1578
1579  def testIsInitialized(self):
1580    # Trivial cases - all optional fields and extensions.
1581    proto = unittest_pb2.TestAllTypes()
1582    self.assertInitialized(proto)
1583    proto = unittest_pb2.TestAllExtensions()
1584    self.assertInitialized(proto)
1585
1586    # The case of uninitialized required fields.
1587    proto = unittest_pb2.TestRequired()
1588    self.assertNotInitialized(proto)
1589    proto.a = proto.b = proto.c = 2
1590    self.assertInitialized(proto)
1591
1592    # The case of uninitialized submessage.
1593    proto = unittest_pb2.TestRequiredForeign()
1594    self.assertInitialized(proto)
1595    proto.optional_message.a = 1
1596    self.assertNotInitialized(proto)
1597    proto.optional_message.b = 0
1598    proto.optional_message.c = 0
1599    self.assertInitialized(proto)
1600
1601    # Uninitialized repeated submessage.
1602    message1 = proto.repeated_message.add()
1603    self.assertNotInitialized(proto)
1604    message1.a = message1.b = message1.c = 0
1605    self.assertInitialized(proto)
1606
1607    # Uninitialized repeated group in an extension.
1608    proto = unittest_pb2.TestAllExtensions()
1609    extension = unittest_pb2.TestRequired.multi
1610    message1 = proto.Extensions[extension].add()
1611    message2 = proto.Extensions[extension].add()
1612    self.assertNotInitialized(proto)
1613    message1.a = 1
1614    message1.b = 1
1615    message1.c = 1
1616    self.assertNotInitialized(proto)
1617    message2.a = 2
1618    message2.b = 2
1619    message2.c = 2
1620    self.assertInitialized(proto)
1621
1622    # Uninitialized nonrepeated message in an extension.
1623    proto = unittest_pb2.TestAllExtensions()
1624    extension = unittest_pb2.TestRequired.single
1625    proto.Extensions[extension].a = 1
1626    self.assertNotInitialized(proto)
1627    proto.Extensions[extension].b = 2
1628    proto.Extensions[extension].c = 3
1629    self.assertInitialized(proto)
1630
1631    # Try passing an errors list.
1632    errors = []
1633    proto = unittest_pb2.TestRequired()
1634    self.assertFalse(proto.IsInitialized(errors))
1635    self.assertEqual(errors, ['a', 'b', 'c'])
1636
1637  @unittest.skipIf(
1638      api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1639      'Errors are only available from the most recent C++ implementation.')
1640  def testFileDescriptorErrors(self):
1641    file_name = 'test_file_descriptor_errors.proto'
1642    package_name = 'test_file_descriptor_errors.proto'
1643    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1644    file_descriptor_proto.name = file_name
1645    file_descriptor_proto.package = package_name
1646    m1 = file_descriptor_proto.message_type.add()
1647    m1.name = 'msg1'
1648    # Compiles the proto into the C++ descriptor pool
1649    descriptor.FileDescriptor(
1650        file_name,
1651        package_name,
1652        serialized_pb=file_descriptor_proto.SerializeToString())
1653    # Add a FileDescriptorProto that has duplicate symbols
1654    another_file_name = 'another_test_file_descriptor_errors.proto'
1655    file_descriptor_proto.name = another_file_name
1656    m2 = file_descriptor_proto.message_type.add()
1657    m2.name = 'msg2'
1658    with self.assertRaises(TypeError) as cm:
1659      descriptor.FileDescriptor(
1660          another_file_name,
1661          package_name,
1662          serialized_pb=file_descriptor_proto.SerializeToString())
1663      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1664                      getattr(cm.expected, '__name__', cm.expected))
1665      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1666      # Error message will say something about this definition being a
1667      # duplicate, though we don't check the message exactly to avoid a
1668      # dependency on the C++ logging code.
1669      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
1670
1671  def testStringUTF8Encoding(self):
1672    proto = unittest_pb2.TestAllTypes()
1673
1674    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1675    self.assertRaises(TypeError,
1676                      setattr, proto, 'optional_bytes', u'unicode object')
1677
1678    # Check that the default value is of python's 'unicode' type.
1679    self.assertEqual(type(proto.optional_string), six.text_type)
1680
1681    proto.optional_string = six.text_type('Testing')
1682    self.assertEqual(proto.optional_string, str('Testing'))
1683
1684    # Assign a value of type 'str' which can be encoded in UTF-8.
1685    proto.optional_string = str('Testing')
1686    self.assertEqual(proto.optional_string, six.text_type('Testing'))
1687
1688    # Try to assign a 'bytes' object which contains non-UTF-8.
1689    self.assertRaises(ValueError,
1690                      setattr, proto, 'optional_string', b'a\x80a')
1691    # No exception: Assign already encoded UTF-8 bytes to a string field.
1692    utf8_bytes = u'Тест'.encode('utf-8')
1693    proto.optional_string = utf8_bytes
1694    # No exception: Assign the a non-ascii unicode object.
1695    proto.optional_string = u'Тест'
1696    # No exception thrown (normal str assignment containing ASCII).
1697    proto.optional_string = 'abc'
1698
1699  def testStringUTF8Serialization(self):
1700    proto = message_set_extensions_pb2.TestMessageSet()
1701    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
1702    extension = extension_message.message_set_extension
1703
1704    test_utf8 = u'Тест'
1705    test_utf8_bytes = test_utf8.encode('utf-8')
1706
1707    # 'Test' in another language, using UTF-8 charset.
1708    proto.Extensions[extension].str = test_utf8
1709
1710    # Serialize using the MessageSet wire format (this is specified in the
1711    # .proto file).
1712    serialized = proto.SerializeToString()
1713
1714    # Check byte size.
1715    self.assertEqual(proto.ByteSize(), len(serialized))
1716
1717    raw = unittest_mset_pb2.RawMessageSet()
1718    bytes_read = raw.MergeFromString(serialized)
1719    self.assertEqual(len(serialized), bytes_read)
1720
1721    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
1722
1723    self.assertEqual(1, len(raw.item))
1724    # Check that the type_id is the same as the tag ID in the .proto file.
1725    self.assertEqual(raw.item[0].type_id, 98418634)
1726
1727    # Check the actual bytes on the wire.
1728    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
1729    bytes_read = message2.MergeFromString(raw.item[0].message)
1730    self.assertEqual(len(raw.item[0].message), bytes_read)
1731
1732    self.assertEqual(type(message2.str), six.text_type)
1733    self.assertEqual(message2.str, test_utf8)
1734
1735    # The pure Python API throws an exception on MergeFromString(),
1736    # if any of the string fields of the message can't be UTF-8 decoded.
1737    # The C++ implementation of the API has no way to check that on
1738    # MergeFromString and thus has no way to throw the exception.
1739    #
1740    # The pure Python API always returns objects of type 'unicode' (UTF-8
1741    # encoded), or 'bytes' (in 7 bit ASCII).
1742    badbytes = raw.item[0].message.replace(
1743        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
1744
1745    unicode_decode_failed = False
1746    try:
1747      message2.MergeFromString(badbytes)
1748    except UnicodeDecodeError:
1749      unicode_decode_failed = True
1750    string_field = message2.str
1751    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
1752
1753  def testBytesInTextFormat(self):
1754    proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
1755    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
1756                     six.text_type(proto))
1757
1758  def testEmptyNestedMessage(self):
1759    proto = unittest_pb2.TestAllTypes()
1760    proto.optional_nested_message.MergeFrom(
1761        unittest_pb2.TestAllTypes.NestedMessage())
1762    self.assertTrue(proto.HasField('optional_nested_message'))
1763
1764    proto = unittest_pb2.TestAllTypes()
1765    proto.optional_nested_message.CopyFrom(
1766        unittest_pb2.TestAllTypes.NestedMessage())
1767    self.assertTrue(proto.HasField('optional_nested_message'))
1768
1769    proto = unittest_pb2.TestAllTypes()
1770    bytes_read = proto.optional_nested_message.MergeFromString(b'')
1771    self.assertEqual(0, bytes_read)
1772    self.assertTrue(proto.HasField('optional_nested_message'))
1773
1774    proto = unittest_pb2.TestAllTypes()
1775    proto.optional_nested_message.ParseFromString(b'')
1776    self.assertTrue(proto.HasField('optional_nested_message'))
1777
1778    serialized = proto.SerializeToString()
1779    proto2 = unittest_pb2.TestAllTypes()
1780    self.assertEqual(
1781        len(serialized),
1782        proto2.MergeFromString(serialized))
1783    self.assertTrue(proto2.HasField('optional_nested_message'))
1784
1785  def testSetInParent(self):
1786    proto = unittest_pb2.TestAllTypes()
1787    self.assertFalse(proto.HasField('optionalgroup'))
1788    proto.optionalgroup.SetInParent()
1789    self.assertTrue(proto.HasField('optionalgroup'))
1790
1791  def testPackageInitializationImport(self):
1792    """Test that we can import nested messages from their __init__.py.
1793
1794    Such setup is not trivial since at the time of processing of __init__.py one
1795    can't refer to its submodules by name in code, so expressions like
1796    google.protobuf.internal.import_test_package.inner_pb2
1797    don't work. They do work in imports, so we have assign an alias at import
1798    and then use that alias in generated code.
1799    """
1800    # We import here since it's the import that used to fail, and we want
1801    # the failure to have the right context.
1802    # pylint: disable=g-import-not-at-top
1803    from google.protobuf.internal import import_test_package
1804    # pylint: enable=g-import-not-at-top
1805    msg = import_test_package.myproto.Outer()
1806    # Just check the default value.
1807    self.assertEqual(57, msg.inner.value)
1808
1809#  Since we had so many tests for protocol buffer equality, we broke these out
1810#  into separate TestCase classes.
1811
1812
1813class TestAllTypesEqualityTest(unittest.TestCase):
1814
1815  def setUp(self):
1816    self.first_proto = unittest_pb2.TestAllTypes()
1817    self.second_proto = unittest_pb2.TestAllTypes()
1818
1819  def testNotHashable(self):
1820    self.assertRaises(TypeError, hash, self.first_proto)
1821
1822  def testSelfEquality(self):
1823    self.assertEqual(self.first_proto, self.first_proto)
1824
1825  def testEmptyProtosEqual(self):
1826    self.assertEqual(self.first_proto, self.second_proto)
1827
1828
1829class FullProtosEqualityTest(unittest.TestCase):
1830
1831  """Equality tests using completely-full protos as a starting point."""
1832
1833  def setUp(self):
1834    self.first_proto = unittest_pb2.TestAllTypes()
1835    self.second_proto = unittest_pb2.TestAllTypes()
1836    test_util.SetAllFields(self.first_proto)
1837    test_util.SetAllFields(self.second_proto)
1838
1839  def testNotHashable(self):
1840    self.assertRaises(TypeError, hash, self.first_proto)
1841
1842  def testNoneNotEqual(self):
1843    self.assertNotEqual(self.first_proto, None)
1844    self.assertNotEqual(None, self.second_proto)
1845
1846  def testNotEqualToOtherMessage(self):
1847    third_proto = unittest_pb2.TestRequired()
1848    self.assertNotEqual(self.first_proto, third_proto)
1849    self.assertNotEqual(third_proto, self.second_proto)
1850
1851  def testAllFieldsFilledEquality(self):
1852    self.assertEqual(self.first_proto, self.second_proto)
1853
1854  def testNonRepeatedScalar(self):
1855    # Nonrepeated scalar field change should cause inequality.
1856    self.first_proto.optional_int32 += 1
1857    self.assertNotEqual(self.first_proto, self.second_proto)
1858    # ...as should clearing a field.
1859    self.first_proto.ClearField('optional_int32')
1860    self.assertNotEqual(self.first_proto, self.second_proto)
1861
1862  def testNonRepeatedComposite(self):
1863    # Change a nonrepeated composite field.
1864    self.first_proto.optional_nested_message.bb += 1
1865    self.assertNotEqual(self.first_proto, self.second_proto)
1866    self.first_proto.optional_nested_message.bb -= 1
1867    self.assertEqual(self.first_proto, self.second_proto)
1868    # Clear a field in the nested message.
1869    self.first_proto.optional_nested_message.ClearField('bb')
1870    self.assertNotEqual(self.first_proto, self.second_proto)
1871    self.first_proto.optional_nested_message.bb = (
1872        self.second_proto.optional_nested_message.bb)
1873    self.assertEqual(self.first_proto, self.second_proto)
1874    # Remove the nested message entirely.
1875    self.first_proto.ClearField('optional_nested_message')
1876    self.assertNotEqual(self.first_proto, self.second_proto)
1877
1878  def testRepeatedScalar(self):
1879    # Change a repeated scalar field.
1880    self.first_proto.repeated_int32.append(5)
1881    self.assertNotEqual(self.first_proto, self.second_proto)
1882    self.first_proto.ClearField('repeated_int32')
1883    self.assertNotEqual(self.first_proto, self.second_proto)
1884
1885  def testRepeatedComposite(self):
1886    # Change value within a repeated composite field.
1887    self.first_proto.repeated_nested_message[0].bb += 1
1888    self.assertNotEqual(self.first_proto, self.second_proto)
1889    self.first_proto.repeated_nested_message[0].bb -= 1
1890    self.assertEqual(self.first_proto, self.second_proto)
1891    # Add a value to a repeated composite field.
1892    self.first_proto.repeated_nested_message.add()
1893    self.assertNotEqual(self.first_proto, self.second_proto)
1894    self.second_proto.repeated_nested_message.add()
1895    self.assertEqual(self.first_proto, self.second_proto)
1896
1897  def testNonRepeatedScalarHasBits(self):
1898    # Ensure that we test "has" bits as well as value for
1899    # nonrepeated scalar field.
1900    self.first_proto.ClearField('optional_int32')
1901    self.second_proto.optional_int32 = 0
1902    self.assertNotEqual(self.first_proto, self.second_proto)
1903
1904  def testNonRepeatedCompositeHasBits(self):
1905    # Ensure that we test "has" bits as well as value for
1906    # nonrepeated composite field.
1907    self.first_proto.ClearField('optional_nested_message')
1908    self.second_proto.optional_nested_message.ClearField('bb')
1909    self.assertNotEqual(self.first_proto, self.second_proto)
1910    self.first_proto.optional_nested_message.bb = 0
1911    self.first_proto.optional_nested_message.ClearField('bb')
1912    self.assertEqual(self.first_proto, self.second_proto)
1913
1914
1915class ExtensionEqualityTest(unittest.TestCase):
1916
1917  def testExtensionEquality(self):
1918    first_proto = unittest_pb2.TestAllExtensions()
1919    second_proto = unittest_pb2.TestAllExtensions()
1920    self.assertEqual(first_proto, second_proto)
1921    test_util.SetAllExtensions(first_proto)
1922    self.assertNotEqual(first_proto, second_proto)
1923    test_util.SetAllExtensions(second_proto)
1924    self.assertEqual(first_proto, second_proto)
1925
1926    # Ensure that we check value equality.
1927    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
1928    self.assertNotEqual(first_proto, second_proto)
1929    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
1930    self.assertEqual(first_proto, second_proto)
1931
1932    # Ensure that we also look at "has" bits.
1933    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
1934    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1935    self.assertNotEqual(first_proto, second_proto)
1936    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1937    self.assertEqual(first_proto, second_proto)
1938
1939    # Ensure that differences in cached values
1940    # don't matter if "has" bits are both false.
1941    first_proto = unittest_pb2.TestAllExtensions()
1942    second_proto = unittest_pb2.TestAllExtensions()
1943    self.assertEqual(
1944        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
1945    self.assertEqual(first_proto, second_proto)
1946
1947
1948class MutualRecursionEqualityTest(unittest.TestCase):
1949
1950  def testEqualityWithMutualRecursion(self):
1951    first_proto = unittest_pb2.TestMutualRecursionA()
1952    second_proto = unittest_pb2.TestMutualRecursionA()
1953    self.assertEqual(first_proto, second_proto)
1954    first_proto.bb.a.bb.optional_int32 = 23
1955    self.assertNotEqual(first_proto, second_proto)
1956    second_proto.bb.a.bb.optional_int32 = 23
1957    self.assertEqual(first_proto, second_proto)
1958
1959
1960class ByteSizeTest(unittest.TestCase):
1961
1962  def setUp(self):
1963    self.proto = unittest_pb2.TestAllTypes()
1964    self.extended_proto = more_extensions_pb2.ExtendedMessage()
1965    self.packed_proto = unittest_pb2.TestPackedTypes()
1966    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
1967
1968  def Size(self):
1969    return self.proto.ByteSize()
1970
1971  def testEmptyMessage(self):
1972    self.assertEqual(0, self.proto.ByteSize())
1973
1974  def testSizedOnKwargs(self):
1975    # Use a separate message to ensure testing right after creation.
1976    proto = unittest_pb2.TestAllTypes()
1977    self.assertEqual(0, proto.ByteSize())
1978    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
1979    # One byte for the tag, one to encode varint 1.
1980    self.assertEqual(2, proto_kwargs.ByteSize())
1981
1982  def testVarints(self):
1983    def Test(i, expected_varint_size):
1984      self.proto.Clear()
1985      self.proto.optional_int64 = i
1986      # Add one to the varint size for the tag info
1987      # for tag 1.
1988      self.assertEqual(expected_varint_size + 1, self.Size())
1989    Test(0, 1)
1990    Test(1, 1)
1991    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
1992      Test((1 << i) - 1, num_bytes)
1993    Test(-1, 10)
1994    Test(-2, 10)
1995    Test(-(1 << 63), 10)
1996
1997  def testStrings(self):
1998    self.proto.optional_string = ''
1999    # Need one byte for tag info (tag #14), and one byte for length.
2000    self.assertEqual(2, self.Size())
2001
2002    self.proto.optional_string = 'abc'
2003    # Need one byte for tag info (tag #14), and one byte for length.
2004    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2005
2006    self.proto.optional_string = 'x' * 128
2007    # Need one byte for tag info (tag #14), and TWO bytes for length.
2008    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2009
2010  def testOtherNumerics(self):
2011    self.proto.optional_fixed32 = 1234
2012    # One byte for tag and 4 bytes for fixed32.
2013    self.assertEqual(5, self.Size())
2014    self.proto = unittest_pb2.TestAllTypes()
2015
2016    self.proto.optional_fixed64 = 1234
2017    # One byte for tag and 8 bytes for fixed64.
2018    self.assertEqual(9, self.Size())
2019    self.proto = unittest_pb2.TestAllTypes()
2020
2021    self.proto.optional_float = 1.234
2022    # One byte for tag and 4 bytes for float.
2023    self.assertEqual(5, self.Size())
2024    self.proto = unittest_pb2.TestAllTypes()
2025
2026    self.proto.optional_double = 1.234
2027    # One byte for tag and 8 bytes for float.
2028    self.assertEqual(9, self.Size())
2029    self.proto = unittest_pb2.TestAllTypes()
2030
2031    self.proto.optional_sint32 = 64
2032    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2033    self.assertEqual(3, self.Size())
2034    self.proto = unittest_pb2.TestAllTypes()
2035
2036  def testComposites(self):
2037    # 3 bytes.
2038    self.proto.optional_nested_message.bb = (1 << 14)
2039    # Plus one byte for bb tag.
2040    # Plus 1 byte for optional_nested_message serialized size.
2041    # Plus two bytes for optional_nested_message tag.
2042    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2043
2044  def testGroups(self):
2045    # 4 bytes.
2046    self.proto.optionalgroup.a = (1 << 21)
2047    # Plus two bytes for |a| tag.
2048    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2049    self.assertEqual(4 + 2 + 2*2, self.Size())
2050
2051  def testRepeatedScalars(self):
2052    self.proto.repeated_int32.append(10)  # 1 byte.
2053    self.proto.repeated_int32.append(128)  # 2 bytes.
2054    # Also need 2 bytes for each entry for tag.
2055    self.assertEqual(1 + 2 + 2*2, self.Size())
2056
2057  def testRepeatedScalarsExtend(self):
2058    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2059    # Also need 2 bytes for each entry for tag.
2060    self.assertEqual(1 + 2 + 2*2, self.Size())
2061
2062  def testRepeatedScalarsRemove(self):
2063    self.proto.repeated_int32.append(10)  # 1 byte.
2064    self.proto.repeated_int32.append(128)  # 2 bytes.
2065    # Also need 2 bytes for each entry for tag.
2066    self.assertEqual(1 + 2 + 2*2, self.Size())
2067    self.proto.repeated_int32.remove(128)
2068    self.assertEqual(1 + 2, self.Size())
2069
2070  def testRepeatedComposites(self):
2071    # Empty message.  2 bytes tag plus 1 byte length.
2072    foreign_message_0 = self.proto.repeated_nested_message.add()
2073    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2074    foreign_message_1 = self.proto.repeated_nested_message.add()
2075    foreign_message_1.bb = 7
2076    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2077
2078  def testRepeatedCompositesDelete(self):
2079    # Empty message.  2 bytes tag plus 1 byte length.
2080    foreign_message_0 = self.proto.repeated_nested_message.add()
2081    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2082    foreign_message_1 = self.proto.repeated_nested_message.add()
2083    foreign_message_1.bb = 9
2084    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2085
2086    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2087    del self.proto.repeated_nested_message[0]
2088    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2089
2090    # Now add a new message.
2091    foreign_message_2 = self.proto.repeated_nested_message.add()
2092    foreign_message_2.bb = 12
2093
2094    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2095    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2096    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2097
2098    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2099    del self.proto.repeated_nested_message[1]
2100    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2101
2102    del self.proto.repeated_nested_message[0]
2103    self.assertEqual(0, self.Size())
2104
2105  def testRepeatedGroups(self):
2106    # 2-byte START_GROUP plus 2-byte END_GROUP.
2107    group_0 = self.proto.repeatedgroup.add()
2108    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2109    # plus 2-byte END_GROUP.
2110    group_1 = self.proto.repeatedgroup.add()
2111    group_1.a =  7
2112    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2113
2114  def testExtensions(self):
2115    proto = unittest_pb2.TestAllExtensions()
2116    self.assertEqual(0, proto.ByteSize())
2117    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2118    proto.Extensions[extension] = 23
2119    # 1 byte for tag, 1 byte for value.
2120    self.assertEqual(2, proto.ByteSize())
2121
2122  def testCacheInvalidationForNonrepeatedScalar(self):
2123    # Test non-extension.
2124    self.proto.optional_int32 = 1
2125    self.assertEqual(2, self.proto.ByteSize())
2126    self.proto.optional_int32 = 128
2127    self.assertEqual(3, self.proto.ByteSize())
2128    self.proto.ClearField('optional_int32')
2129    self.assertEqual(0, self.proto.ByteSize())
2130
2131    # Test within extension.
2132    extension = more_extensions_pb2.optional_int_extension
2133    self.extended_proto.Extensions[extension] = 1
2134    self.assertEqual(2, self.extended_proto.ByteSize())
2135    self.extended_proto.Extensions[extension] = 128
2136    self.assertEqual(3, self.extended_proto.ByteSize())
2137    self.extended_proto.ClearExtension(extension)
2138    self.assertEqual(0, self.extended_proto.ByteSize())
2139
2140  def testCacheInvalidationForRepeatedScalar(self):
2141    # Test non-extension.
2142    self.proto.repeated_int32.append(1)
2143    self.assertEqual(3, self.proto.ByteSize())
2144    self.proto.repeated_int32.append(1)
2145    self.assertEqual(6, self.proto.ByteSize())
2146    self.proto.repeated_int32[1] = 128
2147    self.assertEqual(7, self.proto.ByteSize())
2148    self.proto.ClearField('repeated_int32')
2149    self.assertEqual(0, self.proto.ByteSize())
2150
2151    # Test within extension.
2152    extension = more_extensions_pb2.repeated_int_extension
2153    repeated = self.extended_proto.Extensions[extension]
2154    repeated.append(1)
2155    self.assertEqual(2, self.extended_proto.ByteSize())
2156    repeated.append(1)
2157    self.assertEqual(4, self.extended_proto.ByteSize())
2158    repeated[1] = 128
2159    self.assertEqual(5, self.extended_proto.ByteSize())
2160    self.extended_proto.ClearExtension(extension)
2161    self.assertEqual(0, self.extended_proto.ByteSize())
2162
2163  def testCacheInvalidationForNonrepeatedMessage(self):
2164    # Test non-extension.
2165    self.proto.optional_foreign_message.c = 1
2166    self.assertEqual(5, self.proto.ByteSize())
2167    self.proto.optional_foreign_message.c = 128
2168    self.assertEqual(6, self.proto.ByteSize())
2169    self.proto.optional_foreign_message.ClearField('c')
2170    self.assertEqual(3, self.proto.ByteSize())
2171    self.proto.ClearField('optional_foreign_message')
2172    self.assertEqual(0, self.proto.ByteSize())
2173
2174    if api_implementation.Type() == 'python':
2175      # This is only possible in pure-Python implementation of the API.
2176      child = self.proto.optional_foreign_message
2177      self.proto.ClearField('optional_foreign_message')
2178      child.c = 128
2179      self.assertEqual(0, self.proto.ByteSize())
2180
2181    # Test within extension.
2182    extension = more_extensions_pb2.optional_message_extension
2183    child = self.extended_proto.Extensions[extension]
2184    self.assertEqual(0, self.extended_proto.ByteSize())
2185    child.foreign_message_int = 1
2186    self.assertEqual(4, self.extended_proto.ByteSize())
2187    child.foreign_message_int = 128
2188    self.assertEqual(5, self.extended_proto.ByteSize())
2189    self.extended_proto.ClearExtension(extension)
2190    self.assertEqual(0, self.extended_proto.ByteSize())
2191
2192  def testCacheInvalidationForRepeatedMessage(self):
2193    # Test non-extension.
2194    child0 = self.proto.repeated_foreign_message.add()
2195    self.assertEqual(3, self.proto.ByteSize())
2196    self.proto.repeated_foreign_message.add()
2197    self.assertEqual(6, self.proto.ByteSize())
2198    child0.c = 1
2199    self.assertEqual(8, self.proto.ByteSize())
2200    self.proto.ClearField('repeated_foreign_message')
2201    self.assertEqual(0, self.proto.ByteSize())
2202
2203    # Test within extension.
2204    extension = more_extensions_pb2.repeated_message_extension
2205    child_list = self.extended_proto.Extensions[extension]
2206    child0 = child_list.add()
2207    self.assertEqual(2, self.extended_proto.ByteSize())
2208    child_list.add()
2209    self.assertEqual(4, self.extended_proto.ByteSize())
2210    child0.foreign_message_int = 1
2211    self.assertEqual(6, self.extended_proto.ByteSize())
2212    child0.ClearField('foreign_message_int')
2213    self.assertEqual(4, self.extended_proto.ByteSize())
2214    self.extended_proto.ClearExtension(extension)
2215    self.assertEqual(0, self.extended_proto.ByteSize())
2216
2217  def testPackedRepeatedScalars(self):
2218    self.assertEqual(0, self.packed_proto.ByteSize())
2219
2220    self.packed_proto.packed_int32.append(10)   # 1 byte.
2221    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2222    # The tag is 2 bytes (the field number is 90), and the varint
2223    # storing the length is 1 byte.
2224    int_size = 1 + 2 + 3
2225    self.assertEqual(int_size, self.packed_proto.ByteSize())
2226
2227    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2228    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2229    # 2 more tag bytes, 1 more length byte.
2230    double_size = 8 + 8 + 3
2231    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2232
2233    self.packed_proto.ClearField('packed_int32')
2234    self.assertEqual(double_size, self.packed_proto.ByteSize())
2235
2236  def testPackedExtensions(self):
2237    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2238    extension = self.packed_extended_proto.Extensions[
2239        unittest_pb2.packed_fixed32_extension]
2240    extension.extend([1, 2, 3, 4])   # 16 bytes
2241    # Tag is 3 bytes.
2242    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2243
2244
2245# Issues to be sure to cover include:
2246#   * Handling of unrecognized tags ("uninterpreted_bytes").
2247#   * Handling of MessageSets.
2248#   * Consistent ordering of tags in the wire format,
2249#     including ordering between extensions and non-extension
2250#     fields.
2251#   * Consistent serialization of negative numbers, especially
2252#     negative int32s.
2253#   * Handling of empty submessages (with and without "has"
2254#     bits set).
2255
2256class SerializationTest(unittest.TestCase):
2257
2258  def testSerializeEmtpyMessage(self):
2259    first_proto = unittest_pb2.TestAllTypes()
2260    second_proto = unittest_pb2.TestAllTypes()
2261    serialized = first_proto.SerializeToString()
2262    self.assertEqual(first_proto.ByteSize(), len(serialized))
2263    self.assertEqual(
2264        len(serialized),
2265        second_proto.MergeFromString(serialized))
2266    self.assertEqual(first_proto, second_proto)
2267
2268  def testSerializeAllFields(self):
2269    first_proto = unittest_pb2.TestAllTypes()
2270    second_proto = unittest_pb2.TestAllTypes()
2271    test_util.SetAllFields(first_proto)
2272    serialized = first_proto.SerializeToString()
2273    self.assertEqual(first_proto.ByteSize(), len(serialized))
2274    self.assertEqual(
2275        len(serialized),
2276        second_proto.MergeFromString(serialized))
2277    self.assertEqual(first_proto, second_proto)
2278
2279  def testSerializeAllExtensions(self):
2280    first_proto = unittest_pb2.TestAllExtensions()
2281    second_proto = unittest_pb2.TestAllExtensions()
2282    test_util.SetAllExtensions(first_proto)
2283    serialized = first_proto.SerializeToString()
2284    self.assertEqual(
2285        len(serialized),
2286        second_proto.MergeFromString(serialized))
2287    self.assertEqual(first_proto, second_proto)
2288
2289  def testSerializeWithOptionalGroup(self):
2290    first_proto = unittest_pb2.TestAllTypes()
2291    second_proto = unittest_pb2.TestAllTypes()
2292    first_proto.optionalgroup.a = 242
2293    serialized = first_proto.SerializeToString()
2294    self.assertEqual(
2295        len(serialized),
2296        second_proto.MergeFromString(serialized))
2297    self.assertEqual(first_proto, second_proto)
2298
2299  def testSerializeNegativeValues(self):
2300    first_proto = unittest_pb2.TestAllTypes()
2301
2302    first_proto.optional_int32 = -1
2303    first_proto.optional_int64 = -(2 << 40)
2304    first_proto.optional_sint32 = -3
2305    first_proto.optional_sint64 = -(4 << 40)
2306    first_proto.optional_sfixed32 = -5
2307    first_proto.optional_sfixed64 = -(6 << 40)
2308
2309    second_proto = unittest_pb2.TestAllTypes.FromString(
2310        first_proto.SerializeToString())
2311
2312    self.assertEqual(first_proto, second_proto)
2313
2314  def testParseTruncated(self):
2315    # This test is only applicable for the Python implementation of the API.
2316    if api_implementation.Type() != 'python':
2317      return
2318
2319    first_proto = unittest_pb2.TestAllTypes()
2320    test_util.SetAllFields(first_proto)
2321    serialized = first_proto.SerializeToString()
2322
2323    for truncation_point in range(len(serialized) + 1):
2324      try:
2325        second_proto = unittest_pb2.TestAllTypes()
2326        unknown_fields = unittest_pb2.TestEmptyMessage()
2327        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2328        # If we didn't raise an error then we read exactly the amount expected.
2329        self.assertEqual(truncation_point, pos)
2330
2331        # Parsing to unknown fields should not throw if parsing to known fields
2332        # did not.
2333        try:
2334          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2335          self.assertEqual(truncation_point, pos2)
2336        except message.DecodeError:
2337          self.fail('Parsing unknown fields failed when parsing known fields '
2338                    'did not.')
2339      except message.DecodeError:
2340        # Parsing unknown fields should also fail.
2341        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2342                          serialized, 0, truncation_point)
2343
2344  def testCanonicalSerializationOrder(self):
2345    proto = more_messages_pb2.OutOfOrderFields()
2346    # These are also their tag numbers.  Even though we're setting these in
2347    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2348    # file, they should nonetheless be serialized in tag order.
2349    proto.optional_sint32 = 5
2350    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2351    proto.optional_uint32 = 3
2352    proto.Extensions[more_messages_pb2.optional_int64] = 2
2353    proto.optional_int32 = 1
2354    serialized = proto.SerializeToString()
2355    self.assertEqual(proto.ByteSize(), len(serialized))
2356    d = _MiniDecoder(serialized)
2357    ReadTag = d.ReadFieldNumberAndWireType
2358    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2359    self.assertEqual(1, d.ReadInt32())
2360    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2361    self.assertEqual(2, d.ReadInt64())
2362    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2363    self.assertEqual(3, d.ReadUInt32())
2364    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2365    self.assertEqual(4, d.ReadUInt64())
2366    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2367    self.assertEqual(5, d.ReadSInt32())
2368
2369  def testCanonicalSerializationOrderSameAsCpp(self):
2370    # Copy of the same test we use for C++.
2371    proto = unittest_pb2.TestFieldOrderings()
2372    test_util.SetAllFieldsAndExtensions(proto)
2373    serialized = proto.SerializeToString()
2374    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2375
2376  def testMergeFromStringWhenFieldsAlreadySet(self):
2377    first_proto = unittest_pb2.TestAllTypes()
2378    first_proto.repeated_string.append('foobar')
2379    first_proto.optional_int32 = 23
2380    first_proto.optional_nested_message.bb = 42
2381    serialized = first_proto.SerializeToString()
2382
2383    second_proto = unittest_pb2.TestAllTypes()
2384    second_proto.repeated_string.append('baz')
2385    second_proto.optional_int32 = 100
2386    second_proto.optional_nested_message.bb = 999
2387
2388    bytes_parsed = second_proto.MergeFromString(serialized)
2389    self.assertEqual(len(serialized), bytes_parsed)
2390
2391    # Ensure that we append to repeated fields.
2392    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2393    # Ensure that we overwrite nonrepeatd scalars.
2394    self.assertEqual(23, second_proto.optional_int32)
2395    # Ensure that we recursively call MergeFromString() on
2396    # submessages.
2397    self.assertEqual(42, second_proto.optional_nested_message.bb)
2398
2399  def testMessageSetWireFormat(self):
2400    proto = message_set_extensions_pb2.TestMessageSet()
2401    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2402    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2403    extension1 = extension_message1.message_set_extension
2404    extension2 = extension_message2.message_set_extension
2405    extension3 = message_set_extensions_pb2.message_set_extension3
2406    proto.Extensions[extension1].i = 123
2407    proto.Extensions[extension2].str = 'foo'
2408    proto.Extensions[extension3].text = 'bar'
2409
2410    # Serialize using the MessageSet wire format (this is specified in the
2411    # .proto file).
2412    serialized = proto.SerializeToString()
2413
2414    raw = unittest_mset_pb2.RawMessageSet()
2415    self.assertEqual(False,
2416                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2417    self.assertEqual(
2418        len(serialized),
2419        raw.MergeFromString(serialized))
2420    self.assertEqual(3, len(raw.item))
2421
2422    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2423    self.assertEqual(
2424        len(raw.item[0].message),
2425        message1.MergeFromString(raw.item[0].message))
2426    self.assertEqual(123, message1.i)
2427
2428    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2429    self.assertEqual(
2430        len(raw.item[1].message),
2431        message2.MergeFromString(raw.item[1].message))
2432    self.assertEqual('foo', message2.str)
2433
2434    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2435    self.assertEqual(
2436        len(raw.item[2].message),
2437        message3.MergeFromString(raw.item[2].message))
2438    self.assertEqual('bar', message3.text)
2439
2440    # Deserialize using the MessageSet wire format.
2441    proto2 = message_set_extensions_pb2.TestMessageSet()
2442    self.assertEqual(
2443        len(serialized),
2444        proto2.MergeFromString(serialized))
2445    self.assertEqual(123, proto2.Extensions[extension1].i)
2446    self.assertEqual('foo', proto2.Extensions[extension2].str)
2447    self.assertEqual('bar', proto2.Extensions[extension3].text)
2448
2449    # Check byte size.
2450    self.assertEqual(proto2.ByteSize(), len(serialized))
2451    self.assertEqual(proto.ByteSize(), len(serialized))
2452
2453  def testMessageSetWireFormatUnknownExtension(self):
2454    # Create a message using the message set wire format with an unknown
2455    # message.
2456    raw = unittest_mset_pb2.RawMessageSet()
2457
2458    # Add an item.
2459    item = raw.item.add()
2460    item.type_id = 98418603
2461    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2462    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2463    message1.i = 12345
2464    item.message = message1.SerializeToString()
2465
2466    # Add a second, unknown extension.
2467    item = raw.item.add()
2468    item.type_id = 98418604
2469    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2470    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2471    message1.i = 12346
2472    item.message = message1.SerializeToString()
2473
2474    # Add another unknown extension.
2475    item = raw.item.add()
2476    item.type_id = 98418605
2477    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2478    message1.str = 'foo'
2479    item.message = message1.SerializeToString()
2480
2481    serialized = raw.SerializeToString()
2482
2483    # Parse message using the message set wire format.
2484    proto = message_set_extensions_pb2.TestMessageSet()
2485    self.assertEqual(
2486        len(serialized),
2487        proto.MergeFromString(serialized))
2488
2489    # Check that the message parsed well.
2490    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2491    extension1 = extension_message1.message_set_extension
2492    self.assertEqual(12345, proto.Extensions[extension1].i)
2493
2494  def testUnknownFields(self):
2495    proto = unittest_pb2.TestAllTypes()
2496    test_util.SetAllFields(proto)
2497
2498    serialized = proto.SerializeToString()
2499
2500    # The empty message should be parsable with all of the fields
2501    # unknown.
2502    proto2 = unittest_pb2.TestEmptyMessage()
2503
2504    # Parsing this message should succeed.
2505    self.assertEqual(
2506        len(serialized),
2507        proto2.MergeFromString(serialized))
2508
2509    # Now test with a int64 field set.
2510    proto = unittest_pb2.TestAllTypes()
2511    proto.optional_int64 = 0x0fffffffffffffff
2512    serialized = proto.SerializeToString()
2513    # The empty message should be parsable with all of the fields
2514    # unknown.
2515    proto2 = unittest_pb2.TestEmptyMessage()
2516    # Parsing this message should succeed.
2517    self.assertEqual(
2518        len(serialized),
2519        proto2.MergeFromString(serialized))
2520
2521  def _CheckRaises(self, exc_class, callable_obj, exception):
2522    """This method checks if the excpetion type and message are as expected."""
2523    try:
2524      callable_obj()
2525    except exc_class as ex:
2526      # Check if the exception message is the right one.
2527      self.assertEqual(exception, str(ex))
2528      return
2529    else:
2530      raise self.failureException('%s not raised' % str(exc_class))
2531
2532  def testSerializeUninitialized(self):
2533    proto = unittest_pb2.TestRequired()
2534    self._CheckRaises(
2535        message.EncodeError,
2536        proto.SerializeToString,
2537        'Message protobuf_unittest.TestRequired is missing required fields: '
2538        'a,b,c')
2539    # Shouldn't raise exceptions.
2540    partial = proto.SerializePartialToString()
2541
2542    proto2 = unittest_pb2.TestRequired()
2543    self.assertFalse(proto2.HasField('a'))
2544    # proto2 ParseFromString does not check that required fields are set.
2545    proto2.ParseFromString(partial)
2546    self.assertFalse(proto2.HasField('a'))
2547
2548    proto.a = 1
2549    self._CheckRaises(
2550        message.EncodeError,
2551        proto.SerializeToString,
2552        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2553    # Shouldn't raise exceptions.
2554    partial = proto.SerializePartialToString()
2555
2556    proto.b = 2
2557    self._CheckRaises(
2558        message.EncodeError,
2559        proto.SerializeToString,
2560        'Message protobuf_unittest.TestRequired is missing required fields: c')
2561    # Shouldn't raise exceptions.
2562    partial = proto.SerializePartialToString()
2563
2564    proto.c = 3
2565    serialized = proto.SerializeToString()
2566    # Shouldn't raise exceptions.
2567    partial = proto.SerializePartialToString()
2568
2569    proto2 = unittest_pb2.TestRequired()
2570    self.assertEqual(
2571        len(serialized),
2572        proto2.MergeFromString(serialized))
2573    self.assertEqual(1, proto2.a)
2574    self.assertEqual(2, proto2.b)
2575    self.assertEqual(3, proto2.c)
2576    self.assertEqual(
2577        len(partial),
2578        proto2.MergeFromString(partial))
2579    self.assertEqual(1, proto2.a)
2580    self.assertEqual(2, proto2.b)
2581    self.assertEqual(3, proto2.c)
2582
2583  def testSerializeUninitializedSubMessage(self):
2584    proto = unittest_pb2.TestRequiredForeign()
2585
2586    # Sub-message doesn't exist yet, so this succeeds.
2587    proto.SerializeToString()
2588
2589    proto.optional_message.a = 1
2590    self._CheckRaises(
2591        message.EncodeError,
2592        proto.SerializeToString,
2593        'Message protobuf_unittest.TestRequiredForeign '
2594        'is missing required fields: '
2595        'optional_message.b,optional_message.c')
2596
2597    proto.optional_message.b = 2
2598    proto.optional_message.c = 3
2599    proto.SerializeToString()
2600
2601    proto.repeated_message.add().a = 1
2602    proto.repeated_message.add().b = 2
2603    self._CheckRaises(
2604        message.EncodeError,
2605        proto.SerializeToString,
2606        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2607        'repeated_message[0].b,repeated_message[0].c,'
2608        'repeated_message[1].a,repeated_message[1].c')
2609
2610    proto.repeated_message[0].b = 2
2611    proto.repeated_message[0].c = 3
2612    proto.repeated_message[1].a = 1
2613    proto.repeated_message[1].c = 3
2614    proto.SerializeToString()
2615
2616  def testSerializeAllPackedFields(self):
2617    first_proto = unittest_pb2.TestPackedTypes()
2618    second_proto = unittest_pb2.TestPackedTypes()
2619    test_util.SetAllPackedFields(first_proto)
2620    serialized = first_proto.SerializeToString()
2621    self.assertEqual(first_proto.ByteSize(), len(serialized))
2622    bytes_read = second_proto.MergeFromString(serialized)
2623    self.assertEqual(second_proto.ByteSize(), bytes_read)
2624    self.assertEqual(first_proto, second_proto)
2625
2626  def testSerializeAllPackedExtensions(self):
2627    first_proto = unittest_pb2.TestPackedExtensions()
2628    second_proto = unittest_pb2.TestPackedExtensions()
2629    test_util.SetAllPackedExtensions(first_proto)
2630    serialized = first_proto.SerializeToString()
2631    bytes_read = second_proto.MergeFromString(serialized)
2632    self.assertEqual(second_proto.ByteSize(), bytes_read)
2633    self.assertEqual(first_proto, second_proto)
2634
2635  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2636    first_proto = unittest_pb2.TestPackedTypes()
2637    first_proto.packed_int32.extend([1, 2])
2638    first_proto.packed_double.append(3.0)
2639    serialized = first_proto.SerializeToString()
2640
2641    second_proto = unittest_pb2.TestPackedTypes()
2642    second_proto.packed_int32.append(3)
2643    second_proto.packed_double.extend([1.0, 2.0])
2644    second_proto.packed_sint32.append(4)
2645
2646    self.assertEqual(
2647        len(serialized),
2648        second_proto.MergeFromString(serialized))
2649    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2650    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2651    self.assertEqual([4], second_proto.packed_sint32)
2652
2653  def testPackedFieldsWireFormat(self):
2654    proto = unittest_pb2.TestPackedTypes()
2655    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2656    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2657    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2658    serialized = proto.SerializeToString()
2659    self.assertEqual(proto.ByteSize(), len(serialized))
2660    d = _MiniDecoder(serialized)
2661    ReadTag = d.ReadFieldNumberAndWireType
2662    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2663    self.assertEqual(1+1+1+2, d.ReadInt32())
2664    self.assertEqual(1, d.ReadInt32())
2665    self.assertEqual(2, d.ReadInt32())
2666    self.assertEqual(150, d.ReadInt32())
2667    self.assertEqual(3, d.ReadInt32())
2668    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2669    self.assertEqual(4, d.ReadInt32())
2670    self.assertEqual(2.0, d.ReadFloat())
2671    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2672    self.assertEqual(8+8, d.ReadInt32())
2673    self.assertEqual(1.0, d.ReadDouble())
2674    self.assertEqual(1000.0, d.ReadDouble())
2675    self.assertTrue(d.EndOfStream())
2676
2677  def testParsePackedFromUnpacked(self):
2678    unpacked = unittest_pb2.TestUnpackedTypes()
2679    test_util.SetAllUnpackedFields(unpacked)
2680    packed = unittest_pb2.TestPackedTypes()
2681    serialized = unpacked.SerializeToString()
2682    self.assertEqual(
2683        len(serialized),
2684        packed.MergeFromString(serialized))
2685    expected = unittest_pb2.TestPackedTypes()
2686    test_util.SetAllPackedFields(expected)
2687    self.assertEqual(expected, packed)
2688
2689  def testParseUnpackedFromPacked(self):
2690    packed = unittest_pb2.TestPackedTypes()
2691    test_util.SetAllPackedFields(packed)
2692    unpacked = unittest_pb2.TestUnpackedTypes()
2693    serialized = packed.SerializeToString()
2694    self.assertEqual(
2695        len(serialized),
2696        unpacked.MergeFromString(serialized))
2697    expected = unittest_pb2.TestUnpackedTypes()
2698    test_util.SetAllUnpackedFields(expected)
2699    self.assertEqual(expected, unpacked)
2700
2701  def testFieldNumbers(self):
2702    proto = unittest_pb2.TestAllTypes()
2703    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2704    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2705    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2706    self.assertEqual(
2707      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2708    self.assertEqual(
2709      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2710    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2711    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2712    self.assertEqual(
2713      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2714    self.assertEqual(
2715      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2716
2717  def testExtensionFieldNumbers(self):
2718    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2719    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2720    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2721    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2722    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2723    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2724    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2725    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2726    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2727    self.assertEqual(
2728      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2729    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2730    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2731      21)
2732    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2733    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2734    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2735    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2736    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2737    self.assertEqual(
2738      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2739    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2740    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2741      51)
2742
2743  def testInitKwargs(self):
2744    proto = unittest_pb2.TestAllTypes(
2745        optional_int32=1,
2746        optional_string='foo',
2747        optional_bool=True,
2748        optional_bytes=b'bar',
2749        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2750        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2751        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2752        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2753        repeated_int32=[1, 2, 3])
2754    self.assertTrue(proto.IsInitialized())
2755    self.assertTrue(proto.HasField('optional_int32'))
2756    self.assertTrue(proto.HasField('optional_string'))
2757    self.assertTrue(proto.HasField('optional_bool'))
2758    self.assertTrue(proto.HasField('optional_bytes'))
2759    self.assertTrue(proto.HasField('optional_nested_message'))
2760    self.assertTrue(proto.HasField('optional_foreign_message'))
2761    self.assertTrue(proto.HasField('optional_nested_enum'))
2762    self.assertTrue(proto.HasField('optional_foreign_enum'))
2763    self.assertEqual(1, proto.optional_int32)
2764    self.assertEqual('foo', proto.optional_string)
2765    self.assertEqual(True, proto.optional_bool)
2766    self.assertEqual(b'bar', proto.optional_bytes)
2767    self.assertEqual(1, proto.optional_nested_message.bb)
2768    self.assertEqual(1, proto.optional_foreign_message.c)
2769    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
2770                     proto.optional_nested_enum)
2771    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
2772    self.assertEqual([1, 2, 3], proto.repeated_int32)
2773
2774  def testInitArgsUnknownFieldName(self):
2775    def InitalizeEmptyMessageWithExtraKeywordArg():
2776      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
2777    self._CheckRaises(
2778        ValueError,
2779        InitalizeEmptyMessageWithExtraKeywordArg,
2780        'Protocol message TestEmptyMessage has no "unknown" field.')
2781
2782  def testInitRequiredKwargs(self):
2783    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
2784    self.assertTrue(proto.IsInitialized())
2785    self.assertTrue(proto.HasField('a'))
2786    self.assertTrue(proto.HasField('b'))
2787    self.assertTrue(proto.HasField('c'))
2788    self.assertTrue(not proto.HasField('dummy2'))
2789    self.assertEqual(1, proto.a)
2790    self.assertEqual(1, proto.b)
2791    self.assertEqual(1, proto.c)
2792
2793  def testInitRequiredForeignKwargs(self):
2794    proto = unittest_pb2.TestRequiredForeign(
2795        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
2796    self.assertTrue(proto.IsInitialized())
2797    self.assertTrue(proto.HasField('optional_message'))
2798    self.assertTrue(proto.optional_message.IsInitialized())
2799    self.assertTrue(proto.optional_message.HasField('a'))
2800    self.assertTrue(proto.optional_message.HasField('b'))
2801    self.assertTrue(proto.optional_message.HasField('c'))
2802    self.assertTrue(not proto.optional_message.HasField('dummy2'))
2803    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
2804                     proto.optional_message)
2805    self.assertEqual(1, proto.optional_message.a)
2806    self.assertEqual(1, proto.optional_message.b)
2807    self.assertEqual(1, proto.optional_message.c)
2808
2809  def testInitRepeatedKwargs(self):
2810    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
2811    self.assertTrue(proto.IsInitialized())
2812    self.assertEqual(1, proto.repeated_int32[0])
2813    self.assertEqual(2, proto.repeated_int32[1])
2814    self.assertEqual(3, proto.repeated_int32[2])
2815
2816
2817class OptionsTest(unittest.TestCase):
2818
2819  def testMessageOptions(self):
2820    proto = message_set_extensions_pb2.TestMessageSet()
2821    self.assertEqual(True,
2822                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2823    proto = unittest_pb2.TestAllTypes()
2824    self.assertEqual(False,
2825                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2826
2827  def testPackedOptions(self):
2828    proto = unittest_pb2.TestAllTypes()
2829    proto.optional_int32 = 1
2830    proto.optional_double = 3.0
2831    for field_descriptor, _ in proto.ListFields():
2832      self.assertEqual(False, field_descriptor.GetOptions().packed)
2833
2834    proto = unittest_pb2.TestPackedTypes()
2835    proto.packed_int32.append(1)
2836    proto.packed_double.append(3.0)
2837    for field_descriptor, _ in proto.ListFields():
2838      self.assertEqual(True, field_descriptor.GetOptions().packed)
2839      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
2840                       field_descriptor.label)
2841
2842
2843
2844class ClassAPITest(unittest.TestCase):
2845
2846  @unittest.skipIf(
2847      api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
2848      'C++ implementation requires a call to MakeDescriptor()')
2849  def testMakeClassWithNestedDescriptor(self):
2850    leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
2851                                      containing_type=None, fields=[],
2852                                      nested_types=[], enum_types=[],
2853                                      extensions=[])
2854    child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
2855                                       containing_type=None, fields=[],
2856                                       nested_types=[leaf_desc], enum_types=[],
2857                                       extensions=[])
2858    sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
2859                                         '', containing_type=None, fields=[],
2860                                         nested_types=[], enum_types=[],
2861                                         extensions=[])
2862    parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
2863                                        containing_type=None, fields=[],
2864                                        nested_types=[child_desc, sibling_desc],
2865                                        enum_types=[], extensions=[])
2866    message_class = reflection.MakeClass(parent_desc)
2867    self.assertIn('child', message_class.__dict__)
2868    self.assertIn('sibling', message_class.__dict__)
2869    self.assertIn('leaf', message_class.child.__dict__)
2870
2871  def _GetSerializedFileDescriptor(self, name):
2872    """Get a serialized representation of a test FileDescriptorProto.
2873
2874    Args:
2875      name: All calls to this must use a unique message name, to avoid
2876          collisions in the cpp descriptor pool.
2877    Returns:
2878      A string containing the serialized form of a test FileDescriptorProto.
2879    """
2880    file_descriptor_str = (
2881        'message_type {'
2882        '  name: "' + name + '"'
2883        '  field {'
2884        '    name: "flat"'
2885        '    number: 1'
2886        '    label: LABEL_REPEATED'
2887        '    type: TYPE_UINT32'
2888        '  }'
2889        '  field {'
2890        '    name: "bar"'
2891        '    number: 2'
2892        '    label: LABEL_OPTIONAL'
2893        '    type: TYPE_MESSAGE'
2894        '    type_name: "Bar"'
2895        '  }'
2896        '  nested_type {'
2897        '    name: "Bar"'
2898        '    field {'
2899        '      name: "baz"'
2900        '      number: 3'
2901        '      label: LABEL_OPTIONAL'
2902        '      type: TYPE_MESSAGE'
2903        '      type_name: "Baz"'
2904        '    }'
2905        '    nested_type {'
2906        '      name: "Baz"'
2907        '      enum_type {'
2908        '        name: "deep_enum"'
2909        '        value {'
2910        '          name: "VALUE_A"'
2911        '          number: 0'
2912        '        }'
2913        '      }'
2914        '      field {'
2915        '        name: "deep"'
2916        '        number: 4'
2917        '        label: LABEL_OPTIONAL'
2918        '        type: TYPE_UINT32'
2919        '      }'
2920        '    }'
2921        '  }'
2922        '}')
2923    file_descriptor = descriptor_pb2.FileDescriptorProto()
2924    text_format.Merge(file_descriptor_str, file_descriptor)
2925    return file_descriptor.SerializeToString()
2926
2927  def testParsingFlatClassWithExplicitClassDeclaration(self):
2928    """Test that the generated class can parse a flat message."""
2929    # TODO(xiaofeng): This test fails with cpp implemetnation in the call
2930    # of six.with_metaclass(). The other two callsites of with_metaclass
2931    # in this file are both excluded from cpp test, so it might be expected
2932    # to fail. Need someone more familiar with the python code to take a
2933    # look at this.
2934    if api_implementation.Type() != 'python':
2935      return
2936    file_descriptor = descriptor_pb2.FileDescriptorProto()
2937    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
2938    msg_descriptor = descriptor.MakeDescriptor(
2939        file_descriptor.message_type[0])
2940
2941    class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
2942      DESCRIPTOR = msg_descriptor
2943    msg = MessageClass()
2944    msg_str = (
2945        'flat: 0 '
2946        'flat: 1 '
2947        'flat: 2 ')
2948    text_format.Merge(msg_str, msg)
2949    self.assertEqual(msg.flat, [0, 1, 2])
2950
2951  def testParsingFlatClass(self):
2952    """Test that the generated class can parse a flat message."""
2953    file_descriptor = descriptor_pb2.FileDescriptorProto()
2954    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
2955    msg_descriptor = descriptor.MakeDescriptor(
2956        file_descriptor.message_type[0])
2957    msg_class = reflection.MakeClass(msg_descriptor)
2958    msg = msg_class()
2959    msg_str = (
2960        'flat: 0 '
2961        'flat: 1 '
2962        'flat: 2 ')
2963    text_format.Merge(msg_str, msg)
2964    self.assertEqual(msg.flat, [0, 1, 2])
2965
2966  def testParsingNestedClass(self):
2967    """Test that the generated class can parse a nested message."""
2968    file_descriptor = descriptor_pb2.FileDescriptorProto()
2969    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
2970    msg_descriptor = descriptor.MakeDescriptor(
2971        file_descriptor.message_type[0])
2972    msg_class = reflection.MakeClass(msg_descriptor)
2973    msg = msg_class()
2974    msg_str = (
2975        'bar {'
2976        '  baz {'
2977        '    deep: 4'
2978        '  }'
2979        '}')
2980    text_format.Merge(msg_str, msg)
2981    self.assertEqual(msg.bar.baz.deep, 4)
2982
2983if __name__ == '__main__':
2984  unittest.main()
2985