1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Unittest for reflection.py, which also indirectly tests the output of the
35pure-Python protocol compiler.
36"""
37
38__author__ = 'robinson@google.com (Will Robinson)'
39
40import copy
41import gc
42import operator
43import struct
44
45from google.apputils import basetest
46from google.protobuf import unittest_import_pb2
47from google.protobuf import unittest_mset_pb2
48from google.protobuf import unittest_pb2
49from google.protobuf import descriptor_pb2
50from google.protobuf import descriptor
51from google.protobuf import message
52from google.protobuf import reflection
53from google.protobuf import text_format
54from google.protobuf.internal import api_implementation
55from google.protobuf.internal import more_extensions_pb2
56from google.protobuf.internal import more_messages_pb2
57from google.protobuf.internal import wire_format
58from google.protobuf.internal import test_util
59from google.protobuf.internal import decoder
60
61
62class _MiniDecoder(object):
63  """Decodes a stream of values from a string.
64
65  Once upon a time we actually had a class called decoder.Decoder.  Then we
66  got rid of it during a redesign that made decoding much, much faster overall.
67  But a couple tests in this file used it to check that the serialized form of
68  a message was correct.  So, this class implements just the methods that were
69  used by said tests, so that we don't have to rewrite the tests.
70  """
71
72  def __init__(self, bytes):
73    self._bytes = bytes
74    self._pos = 0
75
76  def ReadVarint(self):
77    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
78    return result
79
80  ReadInt32 = ReadVarint
81  ReadInt64 = ReadVarint
82  ReadUInt32 = ReadVarint
83  ReadUInt64 = ReadVarint
84
85  def ReadSInt64(self):
86    return wire_format.ZigZagDecode(self.ReadVarint())
87
88  ReadSInt32 = ReadSInt64
89
90  def ReadFieldNumberAndWireType(self):
91    return wire_format.UnpackTag(self.ReadVarint())
92
93  def ReadFloat(self):
94    result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
95    self._pos += 4
96    return result
97
98  def ReadDouble(self):
99    result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
100    self._pos += 8
101    return result
102
103  def EndOfStream(self):
104    return self._pos == len(self._bytes)
105
106
107class ReflectionTest(basetest.TestCase):
108
109  def assertListsEqual(self, values, others):
110    self.assertEqual(len(values), len(others))
111    for i in range(len(values)):
112      self.assertEqual(values[i], others[i])
113
114  def testScalarConstructor(self):
115    # Constructor with only scalar types should succeed.
116    proto = unittest_pb2.TestAllTypes(
117        optional_int32=24,
118        optional_double=54.321,
119        optional_string='optional_string')
120
121    self.assertEqual(24, proto.optional_int32)
122    self.assertEqual(54.321, proto.optional_double)
123    self.assertEqual('optional_string', proto.optional_string)
124
125  def testRepeatedScalarConstructor(self):
126    # Constructor with only repeated scalar types should succeed.
127    proto = unittest_pb2.TestAllTypes(
128        repeated_int32=[1, 2, 3, 4],
129        repeated_double=[1.23, 54.321],
130        repeated_bool=[True, False, False],
131        repeated_string=["optional_string"])
132
133    self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
134    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
135    self.assertEquals([True, False, False], list(proto.repeated_bool))
136    self.assertEquals(["optional_string"], list(proto.repeated_string))
137
138  def testRepeatedCompositeConstructor(self):
139    # Constructor with only repeated composite types should succeed.
140    proto = unittest_pb2.TestAllTypes(
141        repeated_nested_message=[
142            unittest_pb2.TestAllTypes.NestedMessage(
143                bb=unittest_pb2.TestAllTypes.FOO),
144            unittest_pb2.TestAllTypes.NestedMessage(
145                bb=unittest_pb2.TestAllTypes.BAR)],
146        repeated_foreign_message=[
147            unittest_pb2.ForeignMessage(c=-43),
148            unittest_pb2.ForeignMessage(c=45324),
149            unittest_pb2.ForeignMessage(c=12)],
150        repeatedgroup=[
151            unittest_pb2.TestAllTypes.RepeatedGroup(),
152            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
153            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
154
155    self.assertEquals(
156        [unittest_pb2.TestAllTypes.NestedMessage(
157            bb=unittest_pb2.TestAllTypes.FOO),
158         unittest_pb2.TestAllTypes.NestedMessage(
159             bb=unittest_pb2.TestAllTypes.BAR)],
160        list(proto.repeated_nested_message))
161    self.assertEquals(
162        [unittest_pb2.ForeignMessage(c=-43),
163         unittest_pb2.ForeignMessage(c=45324),
164         unittest_pb2.ForeignMessage(c=12)],
165        list(proto.repeated_foreign_message))
166    self.assertEquals(
167        [unittest_pb2.TestAllTypes.RepeatedGroup(),
168         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
169         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
170        list(proto.repeatedgroup))
171
172  def testMixedConstructor(self):
173    # Constructor with only mixed types should succeed.
174    proto = unittest_pb2.TestAllTypes(
175        optional_int32=24,
176        optional_string='optional_string',
177        repeated_double=[1.23, 54.321],
178        repeated_bool=[True, False, False],
179        repeated_nested_message=[
180            unittest_pb2.TestAllTypes.NestedMessage(
181                bb=unittest_pb2.TestAllTypes.FOO),
182            unittest_pb2.TestAllTypes.NestedMessage(
183                bb=unittest_pb2.TestAllTypes.BAR)],
184        repeated_foreign_message=[
185            unittest_pb2.ForeignMessage(c=-43),
186            unittest_pb2.ForeignMessage(c=45324),
187            unittest_pb2.ForeignMessage(c=12)])
188
189    self.assertEqual(24, proto.optional_int32)
190    self.assertEqual('optional_string', proto.optional_string)
191    self.assertEquals([1.23, 54.321], list(proto.repeated_double))
192    self.assertEquals([True, False, False], list(proto.repeated_bool))
193    self.assertEquals(
194        [unittest_pb2.TestAllTypes.NestedMessage(
195            bb=unittest_pb2.TestAllTypes.FOO),
196         unittest_pb2.TestAllTypes.NestedMessage(
197             bb=unittest_pb2.TestAllTypes.BAR)],
198        list(proto.repeated_nested_message))
199    self.assertEquals(
200        [unittest_pb2.ForeignMessage(c=-43),
201         unittest_pb2.ForeignMessage(c=45324),
202         unittest_pb2.ForeignMessage(c=12)],
203        list(proto.repeated_foreign_message))
204
205  def testConstructorTypeError(self):
206    self.assertRaises(
207        TypeError, unittest_pb2.TestAllTypes, optional_int32="foo")
208    self.assertRaises(
209        TypeError, unittest_pb2.TestAllTypes, optional_string=1234)
210    self.assertRaises(
211        TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234)
212    self.assertRaises(
213        TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234)
214    self.assertRaises(
215        TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"])
216    self.assertRaises(
217        TypeError, unittest_pb2.TestAllTypes, repeated_string=1234)
218    self.assertRaises(
219        TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234])
220    self.assertRaises(
221        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234)
222    self.assertRaises(
223        TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234])
224
225  def testConstructorInvalidatesCachedByteSize(self):
226    message = unittest_pb2.TestAllTypes(optional_int32 = 12)
227    self.assertEquals(2, message.ByteSize())
228
229    message = unittest_pb2.TestAllTypes(
230        optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage())
231    self.assertEquals(3, message.ByteSize())
232
233    message = unittest_pb2.TestAllTypes(repeated_int32 = [12])
234    self.assertEquals(3, message.ByteSize())
235
236    message = unittest_pb2.TestAllTypes(
237        repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()])
238    self.assertEquals(3, message.ByteSize())
239
240  def testSimpleHasBits(self):
241    # Test a scalar.
242    proto = unittest_pb2.TestAllTypes()
243    self.assertTrue(not proto.HasField('optional_int32'))
244    self.assertEqual(0, proto.optional_int32)
245    # HasField() shouldn't be true if all we've done is
246    # read the default value.
247    self.assertTrue(not proto.HasField('optional_int32'))
248    proto.optional_int32 = 1
249    # Setting a value however *should* set the "has" bit.
250    self.assertTrue(proto.HasField('optional_int32'))
251    proto.ClearField('optional_int32')
252    # And clearing that value should unset the "has" bit.
253    self.assertTrue(not proto.HasField('optional_int32'))
254
255  def testHasBitsWithSinglyNestedScalar(self):
256    # Helper used to test foreign messages and groups.
257    #
258    # composite_field_name should be the name of a non-repeated
259    # composite (i.e., foreign or group) field in TestAllTypes,
260    # and scalar_field_name should be the name of an integer-valued
261    # scalar field within that composite.
262    #
263    # I never thought I'd miss C++ macros and templates so much. :(
264    # This helper is semantically just:
265    #
266    #   assert proto.composite_field.scalar_field == 0
267    #   assert not proto.composite_field.HasField('scalar_field')
268    #   assert not proto.HasField('composite_field')
269    #
270    #   proto.composite_field.scalar_field = 10
271    #   old_composite_field = proto.composite_field
272    #
273    #   assert proto.composite_field.scalar_field == 10
274    #   assert proto.composite_field.HasField('scalar_field')
275    #   assert proto.HasField('composite_field')
276    #
277    #   proto.ClearField('composite_field')
278    #
279    #   assert not proto.composite_field.HasField('scalar_field')
280    #   assert not proto.HasField('composite_field')
281    #   assert proto.composite_field.scalar_field == 0
282    #
283    #   # Now ensure that ClearField('composite_field') disconnected
284    #   # the old field object from the object tree...
285    #   assert old_composite_field is not proto.composite_field
286    #   old_composite_field.scalar_field = 20
287    #   assert not proto.composite_field.HasField('scalar_field')
288    #   assert not proto.HasField('composite_field')
289    def TestCompositeHasBits(composite_field_name, scalar_field_name):
290      proto = unittest_pb2.TestAllTypes()
291      # First, check that we can get the scalar value, and see that it's the
292      # default (0), but that proto.HasField('omposite') and
293      # proto.composite.HasField('scalar') will still return False.
294      composite_field = getattr(proto, composite_field_name)
295      original_scalar_value = getattr(composite_field, scalar_field_name)
296      self.assertEqual(0, original_scalar_value)
297      # Assert that the composite object does not "have" the scalar.
298      self.assertTrue(not composite_field.HasField(scalar_field_name))
299      # Assert that proto does not "have" the composite field.
300      self.assertTrue(not proto.HasField(composite_field_name))
301
302      # Now set the scalar within the composite field.  Ensure that the setting
303      # is reflected, and that proto.HasField('composite') and
304      # proto.composite.HasField('scalar') now both return True.
305      new_val = 20
306      setattr(composite_field, scalar_field_name, new_val)
307      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
308      # Hold on to a reference to the current composite_field object.
309      old_composite_field = composite_field
310      # Assert that the has methods now return true.
311      self.assertTrue(composite_field.HasField(scalar_field_name))
312      self.assertTrue(proto.HasField(composite_field_name))
313
314      # Now call the clear method...
315      proto.ClearField(composite_field_name)
316
317      # ...and ensure that the "has" bits are all back to False...
318      composite_field = getattr(proto, composite_field_name)
319      self.assertTrue(not composite_field.HasField(scalar_field_name))
320      self.assertTrue(not proto.HasField(composite_field_name))
321      # ...and ensure that the scalar field has returned to its default.
322      self.assertEqual(0, getattr(composite_field, scalar_field_name))
323
324      self.assertTrue(old_composite_field is not composite_field)
325      setattr(old_composite_field, scalar_field_name, new_val)
326      self.assertTrue(not composite_field.HasField(scalar_field_name))
327      self.assertTrue(not proto.HasField(composite_field_name))
328      self.assertEqual(0, getattr(composite_field, scalar_field_name))
329
330    # Test simple, single-level nesting when we set a scalar.
331    TestCompositeHasBits('optionalgroup', 'a')
332    TestCompositeHasBits('optional_nested_message', 'bb')
333    TestCompositeHasBits('optional_foreign_message', 'c')
334    TestCompositeHasBits('optional_import_message', 'd')
335
336  def testReferencesToNestedMessage(self):
337    proto = unittest_pb2.TestAllTypes()
338    nested = proto.optional_nested_message
339    del proto
340    # A previous version had a bug where this would raise an exception when
341    # hitting a now-dead weak reference.
342    nested.bb = 23
343
344  def testDisconnectingNestedMessageBeforeSettingField(self):
345    proto = unittest_pb2.TestAllTypes()
346    nested = proto.optional_nested_message
347    proto.ClearField('optional_nested_message')  # Should disconnect from parent
348    self.assertTrue(nested is not proto.optional_nested_message)
349    nested.bb = 23
350    self.assertTrue(not proto.HasField('optional_nested_message'))
351    self.assertEqual(0, proto.optional_nested_message.bb)
352
353  def testGetDefaultMessageAfterDisconnectingDefaultMessage(self):
354    proto = unittest_pb2.TestAllTypes()
355    nested = proto.optional_nested_message
356    proto.ClearField('optional_nested_message')
357    del proto
358    del nested
359    # Force a garbage collect so that the underlying CMessages are freed along
360    # with the Messages they point to. This is to make sure we're not deleting
361    # default message instances.
362    gc.collect()
363    proto = unittest_pb2.TestAllTypes()
364    nested = proto.optional_nested_message
365
366  def testDisconnectingNestedMessageAfterSettingField(self):
367    proto = unittest_pb2.TestAllTypes()
368    nested = proto.optional_nested_message
369    nested.bb = 5
370    self.assertTrue(proto.HasField('optional_nested_message'))
371    proto.ClearField('optional_nested_message')  # Should disconnect from parent
372    self.assertEqual(5, nested.bb)
373    self.assertEqual(0, proto.optional_nested_message.bb)
374    self.assertTrue(nested is not proto.optional_nested_message)
375    nested.bb = 23
376    self.assertTrue(not proto.HasField('optional_nested_message'))
377    self.assertEqual(0, proto.optional_nested_message.bb)
378
379  def testDisconnectingNestedMessageBeforeGettingField(self):
380    proto = unittest_pb2.TestAllTypes()
381    self.assertTrue(not proto.HasField('optional_nested_message'))
382    proto.ClearField('optional_nested_message')
383    self.assertTrue(not proto.HasField('optional_nested_message'))
384
385  def testDisconnectingNestedMessageAfterMerge(self):
386    # This test exercises the code path that does not use ReleaseMessage().
387    # The underlying fear is that if we use ReleaseMessage() incorrectly,
388    # we will have memory leaks.  It's hard to check that that doesn't happen,
389    # but at least we can exercise that code path to make sure it works.
390    proto1 = unittest_pb2.TestAllTypes()
391    proto2 = unittest_pb2.TestAllTypes()
392    proto2.optional_nested_message.bb = 5
393    proto1.MergeFrom(proto2)
394    self.assertTrue(proto1.HasField('optional_nested_message'))
395    proto1.ClearField('optional_nested_message')
396    self.assertTrue(not proto1.HasField('optional_nested_message'))
397
398  def testDisconnectingLazyNestedMessage(self):
399    # This test exercises releasing a nested message that is lazy. This test
400    # only exercises real code in the C++ implementation as Python does not
401    # support lazy parsing, but the current C++ implementation results in
402    # memory corruption and a crash.
403    if api_implementation.Type() != 'python':
404      return
405    proto = unittest_pb2.TestAllTypes()
406    proto.optional_lazy_message.bb = 5
407    proto.ClearField('optional_lazy_message')
408    del proto
409    gc.collect()
410
411  def testHasBitsWhenModifyingRepeatedFields(self):
412    # Test nesting when we add an element to a repeated field in a submessage.
413    proto = unittest_pb2.TestNestedMessageHasBits()
414    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
415    self.assertEqual(
416        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
417    self.assertTrue(proto.HasField('optional_nested_message'))
418
419    # Do the same test, but with a repeated composite field within the
420    # submessage.
421    proto.ClearField('optional_nested_message')
422    self.assertTrue(not proto.HasField('optional_nested_message'))
423    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
424    self.assertTrue(proto.HasField('optional_nested_message'))
425
426  def testHasBitsForManyLevelsOfNesting(self):
427    # Test nesting many levels deep.
428    recursive_proto = unittest_pb2.TestMutualRecursionA()
429    self.assertTrue(not recursive_proto.HasField('bb'))
430    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
431    self.assertTrue(not recursive_proto.HasField('bb'))
432    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
433    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
434    self.assertTrue(recursive_proto.HasField('bb'))
435    self.assertTrue(recursive_proto.bb.HasField('a'))
436    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
437    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
438    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
439    self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
440    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
441
442  def testSingularListFields(self):
443    proto = unittest_pb2.TestAllTypes()
444    proto.optional_fixed32 = 1
445    proto.optional_int32 = 5
446    proto.optional_string = 'foo'
447    # Access sub-message but don't set it yet.
448    nested_message = proto.optional_nested_message
449    self.assertEqual(
450      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
451        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
452        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
453      proto.ListFields())
454
455    proto.optional_nested_message.bb = 123
456    self.assertEqual(
457      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
458        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
459        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
460        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
461             nested_message) ],
462      proto.ListFields())
463
464  def testRepeatedListFields(self):
465    proto = unittest_pb2.TestAllTypes()
466    proto.repeated_fixed32.append(1)
467    proto.repeated_int32.append(5)
468    proto.repeated_int32.append(11)
469    proto.repeated_string.extend(['foo', 'bar'])
470    proto.repeated_string.extend([])
471    proto.repeated_string.append('baz')
472    proto.repeated_string.extend(str(x) for x in xrange(2))
473    proto.optional_int32 = 21
474    proto.repeated_bool  # Access but don't set anything; should not be listed.
475    self.assertEqual(
476      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
477        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
478        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
479        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
480          ['foo', 'bar', 'baz', '0', '1']) ],
481      proto.ListFields())
482
483  def testSingularListExtensions(self):
484    proto = unittest_pb2.TestAllExtensions()
485    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
486    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
487    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
488    self.assertEqual(
489      [ (unittest_pb2.optional_int32_extension  , 5),
490        (unittest_pb2.optional_fixed32_extension, 1),
491        (unittest_pb2.optional_string_extension , 'foo') ],
492      proto.ListFields())
493
494  def testRepeatedListExtensions(self):
495    proto = unittest_pb2.TestAllExtensions()
496    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
497    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
498    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
499    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
500    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
501    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
502    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
503    self.assertEqual(
504      [ (unittest_pb2.optional_int32_extension  , 21),
505        (unittest_pb2.repeated_int32_extension  , [5, 11]),
506        (unittest_pb2.repeated_fixed32_extension, [1]),
507        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
508      proto.ListFields())
509
510  def testListFieldsAndExtensions(self):
511    proto = unittest_pb2.TestFieldOrderings()
512    test_util.SetAllFieldsAndExtensions(proto)
513    unittest_pb2.my_extension_int
514    self.assertEqual(
515      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
516        (unittest_pb2.my_extension_int               , 23),
517        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
518        (unittest_pb2.my_extension_string            , 'bar'),
519        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
520      proto.ListFields())
521
522  def testDefaultValues(self):
523    proto = unittest_pb2.TestAllTypes()
524    self.assertEqual(0, proto.optional_int32)
525    self.assertEqual(0, proto.optional_int64)
526    self.assertEqual(0, proto.optional_uint32)
527    self.assertEqual(0, proto.optional_uint64)
528    self.assertEqual(0, proto.optional_sint32)
529    self.assertEqual(0, proto.optional_sint64)
530    self.assertEqual(0, proto.optional_fixed32)
531    self.assertEqual(0, proto.optional_fixed64)
532    self.assertEqual(0, proto.optional_sfixed32)
533    self.assertEqual(0, proto.optional_sfixed64)
534    self.assertEqual(0.0, proto.optional_float)
535    self.assertEqual(0.0, proto.optional_double)
536    self.assertEqual(False, proto.optional_bool)
537    self.assertEqual('', proto.optional_string)
538    self.assertEqual(b'', proto.optional_bytes)
539
540    self.assertEqual(41, proto.default_int32)
541    self.assertEqual(42, proto.default_int64)
542    self.assertEqual(43, proto.default_uint32)
543    self.assertEqual(44, proto.default_uint64)
544    self.assertEqual(-45, proto.default_sint32)
545    self.assertEqual(46, proto.default_sint64)
546    self.assertEqual(47, proto.default_fixed32)
547    self.assertEqual(48, proto.default_fixed64)
548    self.assertEqual(49, proto.default_sfixed32)
549    self.assertEqual(-50, proto.default_sfixed64)
550    self.assertEqual(51.5, proto.default_float)
551    self.assertEqual(52e3, proto.default_double)
552    self.assertEqual(True, proto.default_bool)
553    self.assertEqual('hello', proto.default_string)
554    self.assertEqual(b'world', proto.default_bytes)
555    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
556    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
557    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
558                     proto.default_import_enum)
559
560    proto = unittest_pb2.TestExtremeDefaultValues()
561    self.assertEqual(u'\u1234', proto.utf8_string)
562
563  def testHasFieldWithUnknownFieldName(self):
564    proto = unittest_pb2.TestAllTypes()
565    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
566
567  def testClearFieldWithUnknownFieldName(self):
568    proto = unittest_pb2.TestAllTypes()
569    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
570
571  def testClearRemovesChildren(self):
572    # Make sure there aren't any implementation bugs that are only partially
573    # clearing the message (which can happen in the more complex C++
574    # implementation which has parallel message lists).
575    proto = unittest_pb2.TestRequiredForeign()
576    for i in range(10):
577      proto.repeated_message.add()
578    proto2 = unittest_pb2.TestRequiredForeign()
579    proto.CopyFrom(proto2)
580    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
581
582  def testDisallowedAssignments(self):
583    # It's illegal to assign values directly to repeated fields
584    # or to nonrepeated composite fields.  Ensure that this fails.
585    proto = unittest_pb2.TestAllTypes()
586    # Repeated fields.
587    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
588    # Lists shouldn't work, either.
589    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
590    # Composite fields.
591    self.assertRaises(AttributeError, setattr, proto,
592                      'optional_nested_message', 23)
593    # Assignment to a repeated nested message field without specifying
594    # the index in the array of nested messages.
595    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
596                      'bb', 34)
597    # Assignment to an attribute of a repeated field.
598    self.assertRaises(AttributeError, setattr, proto.repeated_float,
599                      'some_attribute', 34)
600    # proto.nonexistent_field = 23 should fail as well.
601    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
602
603  def testSingleScalarTypeSafety(self):
604    proto = unittest_pb2.TestAllTypes()
605    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
606    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
607    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
608    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
609
610  def testIntegerTypes(self):
611    def TestGetAndDeserialize(field_name, value, expected_type):
612      proto = unittest_pb2.TestAllTypes()
613      setattr(proto, field_name, value)
614      self.assertTrue(isinstance(getattr(proto, field_name), expected_type))
615      proto2 = unittest_pb2.TestAllTypes()
616      proto2.ParseFromString(proto.SerializeToString())
617      self.assertTrue(isinstance(getattr(proto2, field_name), expected_type))
618
619    TestGetAndDeserialize('optional_int32', 1, int)
620    TestGetAndDeserialize('optional_int32', 1 << 30, int)
621    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
622    if struct.calcsize('L') == 4:
623      # Python only has signed ints, so 32-bit python can't fit an uint32
624      # in an int.
625      TestGetAndDeserialize('optional_uint32', 1 << 31, long)
626    else:
627      # 64-bit python can fit uint32 inside an int
628      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
629    TestGetAndDeserialize('optional_int64', 1 << 30, long)
630    TestGetAndDeserialize('optional_int64', 1 << 60, long)
631    TestGetAndDeserialize('optional_uint64', 1 << 30, long)
632    TestGetAndDeserialize('optional_uint64', 1 << 60, long)
633
634  def testSingleScalarBoundsChecking(self):
635    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
636      pb = unittest_pb2.TestAllTypes()
637      setattr(pb, field_name, expected_min)
638      self.assertEqual(expected_min, getattr(pb, field_name))
639      setattr(pb, field_name, expected_max)
640      self.assertEqual(expected_max, getattr(pb, field_name))
641      self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
642      self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
643
644    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
645    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
646    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
647    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
648
649    pb = unittest_pb2.TestAllTypes()
650    pb.optional_nested_enum = 1
651    self.assertEqual(1, pb.optional_nested_enum)
652
653  def testRepeatedScalarTypeSafety(self):
654    proto = unittest_pb2.TestAllTypes()
655    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
656    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
657    self.assertRaises(TypeError, proto.repeated_string, 10)
658    self.assertRaises(TypeError, proto.repeated_bytes, 10)
659
660    proto.repeated_int32.append(10)
661    proto.repeated_int32[0] = 23
662    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
663    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
664
665    # Repeated enums tests.
666    #proto.repeated_nested_enum.append(0)
667
668  def testSingleScalarGettersAndSetters(self):
669    proto = unittest_pb2.TestAllTypes()
670    self.assertEqual(0, proto.optional_int32)
671    proto.optional_int32 = 1
672    self.assertEqual(1, proto.optional_int32)
673
674    proto.optional_uint64 = 0xffffffffffff
675    self.assertEqual(0xffffffffffff, proto.optional_uint64)
676    proto.optional_uint64 = 0xffffffffffffffff
677    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
678    # TODO(robinson): Test all other scalar field types.
679
680  def testSingleScalarClearField(self):
681    proto = unittest_pb2.TestAllTypes()
682    # Should be allowed to clear something that's not there (a no-op).
683    proto.ClearField('optional_int32')
684    proto.optional_int32 = 1
685    self.assertTrue(proto.HasField('optional_int32'))
686    proto.ClearField('optional_int32')
687    self.assertEqual(0, proto.optional_int32)
688    self.assertTrue(not proto.HasField('optional_int32'))
689    # TODO(robinson): Test all other scalar field types.
690
691  def testEnums(self):
692    proto = unittest_pb2.TestAllTypes()
693    self.assertEqual(1, proto.FOO)
694    self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
695    self.assertEqual(2, proto.BAR)
696    self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
697    self.assertEqual(3, proto.BAZ)
698    self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
699
700  def testEnum_Name(self):
701    self.assertEqual('FOREIGN_FOO',
702                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO))
703    self.assertEqual('FOREIGN_BAR',
704                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR))
705    self.assertEqual('FOREIGN_BAZ',
706                     unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ))
707    self.assertRaises(ValueError,
708                      unittest_pb2.ForeignEnum.Name, 11312)
709
710    proto = unittest_pb2.TestAllTypes()
711    self.assertEqual('FOO',
712                     proto.NestedEnum.Name(proto.FOO))
713    self.assertEqual('FOO',
714                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO))
715    self.assertEqual('BAR',
716                     proto.NestedEnum.Name(proto.BAR))
717    self.assertEqual('BAR',
718                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR))
719    self.assertEqual('BAZ',
720                     proto.NestedEnum.Name(proto.BAZ))
721    self.assertEqual('BAZ',
722                     unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ))
723    self.assertRaises(ValueError,
724                      proto.NestedEnum.Name, 11312)
725    self.assertRaises(ValueError,
726                      unittest_pb2.TestAllTypes.NestedEnum.Name, 11312)
727
728  def testEnum_Value(self):
729    self.assertEqual(unittest_pb2.FOREIGN_FOO,
730                     unittest_pb2.ForeignEnum.Value('FOREIGN_FOO'))
731    self.assertEqual(unittest_pb2.FOREIGN_BAR,
732                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAR'))
733    self.assertEqual(unittest_pb2.FOREIGN_BAZ,
734                     unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ'))
735    self.assertRaises(ValueError,
736                      unittest_pb2.ForeignEnum.Value, 'FO')
737
738    proto = unittest_pb2.TestAllTypes()
739    self.assertEqual(proto.FOO,
740                     proto.NestedEnum.Value('FOO'))
741    self.assertEqual(proto.FOO,
742                     unittest_pb2.TestAllTypes.NestedEnum.Value('FOO'))
743    self.assertEqual(proto.BAR,
744                     proto.NestedEnum.Value('BAR'))
745    self.assertEqual(proto.BAR,
746                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAR'))
747    self.assertEqual(proto.BAZ,
748                     proto.NestedEnum.Value('BAZ'))
749    self.assertEqual(proto.BAZ,
750                     unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ'))
751    self.assertRaises(ValueError,
752                      proto.NestedEnum.Value, 'Foo')
753    self.assertRaises(ValueError,
754                      unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo')
755
756  def testEnum_KeysAndValues(self):
757    self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'],
758                     unittest_pb2.ForeignEnum.keys())
759    self.assertEqual([4, 5, 6],
760                     unittest_pb2.ForeignEnum.values())
761    self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5),
762                      ('FOREIGN_BAZ', 6)],
763                     unittest_pb2.ForeignEnum.items())
764
765    proto = unittest_pb2.TestAllTypes()
766    self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys())
767    self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values())
768    self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)],
769                     proto.NestedEnum.items())
770
771  def testRepeatedScalars(self):
772    proto = unittest_pb2.TestAllTypes()
773
774    self.assertTrue(not proto.repeated_int32)
775    self.assertEqual(0, len(proto.repeated_int32))
776    proto.repeated_int32.append(5)
777    proto.repeated_int32.append(10)
778    proto.repeated_int32.append(15)
779    self.assertTrue(proto.repeated_int32)
780    self.assertEqual(3, len(proto.repeated_int32))
781
782    self.assertEqual([5, 10, 15], proto.repeated_int32)
783
784    # Test single retrieval.
785    self.assertEqual(5, proto.repeated_int32[0])
786    self.assertEqual(15, proto.repeated_int32[-1])
787    # Test out-of-bounds indices.
788    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
789    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
790    # Test incorrect types passed to __getitem__.
791    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
792    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
793
794    # Test single assignment.
795    proto.repeated_int32[1] = 20
796    self.assertEqual([5, 20, 15], proto.repeated_int32)
797
798    # Test insertion.
799    proto.repeated_int32.insert(1, 25)
800    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
801
802    # Test slice retrieval.
803    proto.repeated_int32.append(30)
804    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
805    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
806
807    # Test slice assignment with an iterator
808    proto.repeated_int32[1:4] = (i for i in xrange(3))
809    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
810
811    # Test slice assignment.
812    proto.repeated_int32[1:4] = [35, 40, 45]
813    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
814
815    # Test that we can use the field as an iterator.
816    result = []
817    for i in proto.repeated_int32:
818      result.append(i)
819    self.assertEqual([5, 35, 40, 45, 30], result)
820
821    # Test single deletion.
822    del proto.repeated_int32[2]
823    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
824
825    # Test slice deletion.
826    del proto.repeated_int32[2:]
827    self.assertEqual([5, 35], proto.repeated_int32)
828
829    # Test extending.
830    proto.repeated_int32.extend([3, 13])
831    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
832
833    # Test clearing.
834    proto.ClearField('repeated_int32')
835    self.assertTrue(not proto.repeated_int32)
836    self.assertEqual(0, len(proto.repeated_int32))
837
838    proto.repeated_int32.append(1)
839    self.assertEqual(1, proto.repeated_int32[-1])
840    # Test assignment to a negative index.
841    proto.repeated_int32[-1] = 2
842    self.assertEqual(2, proto.repeated_int32[-1])
843
844    # Test deletion at negative indices.
845    proto.repeated_int32[:] = [0, 1, 2, 3]
846    del proto.repeated_int32[-1]
847    self.assertEqual([0, 1, 2], proto.repeated_int32)
848
849    del proto.repeated_int32[-2]
850    self.assertEqual([0, 2], proto.repeated_int32)
851
852    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
853    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
854
855    del proto.repeated_int32[-2:-1]
856    self.assertEqual([2], proto.repeated_int32)
857
858    del proto.repeated_int32[100:10000]
859    self.assertEqual([2], proto.repeated_int32)
860
861  def testRepeatedScalarsRemove(self):
862    proto = unittest_pb2.TestAllTypes()
863
864    self.assertTrue(not proto.repeated_int32)
865    self.assertEqual(0, len(proto.repeated_int32))
866    proto.repeated_int32.append(5)
867    proto.repeated_int32.append(10)
868    proto.repeated_int32.append(5)
869    proto.repeated_int32.append(5)
870
871    self.assertEqual(4, len(proto.repeated_int32))
872    proto.repeated_int32.remove(5)
873    self.assertEqual(3, len(proto.repeated_int32))
874    self.assertEqual(10, proto.repeated_int32[0])
875    self.assertEqual(5, proto.repeated_int32[1])
876    self.assertEqual(5, proto.repeated_int32[2])
877
878    proto.repeated_int32.remove(5)
879    self.assertEqual(2, len(proto.repeated_int32))
880    self.assertEqual(10, proto.repeated_int32[0])
881    self.assertEqual(5, proto.repeated_int32[1])
882
883    proto.repeated_int32.remove(10)
884    self.assertEqual(1, len(proto.repeated_int32))
885    self.assertEqual(5, proto.repeated_int32[0])
886
887    # Remove a non-existent element.
888    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
889
890  def testRepeatedComposites(self):
891    proto = unittest_pb2.TestAllTypes()
892    self.assertTrue(not proto.repeated_nested_message)
893    self.assertEqual(0, len(proto.repeated_nested_message))
894    m0 = proto.repeated_nested_message.add()
895    m1 = proto.repeated_nested_message.add()
896    self.assertTrue(proto.repeated_nested_message)
897    self.assertEqual(2, len(proto.repeated_nested_message))
898    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
899    self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
900
901    # Test out-of-bounds indices.
902    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
903                      1234)
904    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
905                      -1234)
906
907    # Test incorrect types passed to __getitem__.
908    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
909                      'foo')
910    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
911                      None)
912
913    # Test slice retrieval.
914    m2 = proto.repeated_nested_message.add()
915    m3 = proto.repeated_nested_message.add()
916    m4 = proto.repeated_nested_message.add()
917    self.assertListsEqual(
918        [m1, m2, m3], proto.repeated_nested_message[1:4])
919    self.assertListsEqual(
920        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
921    self.assertListsEqual(
922        [m0, m1], proto.repeated_nested_message[:2])
923    self.assertListsEqual(
924        [m2, m3, m4], proto.repeated_nested_message[2:])
925    self.assertEqual(
926        m0, proto.repeated_nested_message[0])
927    self.assertListsEqual(
928        [m0], proto.repeated_nested_message[:1])
929
930    # Test that we can use the field as an iterator.
931    result = []
932    for i in proto.repeated_nested_message:
933      result.append(i)
934    self.assertListsEqual([m0, m1, m2, m3, m4], result)
935
936    # Test single deletion.
937    del proto.repeated_nested_message[2]
938    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
939
940    # Test slice deletion.
941    del proto.repeated_nested_message[2:]
942    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
943
944    # Test extending.
945    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
946    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
947    proto.repeated_nested_message.extend([n1,n2])
948    self.assertEqual(4, len(proto.repeated_nested_message))
949    self.assertEqual(n1, proto.repeated_nested_message[2])
950    self.assertEqual(n2, proto.repeated_nested_message[3])
951
952    # Test clearing.
953    proto.ClearField('repeated_nested_message')
954    self.assertTrue(not proto.repeated_nested_message)
955    self.assertEqual(0, len(proto.repeated_nested_message))
956
957    # Test constructing an element while adding it.
958    proto.repeated_nested_message.add(bb=23)
959    self.assertEqual(1, len(proto.repeated_nested_message))
960    self.assertEqual(23, proto.repeated_nested_message[0].bb)
961
962  def testRepeatedCompositeRemove(self):
963    proto = unittest_pb2.TestAllTypes()
964
965    self.assertEqual(0, len(proto.repeated_nested_message))
966    m0 = proto.repeated_nested_message.add()
967    # Need to set some differentiating variable so m0 != m1 != m2:
968    m0.bb = len(proto.repeated_nested_message)
969    m1 = proto.repeated_nested_message.add()
970    m1.bb = len(proto.repeated_nested_message)
971    self.assertTrue(m0 != m1)
972    m2 = proto.repeated_nested_message.add()
973    m2.bb = len(proto.repeated_nested_message)
974    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
975
976    self.assertEqual(3, len(proto.repeated_nested_message))
977    proto.repeated_nested_message.remove(m0)
978    self.assertEqual(2, len(proto.repeated_nested_message))
979    self.assertEqual(m1, proto.repeated_nested_message[0])
980    self.assertEqual(m2, proto.repeated_nested_message[1])
981
982    # Removing m0 again or removing None should raise error
983    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
984    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
985    self.assertEqual(2, len(proto.repeated_nested_message))
986
987    proto.repeated_nested_message.remove(m2)
988    self.assertEqual(1, len(proto.repeated_nested_message))
989    self.assertEqual(m1, proto.repeated_nested_message[0])
990
991  def testHandWrittenReflection(self):
992    # Hand written extensions are only supported by the pure-Python
993    # implementation of the API.
994    if api_implementation.Type() != 'python':
995      return
996
997    FieldDescriptor = descriptor.FieldDescriptor
998    foo_field_descriptor = FieldDescriptor(
999        name='foo_field', full_name='MyProto.foo_field',
1000        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1001        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1002        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1003        containing_type=None, message_type=None, enum_type=None,
1004        is_extension=False, extension_scope=None,
1005        options=descriptor_pb2.FieldOptions())
1006    mydescriptor = descriptor.Descriptor(
1007        name='MyProto', full_name='MyProto', filename='ignored',
1008        containing_type=None, nested_types=[], enum_types=[],
1009        fields=[foo_field_descriptor], extensions=[],
1010        options=descriptor_pb2.MessageOptions())
1011    class MyProtoClass(message.Message):
1012      DESCRIPTOR = mydescriptor
1013      __metaclass__ = reflection.GeneratedProtocolMessageType
1014    myproto_instance = MyProtoClass()
1015    self.assertEqual(0, myproto_instance.foo_field)
1016    self.assertTrue(not myproto_instance.HasField('foo_field'))
1017    myproto_instance.foo_field = 23
1018    self.assertEqual(23, myproto_instance.foo_field)
1019    self.assertTrue(myproto_instance.HasField('foo_field'))
1020
1021  def testDescriptorProtoSupport(self):
1022    # Hand written descriptors/reflection are only supported by the pure-Python
1023    # implementation of the API.
1024    if api_implementation.Type() != 'python':
1025      return
1026
1027    def AddDescriptorField(proto, field_name, field_type):
1028      AddDescriptorField.field_index += 1
1029      new_field = proto.field.add()
1030      new_field.name = field_name
1031      new_field.type = field_type
1032      new_field.number = AddDescriptorField.field_index
1033      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1034
1035    AddDescriptorField.field_index = 0
1036
1037    desc_proto = descriptor_pb2.DescriptorProto()
1038    desc_proto.name = 'Car'
1039    fdp = descriptor_pb2.FieldDescriptorProto
1040    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1041    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1042    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1043    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1044    # Add a repeated field
1045    AddDescriptorField.field_index += 1
1046    new_field = desc_proto.field.add()
1047    new_field.name = 'owners'
1048    new_field.type = fdp.TYPE_STRING
1049    new_field.number = AddDescriptorField.field_index
1050    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1051
1052    desc = descriptor.MakeDescriptor(desc_proto)
1053    self.assertTrue(desc.fields_by_name.has_key('name'))
1054    self.assertTrue(desc.fields_by_name.has_key('year'))
1055    self.assertTrue(desc.fields_by_name.has_key('automatic'))
1056    self.assertTrue(desc.fields_by_name.has_key('price'))
1057    self.assertTrue(desc.fields_by_name.has_key('owners'))
1058
1059    class CarMessage(message.Message):
1060      __metaclass__ = reflection.GeneratedProtocolMessageType
1061      DESCRIPTOR = desc
1062
1063    prius = CarMessage()
1064    prius.name = 'prius'
1065    prius.year = 2010
1066    prius.automatic = True
1067    prius.price = 25134.75
1068    prius.owners.extend(['bob', 'susan'])
1069
1070    serialized_prius = prius.SerializeToString()
1071    new_prius = reflection.ParseMessage(desc, serialized_prius)
1072    self.assertTrue(new_prius is not prius)
1073    self.assertEqual(prius, new_prius)
1074
1075    # these are unnecessary assuming message equality works as advertised but
1076    # explicitly check to be safe since we're mucking about in metaclass foo
1077    self.assertEqual(prius.name, new_prius.name)
1078    self.assertEqual(prius.year, new_prius.year)
1079    self.assertEqual(prius.automatic, new_prius.automatic)
1080    self.assertEqual(prius.price, new_prius.price)
1081    self.assertEqual(prius.owners, new_prius.owners)
1082
1083  def testTopLevelExtensionsForOptionalScalar(self):
1084    extendee_proto = unittest_pb2.TestAllExtensions()
1085    extension = unittest_pb2.optional_int32_extension
1086    self.assertTrue(not extendee_proto.HasExtension(extension))
1087    self.assertEqual(0, extendee_proto.Extensions[extension])
1088    # As with normal scalar fields, just doing a read doesn't actually set the
1089    # "has" bit.
1090    self.assertTrue(not extendee_proto.HasExtension(extension))
1091    # Actually set the thing.
1092    extendee_proto.Extensions[extension] = 23
1093    self.assertEqual(23, extendee_proto.Extensions[extension])
1094    self.assertTrue(extendee_proto.HasExtension(extension))
1095    # Ensure that clearing works as well.
1096    extendee_proto.ClearExtension(extension)
1097    self.assertEqual(0, extendee_proto.Extensions[extension])
1098    self.assertTrue(not extendee_proto.HasExtension(extension))
1099
1100  def testTopLevelExtensionsForRepeatedScalar(self):
1101    extendee_proto = unittest_pb2.TestAllExtensions()
1102    extension = unittest_pb2.repeated_string_extension
1103    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1104    extendee_proto.Extensions[extension].append('foo')
1105    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1106    string_list = extendee_proto.Extensions[extension]
1107    extendee_proto.ClearExtension(extension)
1108    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1109    self.assertTrue(string_list is not extendee_proto.Extensions[extension])
1110    # Shouldn't be allowed to do Extensions[extension] = 'a'
1111    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1112                      extension, 'a')
1113
1114  def testTopLevelExtensionsForOptionalMessage(self):
1115    extendee_proto = unittest_pb2.TestAllExtensions()
1116    extension = unittest_pb2.optional_foreign_message_extension
1117    self.assertTrue(not extendee_proto.HasExtension(extension))
1118    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1119    # As with normal (non-extension) fields, merely reading from the
1120    # thing shouldn't set the "has" bit.
1121    self.assertTrue(not extendee_proto.HasExtension(extension))
1122    extendee_proto.Extensions[extension].c = 23
1123    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1124    self.assertTrue(extendee_proto.HasExtension(extension))
1125    # Save a reference here.
1126    foreign_message = extendee_proto.Extensions[extension]
1127    extendee_proto.ClearExtension(extension)
1128    self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
1129    # Setting a field on foreign_message now shouldn't set
1130    # any "has" bits on extendee_proto.
1131    foreign_message.c = 42
1132    self.assertEqual(42, foreign_message.c)
1133    self.assertTrue(foreign_message.HasField('c'))
1134    self.assertTrue(not extendee_proto.HasExtension(extension))
1135    # Shouldn't be allowed to do Extensions[extension] = 'a'
1136    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1137                      extension, 'a')
1138
1139  def testTopLevelExtensionsForRepeatedMessage(self):
1140    extendee_proto = unittest_pb2.TestAllExtensions()
1141    extension = unittest_pb2.repeatedgroup_extension
1142    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1143    group = extendee_proto.Extensions[extension].add()
1144    group.a = 23
1145    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1146    group.a = 42
1147    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1148    group_list = extendee_proto.Extensions[extension]
1149    extendee_proto.ClearExtension(extension)
1150    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1151    self.assertTrue(group_list is not extendee_proto.Extensions[extension])
1152    # Shouldn't be allowed to do Extensions[extension] = 'a'
1153    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1154                      extension, 'a')
1155
1156  def testNestedExtensions(self):
1157    extendee_proto = unittest_pb2.TestAllExtensions()
1158    extension = unittest_pb2.TestRequired.single
1159
1160    # We just test the non-repeated case.
1161    self.assertTrue(not extendee_proto.HasExtension(extension))
1162    required = extendee_proto.Extensions[extension]
1163    self.assertEqual(0, required.a)
1164    self.assertTrue(not extendee_proto.HasExtension(extension))
1165    required.a = 23
1166    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1167    self.assertTrue(extendee_proto.HasExtension(extension))
1168    extendee_proto.ClearExtension(extension)
1169    self.assertTrue(required is not extendee_proto.Extensions[extension])
1170    self.assertTrue(not extendee_proto.HasExtension(extension))
1171
1172  def testRegisteredExtensions(self):
1173    self.assertTrue('protobuf_unittest.optional_int32_extension' in
1174                    unittest_pb2.TestAllExtensions._extensions_by_name)
1175    self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number)
1176    # Make sure extensions haven't been registered into types that shouldn't
1177    # have any.
1178    self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name))
1179
1180  # If message A directly contains message B, and
1181  # a.HasField('b') is currently False, then mutating any
1182  # extension in B should change a.HasField('b') to True
1183  # (and so on up the object tree).
1184  def testHasBitsForAncestorsOfExtendedMessage(self):
1185    # Optional scalar extension.
1186    toplevel = more_extensions_pb2.TopLevelMessage()
1187    self.assertTrue(not toplevel.HasField('submessage'))
1188    self.assertEqual(0, toplevel.submessage.Extensions[
1189        more_extensions_pb2.optional_int_extension])
1190    self.assertTrue(not toplevel.HasField('submessage'))
1191    toplevel.submessage.Extensions[
1192        more_extensions_pb2.optional_int_extension] = 23
1193    self.assertEqual(23, toplevel.submessage.Extensions[
1194        more_extensions_pb2.optional_int_extension])
1195    self.assertTrue(toplevel.HasField('submessage'))
1196
1197    # Repeated scalar extension.
1198    toplevel = more_extensions_pb2.TopLevelMessage()
1199    self.assertTrue(not toplevel.HasField('submessage'))
1200    self.assertEqual([], toplevel.submessage.Extensions[
1201        more_extensions_pb2.repeated_int_extension])
1202    self.assertTrue(not toplevel.HasField('submessage'))
1203    toplevel.submessage.Extensions[
1204        more_extensions_pb2.repeated_int_extension].append(23)
1205    self.assertEqual([23], toplevel.submessage.Extensions[
1206        more_extensions_pb2.repeated_int_extension])
1207    self.assertTrue(toplevel.HasField('submessage'))
1208
1209    # Optional message extension.
1210    toplevel = more_extensions_pb2.TopLevelMessage()
1211    self.assertTrue(not toplevel.HasField('submessage'))
1212    self.assertEqual(0, toplevel.submessage.Extensions[
1213        more_extensions_pb2.optional_message_extension].foreign_message_int)
1214    self.assertTrue(not toplevel.HasField('submessage'))
1215    toplevel.submessage.Extensions[
1216        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1217    self.assertEqual(23, toplevel.submessage.Extensions[
1218        more_extensions_pb2.optional_message_extension].foreign_message_int)
1219    self.assertTrue(toplevel.HasField('submessage'))
1220
1221    # Repeated message extension.
1222    toplevel = more_extensions_pb2.TopLevelMessage()
1223    self.assertTrue(not toplevel.HasField('submessage'))
1224    self.assertEqual(0, len(toplevel.submessage.Extensions[
1225        more_extensions_pb2.repeated_message_extension]))
1226    self.assertTrue(not toplevel.HasField('submessage'))
1227    foreign = toplevel.submessage.Extensions[
1228        more_extensions_pb2.repeated_message_extension].add()
1229    self.assertEqual(foreign, toplevel.submessage.Extensions[
1230        more_extensions_pb2.repeated_message_extension][0])
1231    self.assertTrue(toplevel.HasField('submessage'))
1232
1233  def testDisconnectionAfterClearingEmptyMessage(self):
1234    toplevel = more_extensions_pb2.TopLevelMessage()
1235    extendee_proto = toplevel.submessage
1236    extension = more_extensions_pb2.optional_message_extension
1237    extension_proto = extendee_proto.Extensions[extension]
1238    extendee_proto.ClearExtension(extension)
1239    extension_proto.foreign_message_int = 23
1240
1241    self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
1242
1243  def testExtensionFailureModes(self):
1244    extendee_proto = unittest_pb2.TestAllExtensions()
1245
1246    # Try non-extension-handle arguments to HasExtension,
1247    # ClearExtension(), and Extensions[]...
1248    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1249    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1250    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1251    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1252
1253    # Try something that *is* an extension handle, just not for
1254    # this message...
1255    unknown_handle = more_extensions_pb2.optional_int_extension
1256    self.assertRaises(KeyError, extendee_proto.HasExtension,
1257                      unknown_handle)
1258    self.assertRaises(KeyError, extendee_proto.ClearExtension,
1259                      unknown_handle)
1260    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1261                      unknown_handle)
1262    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1263                      unknown_handle, 5)
1264
1265    # Try call HasExtension() with a valid handle, but for a
1266    # *repeated* field.  (Just as with non-extension repeated
1267    # fields, Has*() isn't supported for extension repeated fields).
1268    self.assertRaises(KeyError, extendee_proto.HasExtension,
1269                      unittest_pb2.repeated_string_extension)
1270
1271  def testStaticParseFrom(self):
1272    proto1 = unittest_pb2.TestAllTypes()
1273    test_util.SetAllFields(proto1)
1274
1275    string1 = proto1.SerializeToString()
1276    proto2 = unittest_pb2.TestAllTypes.FromString(string1)
1277
1278    # Messages should be equal.
1279    self.assertEqual(proto2, proto1)
1280
1281  def testMergeFromSingularField(self):
1282    # Test merge with just a singular field.
1283    proto1 = unittest_pb2.TestAllTypes()
1284    proto1.optional_int32 = 1
1285
1286    proto2 = unittest_pb2.TestAllTypes()
1287    # This shouldn't get overwritten.
1288    proto2.optional_string = 'value'
1289
1290    proto2.MergeFrom(proto1)
1291    self.assertEqual(1, proto2.optional_int32)
1292    self.assertEqual('value', proto2.optional_string)
1293
1294  def testMergeFromRepeatedField(self):
1295    # Test merge with just a repeated field.
1296    proto1 = unittest_pb2.TestAllTypes()
1297    proto1.repeated_int32.append(1)
1298    proto1.repeated_int32.append(2)
1299
1300    proto2 = unittest_pb2.TestAllTypes()
1301    proto2.repeated_int32.append(0)
1302    proto2.MergeFrom(proto1)
1303
1304    self.assertEqual(0, proto2.repeated_int32[0])
1305    self.assertEqual(1, proto2.repeated_int32[1])
1306    self.assertEqual(2, proto2.repeated_int32[2])
1307
1308  def testMergeFromOptionalGroup(self):
1309    # Test merge with an optional group.
1310    proto1 = unittest_pb2.TestAllTypes()
1311    proto1.optionalgroup.a = 12
1312    proto2 = unittest_pb2.TestAllTypes()
1313    proto2.MergeFrom(proto1)
1314    self.assertEqual(12, proto2.optionalgroup.a)
1315
1316  def testMergeFromRepeatedNestedMessage(self):
1317    # Test merge with a repeated nested message.
1318    proto1 = unittest_pb2.TestAllTypes()
1319    m = proto1.repeated_nested_message.add()
1320    m.bb = 123
1321    m = proto1.repeated_nested_message.add()
1322    m.bb = 321
1323
1324    proto2 = unittest_pb2.TestAllTypes()
1325    m = proto2.repeated_nested_message.add()
1326    m.bb = 999
1327    proto2.MergeFrom(proto1)
1328    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
1329    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
1330    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
1331
1332    proto3 = unittest_pb2.TestAllTypes()
1333    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
1334    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
1335    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
1336    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
1337
1338  def testMergeFromAllFields(self):
1339    # With all fields set.
1340    proto1 = unittest_pb2.TestAllTypes()
1341    test_util.SetAllFields(proto1)
1342    proto2 = unittest_pb2.TestAllTypes()
1343    proto2.MergeFrom(proto1)
1344
1345    # Messages should be equal.
1346    self.assertEqual(proto2, proto1)
1347
1348    # Serialized string should be equal too.
1349    string1 = proto1.SerializeToString()
1350    string2 = proto2.SerializeToString()
1351    self.assertEqual(string1, string2)
1352
1353  def testMergeFromExtensionsSingular(self):
1354    proto1 = unittest_pb2.TestAllExtensions()
1355    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1356
1357    proto2 = unittest_pb2.TestAllExtensions()
1358    proto2.MergeFrom(proto1)
1359    self.assertEqual(
1360        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1361
1362  def testMergeFromExtensionsRepeated(self):
1363    proto1 = unittest_pb2.TestAllExtensions()
1364    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1365    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1366
1367    proto2 = unittest_pb2.TestAllExtensions()
1368    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1369    proto2.MergeFrom(proto1)
1370    self.assertEqual(
1371        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1372    self.assertEqual(
1373        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1374    self.assertEqual(
1375        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1376    self.assertEqual(
1377        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1378
1379  def testMergeFromExtensionsNestedMessage(self):
1380    proto1 = unittest_pb2.TestAllExtensions()
1381    ext1 = proto1.Extensions[
1382        unittest_pb2.repeated_nested_message_extension]
1383    m = ext1.add()
1384    m.bb = 222
1385    m = ext1.add()
1386    m.bb = 333
1387
1388    proto2 = unittest_pb2.TestAllExtensions()
1389    ext2 = proto2.Extensions[
1390        unittest_pb2.repeated_nested_message_extension]
1391    m = ext2.add()
1392    m.bb = 111
1393
1394    proto2.MergeFrom(proto1)
1395    ext2 = proto2.Extensions[
1396        unittest_pb2.repeated_nested_message_extension]
1397    self.assertEqual(3, len(ext2))
1398    self.assertEqual(111, ext2[0].bb)
1399    self.assertEqual(222, ext2[1].bb)
1400    self.assertEqual(333, ext2[2].bb)
1401
1402  def testMergeFromBug(self):
1403    message1 = unittest_pb2.TestAllTypes()
1404    message2 = unittest_pb2.TestAllTypes()
1405
1406    # Cause optional_nested_message to be instantiated within message1, even
1407    # though it is not considered to be "present".
1408    message1.optional_nested_message
1409    self.assertFalse(message1.HasField('optional_nested_message'))
1410
1411    # Merge into message2.  This should not instantiate the field is message2.
1412    message2.MergeFrom(message1)
1413    self.assertFalse(message2.HasField('optional_nested_message'))
1414
1415  def testCopyFromSingularField(self):
1416    # Test copy with just a singular field.
1417    proto1 = unittest_pb2.TestAllTypes()
1418    proto1.optional_int32 = 1
1419    proto1.optional_string = 'important-text'
1420
1421    proto2 = unittest_pb2.TestAllTypes()
1422    proto2.optional_string = 'value'
1423
1424    proto2.CopyFrom(proto1)
1425    self.assertEqual(1, proto2.optional_int32)
1426    self.assertEqual('important-text', proto2.optional_string)
1427
1428  def testCopyFromRepeatedField(self):
1429    # Test copy with a repeated field.
1430    proto1 = unittest_pb2.TestAllTypes()
1431    proto1.repeated_int32.append(1)
1432    proto1.repeated_int32.append(2)
1433
1434    proto2 = unittest_pb2.TestAllTypes()
1435    proto2.repeated_int32.append(0)
1436    proto2.CopyFrom(proto1)
1437
1438    self.assertEqual(1, proto2.repeated_int32[0])
1439    self.assertEqual(2, proto2.repeated_int32[1])
1440
1441  def testCopyFromAllFields(self):
1442    # With all fields set.
1443    proto1 = unittest_pb2.TestAllTypes()
1444    test_util.SetAllFields(proto1)
1445    proto2 = unittest_pb2.TestAllTypes()
1446    proto2.CopyFrom(proto1)
1447
1448    # Messages should be equal.
1449    self.assertEqual(proto2, proto1)
1450
1451    # Serialized string should be equal too.
1452    string1 = proto1.SerializeToString()
1453    string2 = proto2.SerializeToString()
1454    self.assertEqual(string1, string2)
1455
1456  def testCopyFromSelf(self):
1457    proto1 = unittest_pb2.TestAllTypes()
1458    proto1.repeated_int32.append(1)
1459    proto1.optional_int32 = 2
1460    proto1.optional_string = 'important-text'
1461
1462    proto1.CopyFrom(proto1)
1463    self.assertEqual(1, proto1.repeated_int32[0])
1464    self.assertEqual(2, proto1.optional_int32)
1465    self.assertEqual('important-text', proto1.optional_string)
1466
1467  def testCopyFromBadType(self):
1468    # The python implementation doesn't raise an exception in this
1469    # case. In theory it should.
1470    if api_implementation.Type() == 'python':
1471      return
1472    proto1 = unittest_pb2.TestAllTypes()
1473    proto2 = unittest_pb2.TestAllExtensions()
1474    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1475
1476  def testDeepCopy(self):
1477    proto1 = unittest_pb2.TestAllTypes()
1478    proto1.optional_int32 = 1
1479    proto2 = copy.deepcopy(proto1)
1480    self.assertEqual(1, proto2.optional_int32)
1481
1482    proto1.repeated_int32.append(2)
1483    proto1.repeated_int32.append(3)
1484    container = copy.deepcopy(proto1.repeated_int32)
1485    self.assertEqual([2, 3], container)
1486
1487    # TODO(anuraag): Implement deepcopy for repeated composite / extension dict
1488
1489  def testClear(self):
1490    proto = unittest_pb2.TestAllTypes()
1491    # C++ implementation does not support lazy fields right now so leave it
1492    # out for now.
1493    if api_implementation.Type() == 'python':
1494      test_util.SetAllFields(proto)
1495    else:
1496      test_util.SetAllNonLazyFields(proto)
1497    # Clear the message.
1498    proto.Clear()
1499    self.assertEquals(proto.ByteSize(), 0)
1500    empty_proto = unittest_pb2.TestAllTypes()
1501    self.assertEquals(proto, empty_proto)
1502
1503    # Test if extensions which were set are cleared.
1504    proto = unittest_pb2.TestAllExtensions()
1505    test_util.SetAllExtensions(proto)
1506    # Clear the message.
1507    proto.Clear()
1508    self.assertEquals(proto.ByteSize(), 0)
1509    empty_proto = unittest_pb2.TestAllExtensions()
1510    self.assertEquals(proto, empty_proto)
1511
1512  def testDisconnectingBeforeClear(self):
1513    proto = unittest_pb2.TestAllTypes()
1514    nested = proto.optional_nested_message
1515    proto.Clear()
1516    self.assertTrue(nested is not proto.optional_nested_message)
1517    nested.bb = 23
1518    self.assertTrue(not proto.HasField('optional_nested_message'))
1519    self.assertEqual(0, proto.optional_nested_message.bb)
1520
1521    proto = unittest_pb2.TestAllTypes()
1522    nested = proto.optional_nested_message
1523    nested.bb = 5
1524    foreign = proto.optional_foreign_message
1525    foreign.c = 6
1526
1527    proto.Clear()
1528    self.assertTrue(nested is not proto.optional_nested_message)
1529    self.assertTrue(foreign is not proto.optional_foreign_message)
1530    self.assertEqual(5, nested.bb)
1531    self.assertEqual(6, foreign.c)
1532    nested.bb = 15
1533    foreign.c = 16
1534    self.assertFalse(proto.HasField('optional_nested_message'))
1535    self.assertEqual(0, proto.optional_nested_message.bb)
1536    self.assertFalse(proto.HasField('optional_foreign_message'))
1537    self.assertEqual(0, proto.optional_foreign_message.c)
1538
1539  def testOneOf(self):
1540    proto = unittest_pb2.TestAllTypes()
1541    proto.oneof_uint32 = 10
1542    proto.oneof_nested_message.bb = 11
1543    self.assertEqual(11, proto.oneof_nested_message.bb)
1544    self.assertFalse(proto.HasField('oneof_uint32'))
1545    nested = proto.oneof_nested_message
1546    proto.oneof_string = 'abc'
1547    self.assertEqual('abc', proto.oneof_string)
1548    self.assertEqual(11, nested.bb)
1549    self.assertFalse(proto.HasField('oneof_nested_message'))
1550
1551  def assertInitialized(self, proto):
1552    self.assertTrue(proto.IsInitialized())
1553    # Neither method should raise an exception.
1554    proto.SerializeToString()
1555    proto.SerializePartialToString()
1556
1557  def assertNotInitialized(self, proto):
1558    self.assertFalse(proto.IsInitialized())
1559    self.assertRaises(message.EncodeError, proto.SerializeToString)
1560    # "Partial" serialization doesn't care if message is uninitialized.
1561    proto.SerializePartialToString()
1562
1563  def testIsInitialized(self):
1564    # Trivial cases - all optional fields and extensions.
1565    proto = unittest_pb2.TestAllTypes()
1566    self.assertInitialized(proto)
1567    proto = unittest_pb2.TestAllExtensions()
1568    self.assertInitialized(proto)
1569
1570    # The case of uninitialized required fields.
1571    proto = unittest_pb2.TestRequired()
1572    self.assertNotInitialized(proto)
1573    proto.a = proto.b = proto.c = 2
1574    self.assertInitialized(proto)
1575
1576    # The case of uninitialized submessage.
1577    proto = unittest_pb2.TestRequiredForeign()
1578    self.assertInitialized(proto)
1579    proto.optional_message.a = 1
1580    self.assertNotInitialized(proto)
1581    proto.optional_message.b = 0
1582    proto.optional_message.c = 0
1583    self.assertInitialized(proto)
1584
1585    # Uninitialized repeated submessage.
1586    message1 = proto.repeated_message.add()
1587    self.assertNotInitialized(proto)
1588    message1.a = message1.b = message1.c = 0
1589    self.assertInitialized(proto)
1590
1591    # Uninitialized repeated group in an extension.
1592    proto = unittest_pb2.TestAllExtensions()
1593    extension = unittest_pb2.TestRequired.multi
1594    message1 = proto.Extensions[extension].add()
1595    message2 = proto.Extensions[extension].add()
1596    self.assertNotInitialized(proto)
1597    message1.a = 1
1598    message1.b = 1
1599    message1.c = 1
1600    self.assertNotInitialized(proto)
1601    message2.a = 2
1602    message2.b = 2
1603    message2.c = 2
1604    self.assertInitialized(proto)
1605
1606    # Uninitialized nonrepeated message in an extension.
1607    proto = unittest_pb2.TestAllExtensions()
1608    extension = unittest_pb2.TestRequired.single
1609    proto.Extensions[extension].a = 1
1610    self.assertNotInitialized(proto)
1611    proto.Extensions[extension].b = 2
1612    proto.Extensions[extension].c = 3
1613    self.assertInitialized(proto)
1614
1615    # Try passing an errors list.
1616    errors = []
1617    proto = unittest_pb2.TestRequired()
1618    self.assertFalse(proto.IsInitialized(errors))
1619    self.assertEqual(errors, ['a', 'b', 'c'])
1620
1621  @basetest.unittest.skipIf(
1622      api_implementation.Type() != 'cpp' or api_implementation.Version() != 2,
1623      'Errors are only available from the most recent C++ implementation.')
1624  def testFileDescriptorErrors(self):
1625    file_name = 'test_file_descriptor_errors.proto'
1626    package_name = 'test_file_descriptor_errors.proto'
1627    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
1628    file_descriptor_proto.name = file_name
1629    file_descriptor_proto.package = package_name
1630    m1 = file_descriptor_proto.message_type.add()
1631    m1.name = 'msg1'
1632    # Compiles the proto into the C++ descriptor pool
1633    descriptor.FileDescriptor(
1634        file_name,
1635        package_name,
1636        serialized_pb=file_descriptor_proto.SerializeToString())
1637    # Add a FileDescriptorProto that has duplicate symbols
1638    another_file_name = 'another_test_file_descriptor_errors.proto'
1639    file_descriptor_proto.name = another_file_name
1640    m2 = file_descriptor_proto.message_type.add()
1641    m2.name = 'msg2'
1642    with self.assertRaises(TypeError) as cm:
1643      descriptor.FileDescriptor(
1644          another_file_name,
1645          package_name,
1646          serialized_pb=file_descriptor_proto.SerializeToString())
1647      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
1648                      getattr(cm.expected, '__name__', cm.expected))
1649      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
1650      # Error message will say something about this definition being a
1651      # duplicate, though we don't check the message exactly to avoid a
1652      # dependency on the C++ logging code.
1653      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
1654
1655  def testStringUTF8Encoding(self):
1656    proto = unittest_pb2.TestAllTypes()
1657
1658    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
1659    self.assertRaises(TypeError,
1660                      setattr, proto, 'optional_bytes', u'unicode object')
1661
1662    # Check that the default value is of python's 'unicode' type.
1663    self.assertEqual(type(proto.optional_string), unicode)
1664
1665    proto.optional_string = unicode('Testing')
1666    self.assertEqual(proto.optional_string, str('Testing'))
1667
1668    # Assign a value of type 'str' which can be encoded in UTF-8.
1669    proto.optional_string = str('Testing')
1670    self.assertEqual(proto.optional_string, unicode('Testing'))
1671
1672    # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
1673    self.assertRaises(ValueError,
1674                      setattr, proto, 'optional_string', b'a\x80a')
1675    if str is bytes:  # PY2
1676      # Assign a 'str' object which contains a UTF-8 encoded string.
1677      self.assertRaises(ValueError,
1678                        setattr, proto, 'optional_string', 'Тест')
1679    else:
1680      proto.optional_string = 'Тест'
1681    # No exception thrown.
1682    proto.optional_string = 'abc'
1683
1684  def testStringUTF8Serialization(self):
1685    proto = unittest_mset_pb2.TestMessageSet()
1686    extension_message = unittest_mset_pb2.TestMessageSetExtension2
1687    extension = extension_message.message_set_extension
1688
1689    test_utf8 = u'Тест'
1690    test_utf8_bytes = test_utf8.encode('utf-8')
1691
1692    # 'Test' in another language, using UTF-8 charset.
1693    proto.Extensions[extension].str = test_utf8
1694
1695    # Serialize using the MessageSet wire format (this is specified in the
1696    # .proto file).
1697    serialized = proto.SerializeToString()
1698
1699    # Check byte size.
1700    self.assertEqual(proto.ByteSize(), len(serialized))
1701
1702    raw = unittest_mset_pb2.RawMessageSet()
1703    bytes_read = raw.MergeFromString(serialized)
1704    self.assertEqual(len(serialized), bytes_read)
1705
1706    message2 = unittest_mset_pb2.TestMessageSetExtension2()
1707
1708    self.assertEqual(1, len(raw.item))
1709    # Check that the type_id is the same as the tag ID in the .proto file.
1710    self.assertEqual(raw.item[0].type_id, 1547769)
1711
1712    # Check the actual bytes on the wire.
1713    self.assertTrue(
1714        raw.item[0].message.endswith(test_utf8_bytes))
1715    bytes_read = message2.MergeFromString(raw.item[0].message)
1716    self.assertEqual(len(raw.item[0].message), bytes_read)
1717
1718    self.assertEqual(type(message2.str), unicode)
1719    self.assertEqual(message2.str, test_utf8)
1720
1721    # The pure Python API throws an exception on MergeFromString(),
1722    # if any of the string fields of the message can't be UTF-8 decoded.
1723    # The C++ implementation of the API has no way to check that on
1724    # MergeFromString and thus has no way to throw the exception.
1725    #
1726    # The pure Python API always returns objects of type 'unicode' (UTF-8
1727    # encoded), or 'bytes' (in 7 bit ASCII).
1728    badbytes = raw.item[0].message.replace(
1729        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
1730
1731    unicode_decode_failed = False
1732    try:
1733      message2.MergeFromString(badbytes)
1734    except UnicodeDecodeError:
1735      unicode_decode_failed = True
1736    string_field = message2.str
1737    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
1738
1739  def testBytesInTextFormat(self):
1740    proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
1741    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n',
1742                     unicode(proto))
1743
1744  def testEmptyNestedMessage(self):
1745    proto = unittest_pb2.TestAllTypes()
1746    proto.optional_nested_message.MergeFrom(
1747        unittest_pb2.TestAllTypes.NestedMessage())
1748    self.assertTrue(proto.HasField('optional_nested_message'))
1749
1750    proto = unittest_pb2.TestAllTypes()
1751    proto.optional_nested_message.CopyFrom(
1752        unittest_pb2.TestAllTypes.NestedMessage())
1753    self.assertTrue(proto.HasField('optional_nested_message'))
1754
1755    proto = unittest_pb2.TestAllTypes()
1756    bytes_read = proto.optional_nested_message.MergeFromString(b'')
1757    self.assertEqual(0, bytes_read)
1758    self.assertTrue(proto.HasField('optional_nested_message'))
1759
1760    proto = unittest_pb2.TestAllTypes()
1761    proto.optional_nested_message.ParseFromString(b'')
1762    self.assertTrue(proto.HasField('optional_nested_message'))
1763
1764    serialized = proto.SerializeToString()
1765    proto2 = unittest_pb2.TestAllTypes()
1766    self.assertEqual(
1767        len(serialized),
1768        proto2.MergeFromString(serialized))
1769    self.assertTrue(proto2.HasField('optional_nested_message'))
1770
1771  def testSetInParent(self):
1772    proto = unittest_pb2.TestAllTypes()
1773    self.assertFalse(proto.HasField('optionalgroup'))
1774    proto.optionalgroup.SetInParent()
1775    self.assertTrue(proto.HasField('optionalgroup'))
1776
1777
1778#  Since we had so many tests for protocol buffer equality, we broke these out
1779#  into separate TestCase classes.
1780
1781
1782class TestAllTypesEqualityTest(basetest.TestCase):
1783
1784  def setUp(self):
1785    self.first_proto = unittest_pb2.TestAllTypes()
1786    self.second_proto = unittest_pb2.TestAllTypes()
1787
1788  def testNotHashable(self):
1789    self.assertRaises(TypeError, hash, self.first_proto)
1790
1791  def testSelfEquality(self):
1792    self.assertEqual(self.first_proto, self.first_proto)
1793
1794  def testEmptyProtosEqual(self):
1795    self.assertEqual(self.first_proto, self.second_proto)
1796
1797
1798class FullProtosEqualityTest(basetest.TestCase):
1799
1800  """Equality tests using completely-full protos as a starting point."""
1801
1802  def setUp(self):
1803    self.first_proto = unittest_pb2.TestAllTypes()
1804    self.second_proto = unittest_pb2.TestAllTypes()
1805    test_util.SetAllFields(self.first_proto)
1806    test_util.SetAllFields(self.second_proto)
1807
1808  def testNotHashable(self):
1809    self.assertRaises(TypeError, hash, self.first_proto)
1810
1811  def testNoneNotEqual(self):
1812    self.assertNotEqual(self.first_proto, None)
1813    self.assertNotEqual(None, self.second_proto)
1814
1815  def testNotEqualToOtherMessage(self):
1816    third_proto = unittest_pb2.TestRequired()
1817    self.assertNotEqual(self.first_proto, third_proto)
1818    self.assertNotEqual(third_proto, self.second_proto)
1819
1820  def testAllFieldsFilledEquality(self):
1821    self.assertEqual(self.first_proto, self.second_proto)
1822
1823  def testNonRepeatedScalar(self):
1824    # Nonrepeated scalar field change should cause inequality.
1825    self.first_proto.optional_int32 += 1
1826    self.assertNotEqual(self.first_proto, self.second_proto)
1827    # ...as should clearing a field.
1828    self.first_proto.ClearField('optional_int32')
1829    self.assertNotEqual(self.first_proto, self.second_proto)
1830
1831  def testNonRepeatedComposite(self):
1832    # Change a nonrepeated composite field.
1833    self.first_proto.optional_nested_message.bb += 1
1834    self.assertNotEqual(self.first_proto, self.second_proto)
1835    self.first_proto.optional_nested_message.bb -= 1
1836    self.assertEqual(self.first_proto, self.second_proto)
1837    # Clear a field in the nested message.
1838    self.first_proto.optional_nested_message.ClearField('bb')
1839    self.assertNotEqual(self.first_proto, self.second_proto)
1840    self.first_proto.optional_nested_message.bb = (
1841        self.second_proto.optional_nested_message.bb)
1842    self.assertEqual(self.first_proto, self.second_proto)
1843    # Remove the nested message entirely.
1844    self.first_proto.ClearField('optional_nested_message')
1845    self.assertNotEqual(self.first_proto, self.second_proto)
1846
1847  def testRepeatedScalar(self):
1848    # Change a repeated scalar field.
1849    self.first_proto.repeated_int32.append(5)
1850    self.assertNotEqual(self.first_proto, self.second_proto)
1851    self.first_proto.ClearField('repeated_int32')
1852    self.assertNotEqual(self.first_proto, self.second_proto)
1853
1854  def testRepeatedComposite(self):
1855    # Change value within a repeated composite field.
1856    self.first_proto.repeated_nested_message[0].bb += 1
1857    self.assertNotEqual(self.first_proto, self.second_proto)
1858    self.first_proto.repeated_nested_message[0].bb -= 1
1859    self.assertEqual(self.first_proto, self.second_proto)
1860    # Add a value to a repeated composite field.
1861    self.first_proto.repeated_nested_message.add()
1862    self.assertNotEqual(self.first_proto, self.second_proto)
1863    self.second_proto.repeated_nested_message.add()
1864    self.assertEqual(self.first_proto, self.second_proto)
1865
1866  def testNonRepeatedScalarHasBits(self):
1867    # Ensure that we test "has" bits as well as value for
1868    # nonrepeated scalar field.
1869    self.first_proto.ClearField('optional_int32')
1870    self.second_proto.optional_int32 = 0
1871    self.assertNotEqual(self.first_proto, self.second_proto)
1872
1873  def testNonRepeatedCompositeHasBits(self):
1874    # Ensure that we test "has" bits as well as value for
1875    # nonrepeated composite field.
1876    self.first_proto.ClearField('optional_nested_message')
1877    self.second_proto.optional_nested_message.ClearField('bb')
1878    self.assertNotEqual(self.first_proto, self.second_proto)
1879    self.first_proto.optional_nested_message.bb = 0
1880    self.first_proto.optional_nested_message.ClearField('bb')
1881    self.assertEqual(self.first_proto, self.second_proto)
1882
1883
1884class ExtensionEqualityTest(basetest.TestCase):
1885
1886  def testExtensionEquality(self):
1887    first_proto = unittest_pb2.TestAllExtensions()
1888    second_proto = unittest_pb2.TestAllExtensions()
1889    self.assertEqual(first_proto, second_proto)
1890    test_util.SetAllExtensions(first_proto)
1891    self.assertNotEqual(first_proto, second_proto)
1892    test_util.SetAllExtensions(second_proto)
1893    self.assertEqual(first_proto, second_proto)
1894
1895    # Ensure that we check value equality.
1896    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
1897    self.assertNotEqual(first_proto, second_proto)
1898    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
1899    self.assertEqual(first_proto, second_proto)
1900
1901    # Ensure that we also look at "has" bits.
1902    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
1903    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1904    self.assertNotEqual(first_proto, second_proto)
1905    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
1906    self.assertEqual(first_proto, second_proto)
1907
1908    # Ensure that differences in cached values
1909    # don't matter if "has" bits are both false.
1910    first_proto = unittest_pb2.TestAllExtensions()
1911    second_proto = unittest_pb2.TestAllExtensions()
1912    self.assertEqual(
1913        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
1914    self.assertEqual(first_proto, second_proto)
1915
1916
1917class MutualRecursionEqualityTest(basetest.TestCase):
1918
1919  def testEqualityWithMutualRecursion(self):
1920    first_proto = unittest_pb2.TestMutualRecursionA()
1921    second_proto = unittest_pb2.TestMutualRecursionA()
1922    self.assertEqual(first_proto, second_proto)
1923    first_proto.bb.a.bb.optional_int32 = 23
1924    self.assertNotEqual(first_proto, second_proto)
1925    second_proto.bb.a.bb.optional_int32 = 23
1926    self.assertEqual(first_proto, second_proto)
1927
1928
1929class ByteSizeTest(basetest.TestCase):
1930
1931  def setUp(self):
1932    self.proto = unittest_pb2.TestAllTypes()
1933    self.extended_proto = more_extensions_pb2.ExtendedMessage()
1934    self.packed_proto = unittest_pb2.TestPackedTypes()
1935    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
1936
1937  def Size(self):
1938    return self.proto.ByteSize()
1939
1940  def testEmptyMessage(self):
1941    self.assertEqual(0, self.proto.ByteSize())
1942
1943  def testSizedOnKwargs(self):
1944    # Use a separate message to ensure testing right after creation.
1945    proto = unittest_pb2.TestAllTypes()
1946    self.assertEqual(0, proto.ByteSize())
1947    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
1948    # One byte for the tag, one to encode varint 1.
1949    self.assertEqual(2, proto_kwargs.ByteSize())
1950
1951  def testVarints(self):
1952    def Test(i, expected_varint_size):
1953      self.proto.Clear()
1954      self.proto.optional_int64 = i
1955      # Add one to the varint size for the tag info
1956      # for tag 1.
1957      self.assertEqual(expected_varint_size + 1, self.Size())
1958    Test(0, 1)
1959    Test(1, 1)
1960    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
1961      Test((1 << i) - 1, num_bytes)
1962    Test(-1, 10)
1963    Test(-2, 10)
1964    Test(-(1 << 63), 10)
1965
1966  def testStrings(self):
1967    self.proto.optional_string = ''
1968    # Need one byte for tag info (tag #14), and one byte for length.
1969    self.assertEqual(2, self.Size())
1970
1971    self.proto.optional_string = 'abc'
1972    # Need one byte for tag info (tag #14), and one byte for length.
1973    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
1974
1975    self.proto.optional_string = 'x' * 128
1976    # Need one byte for tag info (tag #14), and TWO bytes for length.
1977    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
1978
1979  def testOtherNumerics(self):
1980    self.proto.optional_fixed32 = 1234
1981    # One byte for tag and 4 bytes for fixed32.
1982    self.assertEqual(5, self.Size())
1983    self.proto = unittest_pb2.TestAllTypes()
1984
1985    self.proto.optional_fixed64 = 1234
1986    # One byte for tag and 8 bytes for fixed64.
1987    self.assertEqual(9, self.Size())
1988    self.proto = unittest_pb2.TestAllTypes()
1989
1990    self.proto.optional_float = 1.234
1991    # One byte for tag and 4 bytes for float.
1992    self.assertEqual(5, self.Size())
1993    self.proto = unittest_pb2.TestAllTypes()
1994
1995    self.proto.optional_double = 1.234
1996    # One byte for tag and 8 bytes for float.
1997    self.assertEqual(9, self.Size())
1998    self.proto = unittest_pb2.TestAllTypes()
1999
2000    self.proto.optional_sint32 = 64
2001    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2002    self.assertEqual(3, self.Size())
2003    self.proto = unittest_pb2.TestAllTypes()
2004
2005  def testComposites(self):
2006    # 3 bytes.
2007    self.proto.optional_nested_message.bb = (1 << 14)
2008    # Plus one byte for bb tag.
2009    # Plus 1 byte for optional_nested_message serialized size.
2010    # Plus two bytes for optional_nested_message tag.
2011    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2012
2013  def testGroups(self):
2014    # 4 bytes.
2015    self.proto.optionalgroup.a = (1 << 21)
2016    # Plus two bytes for |a| tag.
2017    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2018    self.assertEqual(4 + 2 + 2*2, self.Size())
2019
2020  def testRepeatedScalars(self):
2021    self.proto.repeated_int32.append(10)  # 1 byte.
2022    self.proto.repeated_int32.append(128)  # 2 bytes.
2023    # Also need 2 bytes for each entry for tag.
2024    self.assertEqual(1 + 2 + 2*2, self.Size())
2025
2026  def testRepeatedScalarsExtend(self):
2027    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2028    # Also need 2 bytes for each entry for tag.
2029    self.assertEqual(1 + 2 + 2*2, self.Size())
2030
2031  def testRepeatedScalarsRemove(self):
2032    self.proto.repeated_int32.append(10)  # 1 byte.
2033    self.proto.repeated_int32.append(128)  # 2 bytes.
2034    # Also need 2 bytes for each entry for tag.
2035    self.assertEqual(1 + 2 + 2*2, self.Size())
2036    self.proto.repeated_int32.remove(128)
2037    self.assertEqual(1 + 2, self.Size())
2038
2039  def testRepeatedComposites(self):
2040    # Empty message.  2 bytes tag plus 1 byte length.
2041    foreign_message_0 = self.proto.repeated_nested_message.add()
2042    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2043    foreign_message_1 = self.proto.repeated_nested_message.add()
2044    foreign_message_1.bb = 7
2045    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2046
2047  def testRepeatedCompositesDelete(self):
2048    # Empty message.  2 bytes tag plus 1 byte length.
2049    foreign_message_0 = self.proto.repeated_nested_message.add()
2050    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2051    foreign_message_1 = self.proto.repeated_nested_message.add()
2052    foreign_message_1.bb = 9
2053    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2054
2055    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2056    del self.proto.repeated_nested_message[0]
2057    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2058
2059    # Now add a new message.
2060    foreign_message_2 = self.proto.repeated_nested_message.add()
2061    foreign_message_2.bb = 12
2062
2063    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2064    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2065    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2066
2067    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2068    del self.proto.repeated_nested_message[1]
2069    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2070
2071    del self.proto.repeated_nested_message[0]
2072    self.assertEqual(0, self.Size())
2073
2074  def testRepeatedGroups(self):
2075    # 2-byte START_GROUP plus 2-byte END_GROUP.
2076    group_0 = self.proto.repeatedgroup.add()
2077    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2078    # plus 2-byte END_GROUP.
2079    group_1 = self.proto.repeatedgroup.add()
2080    group_1.a =  7
2081    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2082
2083  def testExtensions(self):
2084    proto = unittest_pb2.TestAllExtensions()
2085    self.assertEqual(0, proto.ByteSize())
2086    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2087    proto.Extensions[extension] = 23
2088    # 1 byte for tag, 1 byte for value.
2089    self.assertEqual(2, proto.ByteSize())
2090
2091  def testCacheInvalidationForNonrepeatedScalar(self):
2092    # Test non-extension.
2093    self.proto.optional_int32 = 1
2094    self.assertEqual(2, self.proto.ByteSize())
2095    self.proto.optional_int32 = 128
2096    self.assertEqual(3, self.proto.ByteSize())
2097    self.proto.ClearField('optional_int32')
2098    self.assertEqual(0, self.proto.ByteSize())
2099
2100    # Test within extension.
2101    extension = more_extensions_pb2.optional_int_extension
2102    self.extended_proto.Extensions[extension] = 1
2103    self.assertEqual(2, self.extended_proto.ByteSize())
2104    self.extended_proto.Extensions[extension] = 128
2105    self.assertEqual(3, self.extended_proto.ByteSize())
2106    self.extended_proto.ClearExtension(extension)
2107    self.assertEqual(0, self.extended_proto.ByteSize())
2108
2109  def testCacheInvalidationForRepeatedScalar(self):
2110    # Test non-extension.
2111    self.proto.repeated_int32.append(1)
2112    self.assertEqual(3, self.proto.ByteSize())
2113    self.proto.repeated_int32.append(1)
2114    self.assertEqual(6, self.proto.ByteSize())
2115    self.proto.repeated_int32[1] = 128
2116    self.assertEqual(7, self.proto.ByteSize())
2117    self.proto.ClearField('repeated_int32')
2118    self.assertEqual(0, self.proto.ByteSize())
2119
2120    # Test within extension.
2121    extension = more_extensions_pb2.repeated_int_extension
2122    repeated = self.extended_proto.Extensions[extension]
2123    repeated.append(1)
2124    self.assertEqual(2, self.extended_proto.ByteSize())
2125    repeated.append(1)
2126    self.assertEqual(4, self.extended_proto.ByteSize())
2127    repeated[1] = 128
2128    self.assertEqual(5, self.extended_proto.ByteSize())
2129    self.extended_proto.ClearExtension(extension)
2130    self.assertEqual(0, self.extended_proto.ByteSize())
2131
2132  def testCacheInvalidationForNonrepeatedMessage(self):
2133    # Test non-extension.
2134    self.proto.optional_foreign_message.c = 1
2135    self.assertEqual(5, self.proto.ByteSize())
2136    self.proto.optional_foreign_message.c = 128
2137    self.assertEqual(6, self.proto.ByteSize())
2138    self.proto.optional_foreign_message.ClearField('c')
2139    self.assertEqual(3, self.proto.ByteSize())
2140    self.proto.ClearField('optional_foreign_message')
2141    self.assertEqual(0, self.proto.ByteSize())
2142
2143    if api_implementation.Type() == 'python':
2144      # This is only possible in pure-Python implementation of the API.
2145      child = self.proto.optional_foreign_message
2146      self.proto.ClearField('optional_foreign_message')
2147      child.c = 128
2148      self.assertEqual(0, self.proto.ByteSize())
2149
2150    # Test within extension.
2151    extension = more_extensions_pb2.optional_message_extension
2152    child = self.extended_proto.Extensions[extension]
2153    self.assertEqual(0, self.extended_proto.ByteSize())
2154    child.foreign_message_int = 1
2155    self.assertEqual(4, self.extended_proto.ByteSize())
2156    child.foreign_message_int = 128
2157    self.assertEqual(5, self.extended_proto.ByteSize())
2158    self.extended_proto.ClearExtension(extension)
2159    self.assertEqual(0, self.extended_proto.ByteSize())
2160
2161  def testCacheInvalidationForRepeatedMessage(self):
2162    # Test non-extension.
2163    child0 = self.proto.repeated_foreign_message.add()
2164    self.assertEqual(3, self.proto.ByteSize())
2165    self.proto.repeated_foreign_message.add()
2166    self.assertEqual(6, self.proto.ByteSize())
2167    child0.c = 1
2168    self.assertEqual(8, self.proto.ByteSize())
2169    self.proto.ClearField('repeated_foreign_message')
2170    self.assertEqual(0, self.proto.ByteSize())
2171
2172    # Test within extension.
2173    extension = more_extensions_pb2.repeated_message_extension
2174    child_list = self.extended_proto.Extensions[extension]
2175    child0 = child_list.add()
2176    self.assertEqual(2, self.extended_proto.ByteSize())
2177    child_list.add()
2178    self.assertEqual(4, self.extended_proto.ByteSize())
2179    child0.foreign_message_int = 1
2180    self.assertEqual(6, self.extended_proto.ByteSize())
2181    child0.ClearField('foreign_message_int')
2182    self.assertEqual(4, self.extended_proto.ByteSize())
2183    self.extended_proto.ClearExtension(extension)
2184    self.assertEqual(0, self.extended_proto.ByteSize())
2185
2186  def testPackedRepeatedScalars(self):
2187    self.assertEqual(0, self.packed_proto.ByteSize())
2188
2189    self.packed_proto.packed_int32.append(10)   # 1 byte.
2190    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2191    # The tag is 2 bytes (the field number is 90), and the varint
2192    # storing the length is 1 byte.
2193    int_size = 1 + 2 + 3
2194    self.assertEqual(int_size, self.packed_proto.ByteSize())
2195
2196    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2197    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2198    # 2 more tag bytes, 1 more length byte.
2199    double_size = 8 + 8 + 3
2200    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2201
2202    self.packed_proto.ClearField('packed_int32')
2203    self.assertEqual(double_size, self.packed_proto.ByteSize())
2204
2205  def testPackedExtensions(self):
2206    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2207    extension = self.packed_extended_proto.Extensions[
2208        unittest_pb2.packed_fixed32_extension]
2209    extension.extend([1, 2, 3, 4])   # 16 bytes
2210    # Tag is 3 bytes.
2211    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2212
2213
2214# Issues to be sure to cover include:
2215#   * Handling of unrecognized tags ("uninterpreted_bytes").
2216#   * Handling of MessageSets.
2217#   * Consistent ordering of tags in the wire format,
2218#     including ordering between extensions and non-extension
2219#     fields.
2220#   * Consistent serialization of negative numbers, especially
2221#     negative int32s.
2222#   * Handling of empty submessages (with and without "has"
2223#     bits set).
2224
2225class SerializationTest(basetest.TestCase):
2226
2227  def testSerializeEmtpyMessage(self):
2228    first_proto = unittest_pb2.TestAllTypes()
2229    second_proto = unittest_pb2.TestAllTypes()
2230    serialized = first_proto.SerializeToString()
2231    self.assertEqual(first_proto.ByteSize(), len(serialized))
2232    self.assertEqual(
2233        len(serialized),
2234        second_proto.MergeFromString(serialized))
2235    self.assertEqual(first_proto, second_proto)
2236
2237  def testSerializeAllFields(self):
2238    first_proto = unittest_pb2.TestAllTypes()
2239    second_proto = unittest_pb2.TestAllTypes()
2240    test_util.SetAllFields(first_proto)
2241    serialized = first_proto.SerializeToString()
2242    self.assertEqual(first_proto.ByteSize(), len(serialized))
2243    self.assertEqual(
2244        len(serialized),
2245        second_proto.MergeFromString(serialized))
2246    self.assertEqual(first_proto, second_proto)
2247
2248  def testSerializeAllExtensions(self):
2249    first_proto = unittest_pb2.TestAllExtensions()
2250    second_proto = unittest_pb2.TestAllExtensions()
2251    test_util.SetAllExtensions(first_proto)
2252    serialized = first_proto.SerializeToString()
2253    self.assertEqual(
2254        len(serialized),
2255        second_proto.MergeFromString(serialized))
2256    self.assertEqual(first_proto, second_proto)
2257
2258  def testSerializeWithOptionalGroup(self):
2259    first_proto = unittest_pb2.TestAllTypes()
2260    second_proto = unittest_pb2.TestAllTypes()
2261    first_proto.optionalgroup.a = 242
2262    serialized = first_proto.SerializeToString()
2263    self.assertEqual(
2264        len(serialized),
2265        second_proto.MergeFromString(serialized))
2266    self.assertEqual(first_proto, second_proto)
2267
2268  def testSerializeNegativeValues(self):
2269    first_proto = unittest_pb2.TestAllTypes()
2270
2271    first_proto.optional_int32 = -1
2272    first_proto.optional_int64 = -(2 << 40)
2273    first_proto.optional_sint32 = -3
2274    first_proto.optional_sint64 = -(4 << 40)
2275    first_proto.optional_sfixed32 = -5
2276    first_proto.optional_sfixed64 = -(6 << 40)
2277
2278    second_proto = unittest_pb2.TestAllTypes.FromString(
2279        first_proto.SerializeToString())
2280
2281    self.assertEqual(first_proto, second_proto)
2282
2283  def testParseTruncated(self):
2284    # This test is only applicable for the Python implementation of the API.
2285    if api_implementation.Type() != 'python':
2286      return
2287
2288    first_proto = unittest_pb2.TestAllTypes()
2289    test_util.SetAllFields(first_proto)
2290    serialized = first_proto.SerializeToString()
2291
2292    for truncation_point in xrange(len(serialized) + 1):
2293      try:
2294        second_proto = unittest_pb2.TestAllTypes()
2295        unknown_fields = unittest_pb2.TestEmptyMessage()
2296        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2297        # If we didn't raise an error then we read exactly the amount expected.
2298        self.assertEqual(truncation_point, pos)
2299
2300        # Parsing to unknown fields should not throw if parsing to known fields
2301        # did not.
2302        try:
2303          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2304          self.assertEqual(truncation_point, pos2)
2305        except message.DecodeError:
2306          self.fail('Parsing unknown fields failed when parsing known fields '
2307                    'did not.')
2308      except message.DecodeError:
2309        # Parsing unknown fields should also fail.
2310        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2311                          serialized, 0, truncation_point)
2312
2313  def testCanonicalSerializationOrder(self):
2314    proto = more_messages_pb2.OutOfOrderFields()
2315    # These are also their tag numbers.  Even though we're setting these in
2316    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2317    # file, they should nonetheless be serialized in tag order.
2318    proto.optional_sint32 = 5
2319    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2320    proto.optional_uint32 = 3
2321    proto.Extensions[more_messages_pb2.optional_int64] = 2
2322    proto.optional_int32 = 1
2323    serialized = proto.SerializeToString()
2324    self.assertEqual(proto.ByteSize(), len(serialized))
2325    d = _MiniDecoder(serialized)
2326    ReadTag = d.ReadFieldNumberAndWireType
2327    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2328    self.assertEqual(1, d.ReadInt32())
2329    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2330    self.assertEqual(2, d.ReadInt64())
2331    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2332    self.assertEqual(3, d.ReadUInt32())
2333    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2334    self.assertEqual(4, d.ReadUInt64())
2335    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2336    self.assertEqual(5, d.ReadSInt32())
2337
2338  def testCanonicalSerializationOrderSameAsCpp(self):
2339    # Copy of the same test we use for C++.
2340    proto = unittest_pb2.TestFieldOrderings()
2341    test_util.SetAllFieldsAndExtensions(proto)
2342    serialized = proto.SerializeToString()
2343    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2344
2345  def testMergeFromStringWhenFieldsAlreadySet(self):
2346    first_proto = unittest_pb2.TestAllTypes()
2347    first_proto.repeated_string.append('foobar')
2348    first_proto.optional_int32 = 23
2349    first_proto.optional_nested_message.bb = 42
2350    serialized = first_proto.SerializeToString()
2351
2352    second_proto = unittest_pb2.TestAllTypes()
2353    second_proto.repeated_string.append('baz')
2354    second_proto.optional_int32 = 100
2355    second_proto.optional_nested_message.bb = 999
2356
2357    bytes_parsed = second_proto.MergeFromString(serialized)
2358    self.assertEqual(len(serialized), bytes_parsed)
2359
2360    # Ensure that we append to repeated fields.
2361    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2362    # Ensure that we overwrite nonrepeatd scalars.
2363    self.assertEqual(23, second_proto.optional_int32)
2364    # Ensure that we recursively call MergeFromString() on
2365    # submessages.
2366    self.assertEqual(42, second_proto.optional_nested_message.bb)
2367
2368  def testMessageSetWireFormat(self):
2369    proto = unittest_mset_pb2.TestMessageSet()
2370    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2371    extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
2372    extension1 = extension_message1.message_set_extension
2373    extension2 = extension_message2.message_set_extension
2374    proto.Extensions[extension1].i = 123
2375    proto.Extensions[extension2].str = 'foo'
2376
2377    # Serialize using the MessageSet wire format (this is specified in the
2378    # .proto file).
2379    serialized = proto.SerializeToString()
2380
2381    raw = unittest_mset_pb2.RawMessageSet()
2382    self.assertEqual(False,
2383                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2384    self.assertEqual(
2385        len(serialized),
2386        raw.MergeFromString(serialized))
2387    self.assertEqual(2, len(raw.item))
2388
2389    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2390    self.assertEqual(
2391        len(raw.item[0].message),
2392        message1.MergeFromString(raw.item[0].message))
2393    self.assertEqual(123, message1.i)
2394
2395    message2 = unittest_mset_pb2.TestMessageSetExtension2()
2396    self.assertEqual(
2397        len(raw.item[1].message),
2398        message2.MergeFromString(raw.item[1].message))
2399    self.assertEqual('foo', message2.str)
2400
2401    # Deserialize using the MessageSet wire format.
2402    proto2 = unittest_mset_pb2.TestMessageSet()
2403    self.assertEqual(
2404        len(serialized),
2405        proto2.MergeFromString(serialized))
2406    self.assertEqual(123, proto2.Extensions[extension1].i)
2407    self.assertEqual('foo', proto2.Extensions[extension2].str)
2408
2409    # Check byte size.
2410    self.assertEqual(proto2.ByteSize(), len(serialized))
2411    self.assertEqual(proto.ByteSize(), len(serialized))
2412
2413  def testMessageSetWireFormatUnknownExtension(self):
2414    # Create a message using the message set wire format with an unknown
2415    # message.
2416    raw = unittest_mset_pb2.RawMessageSet()
2417
2418    # Add an item.
2419    item = raw.item.add()
2420    item.type_id = 1545008
2421    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2422    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2423    message1.i = 12345
2424    item.message = message1.SerializeToString()
2425
2426    # Add a second, unknown extension.
2427    item = raw.item.add()
2428    item.type_id = 1545009
2429    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2430    message1 = unittest_mset_pb2.TestMessageSetExtension1()
2431    message1.i = 12346
2432    item.message = message1.SerializeToString()
2433
2434    # Add another unknown extension.
2435    item = raw.item.add()
2436    item.type_id = 1545010
2437    message1 = unittest_mset_pb2.TestMessageSetExtension2()
2438    message1.str = 'foo'
2439    item.message = message1.SerializeToString()
2440
2441    serialized = raw.SerializeToString()
2442
2443    # Parse message using the message set wire format.
2444    proto = unittest_mset_pb2.TestMessageSet()
2445    self.assertEqual(
2446        len(serialized),
2447        proto.MergeFromString(serialized))
2448
2449    # Check that the message parsed well.
2450    extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
2451    extension1 = extension_message1.message_set_extension
2452    self.assertEquals(12345, proto.Extensions[extension1].i)
2453
2454  def testUnknownFields(self):
2455    proto = unittest_pb2.TestAllTypes()
2456    test_util.SetAllFields(proto)
2457
2458    serialized = proto.SerializeToString()
2459
2460    # The empty message should be parsable with all of the fields
2461    # unknown.
2462    proto2 = unittest_pb2.TestEmptyMessage()
2463
2464    # Parsing this message should succeed.
2465    self.assertEqual(
2466        len(serialized),
2467        proto2.MergeFromString(serialized))
2468
2469    # Now test with a int64 field set.
2470    proto = unittest_pb2.TestAllTypes()
2471    proto.optional_int64 = 0x0fffffffffffffff
2472    serialized = proto.SerializeToString()
2473    # The empty message should be parsable with all of the fields
2474    # unknown.
2475    proto2 = unittest_pb2.TestEmptyMessage()
2476    # Parsing this message should succeed.
2477    self.assertEqual(
2478        len(serialized),
2479        proto2.MergeFromString(serialized))
2480
2481  def _CheckRaises(self, exc_class, callable_obj, exception):
2482    """This method checks if the excpetion type and message are as expected."""
2483    try:
2484      callable_obj()
2485    except exc_class as ex:
2486      # Check if the exception message is the right one.
2487      self.assertEqual(exception, str(ex))
2488      return
2489    else:
2490      raise self.failureException('%s not raised' % str(exc_class))
2491
2492  def testSerializeUninitialized(self):
2493    proto = unittest_pb2.TestRequired()
2494    self._CheckRaises(
2495        message.EncodeError,
2496        proto.SerializeToString,
2497        'Message protobuf_unittest.TestRequired is missing required fields: '
2498        'a,b,c')
2499    # Shouldn't raise exceptions.
2500    partial = proto.SerializePartialToString()
2501
2502    proto2 = unittest_pb2.TestRequired()
2503    self.assertFalse(proto2.HasField('a'))
2504    # proto2 ParseFromString does not check that required fields are set.
2505    proto2.ParseFromString(partial)
2506    self.assertFalse(proto2.HasField('a'))
2507
2508    proto.a = 1
2509    self._CheckRaises(
2510        message.EncodeError,
2511        proto.SerializeToString,
2512        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2513    # Shouldn't raise exceptions.
2514    partial = proto.SerializePartialToString()
2515
2516    proto.b = 2
2517    self._CheckRaises(
2518        message.EncodeError,
2519        proto.SerializeToString,
2520        'Message protobuf_unittest.TestRequired is missing required fields: c')
2521    # Shouldn't raise exceptions.
2522    partial = proto.SerializePartialToString()
2523
2524    proto.c = 3
2525    serialized = proto.SerializeToString()
2526    # Shouldn't raise exceptions.
2527    partial = proto.SerializePartialToString()
2528
2529    proto2 = unittest_pb2.TestRequired()
2530    self.assertEqual(
2531        len(serialized),
2532        proto2.MergeFromString(serialized))
2533    self.assertEqual(1, proto2.a)
2534    self.assertEqual(2, proto2.b)
2535    self.assertEqual(3, proto2.c)
2536    self.assertEqual(
2537        len(partial),
2538        proto2.MergeFromString(partial))
2539    self.assertEqual(1, proto2.a)
2540    self.assertEqual(2, proto2.b)
2541    self.assertEqual(3, proto2.c)
2542
2543  def testSerializeUninitializedSubMessage(self):
2544    proto = unittest_pb2.TestRequiredForeign()
2545
2546    # Sub-message doesn't exist yet, so this succeeds.
2547    proto.SerializeToString()
2548
2549    proto.optional_message.a = 1
2550    self._CheckRaises(
2551        message.EncodeError,
2552        proto.SerializeToString,
2553        'Message protobuf_unittest.TestRequiredForeign '
2554        'is missing required fields: '
2555        'optional_message.b,optional_message.c')
2556
2557    proto.optional_message.b = 2
2558    proto.optional_message.c = 3
2559    proto.SerializeToString()
2560
2561    proto.repeated_message.add().a = 1
2562    proto.repeated_message.add().b = 2
2563    self._CheckRaises(
2564        message.EncodeError,
2565        proto.SerializeToString,
2566        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2567        'repeated_message[0].b,repeated_message[0].c,'
2568        'repeated_message[1].a,repeated_message[1].c')
2569
2570    proto.repeated_message[0].b = 2
2571    proto.repeated_message[0].c = 3
2572    proto.repeated_message[1].a = 1
2573    proto.repeated_message[1].c = 3
2574    proto.SerializeToString()
2575
2576  def testSerializeAllPackedFields(self):
2577    first_proto = unittest_pb2.TestPackedTypes()
2578    second_proto = unittest_pb2.TestPackedTypes()
2579    test_util.SetAllPackedFields(first_proto)
2580    serialized = first_proto.SerializeToString()
2581    self.assertEqual(first_proto.ByteSize(), len(serialized))
2582    bytes_read = second_proto.MergeFromString(serialized)
2583    self.assertEqual(second_proto.ByteSize(), bytes_read)
2584    self.assertEqual(first_proto, second_proto)
2585
2586  def testSerializeAllPackedExtensions(self):
2587    first_proto = unittest_pb2.TestPackedExtensions()
2588    second_proto = unittest_pb2.TestPackedExtensions()
2589    test_util.SetAllPackedExtensions(first_proto)
2590    serialized = first_proto.SerializeToString()
2591    bytes_read = second_proto.MergeFromString(serialized)
2592    self.assertEqual(second_proto.ByteSize(), bytes_read)
2593    self.assertEqual(first_proto, second_proto)
2594
2595  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2596    first_proto = unittest_pb2.TestPackedTypes()
2597    first_proto.packed_int32.extend([1, 2])
2598    first_proto.packed_double.append(3.0)
2599    serialized = first_proto.SerializeToString()
2600
2601    second_proto = unittest_pb2.TestPackedTypes()
2602    second_proto.packed_int32.append(3)
2603    second_proto.packed_double.extend([1.0, 2.0])
2604    second_proto.packed_sint32.append(4)
2605
2606    self.assertEqual(
2607        len(serialized),
2608        second_proto.MergeFromString(serialized))
2609    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2610    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2611    self.assertEqual([4], second_proto.packed_sint32)
2612
2613  def testPackedFieldsWireFormat(self):
2614    proto = unittest_pb2.TestPackedTypes()
2615    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
2616    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
2617    proto.packed_float.append(2.0)             # 4 bytes, will be before double
2618    serialized = proto.SerializeToString()
2619    self.assertEqual(proto.ByteSize(), len(serialized))
2620    d = _MiniDecoder(serialized)
2621    ReadTag = d.ReadFieldNumberAndWireType
2622    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2623    self.assertEqual(1+1+1+2, d.ReadInt32())
2624    self.assertEqual(1, d.ReadInt32())
2625    self.assertEqual(2, d.ReadInt32())
2626    self.assertEqual(150, d.ReadInt32())
2627    self.assertEqual(3, d.ReadInt32())
2628    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2629    self.assertEqual(4, d.ReadInt32())
2630    self.assertEqual(2.0, d.ReadFloat())
2631    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
2632    self.assertEqual(8+8, d.ReadInt32())
2633    self.assertEqual(1.0, d.ReadDouble())
2634    self.assertEqual(1000.0, d.ReadDouble())
2635    self.assertTrue(d.EndOfStream())
2636
2637  def testParsePackedFromUnpacked(self):
2638    unpacked = unittest_pb2.TestUnpackedTypes()
2639    test_util.SetAllUnpackedFields(unpacked)
2640    packed = unittest_pb2.TestPackedTypes()
2641    serialized = unpacked.SerializeToString()
2642    self.assertEqual(
2643        len(serialized),
2644        packed.MergeFromString(serialized))
2645    expected = unittest_pb2.TestPackedTypes()
2646    test_util.SetAllPackedFields(expected)
2647    self.assertEqual(expected, packed)
2648
2649  def testParseUnpackedFromPacked(self):
2650    packed = unittest_pb2.TestPackedTypes()
2651    test_util.SetAllPackedFields(packed)
2652    unpacked = unittest_pb2.TestUnpackedTypes()
2653    serialized = packed.SerializeToString()
2654    self.assertEqual(
2655        len(serialized),
2656        unpacked.MergeFromString(serialized))
2657    expected = unittest_pb2.TestUnpackedTypes()
2658    test_util.SetAllUnpackedFields(expected)
2659    self.assertEqual(expected, unpacked)
2660
2661  def testFieldNumbers(self):
2662    proto = unittest_pb2.TestAllTypes()
2663    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
2664    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
2665    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
2666    self.assertEqual(
2667      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
2668    self.assertEqual(
2669      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
2670    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
2671    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
2672    self.assertEqual(
2673      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
2674    self.assertEqual(
2675      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
2676
2677  def testExtensionFieldNumbers(self):
2678    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
2679    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
2680    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
2681    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
2682    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
2683    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
2684    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
2685    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
2686    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
2687    self.assertEqual(
2688      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
2689    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
2690    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2691      21)
2692    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
2693    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
2694    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
2695    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
2696    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
2697    self.assertEqual(
2698      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
2699    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
2700    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
2701      51)
2702
2703  def testInitKwargs(self):
2704    proto = unittest_pb2.TestAllTypes(
2705        optional_int32=1,
2706        optional_string='foo',
2707        optional_bool=True,
2708        optional_bytes=b'bar',
2709        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
2710        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
2711        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
2712        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
2713        repeated_int32=[1, 2, 3])
2714    self.assertTrue(proto.IsInitialized())
2715    self.assertTrue(proto.HasField('optional_int32'))
2716    self.assertTrue(proto.HasField('optional_string'))
2717    self.assertTrue(proto.HasField('optional_bool'))
2718    self.assertTrue(proto.HasField('optional_bytes'))
2719    self.assertTrue(proto.HasField('optional_nested_message'))
2720    self.assertTrue(proto.HasField('optional_foreign_message'))
2721    self.assertTrue(proto.HasField('optional_nested_enum'))
2722    self.assertTrue(proto.HasField('optional_foreign_enum'))
2723    self.assertEqual(1, proto.optional_int32)
2724    self.assertEqual('foo', proto.optional_string)
2725    self.assertEqual(True, proto.optional_bool)
2726    self.assertEqual(b'bar', proto.optional_bytes)
2727    self.assertEqual(1, proto.optional_nested_message.bb)
2728    self.assertEqual(1, proto.optional_foreign_message.c)
2729    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
2730                     proto.optional_nested_enum)
2731    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
2732    self.assertEqual([1, 2, 3], proto.repeated_int32)
2733
2734  def testInitArgsUnknownFieldName(self):
2735    def InitalizeEmptyMessageWithExtraKeywordArg():
2736      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
2737    self._CheckRaises(ValueError,
2738                      InitalizeEmptyMessageWithExtraKeywordArg,
2739                      'Protocol message has no "unknown" field.')
2740
2741  def testInitRequiredKwargs(self):
2742    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
2743    self.assertTrue(proto.IsInitialized())
2744    self.assertTrue(proto.HasField('a'))
2745    self.assertTrue(proto.HasField('b'))
2746    self.assertTrue(proto.HasField('c'))
2747    self.assertTrue(not proto.HasField('dummy2'))
2748    self.assertEqual(1, proto.a)
2749    self.assertEqual(1, proto.b)
2750    self.assertEqual(1, proto.c)
2751
2752  def testInitRequiredForeignKwargs(self):
2753    proto = unittest_pb2.TestRequiredForeign(
2754        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
2755    self.assertTrue(proto.IsInitialized())
2756    self.assertTrue(proto.HasField('optional_message'))
2757    self.assertTrue(proto.optional_message.IsInitialized())
2758    self.assertTrue(proto.optional_message.HasField('a'))
2759    self.assertTrue(proto.optional_message.HasField('b'))
2760    self.assertTrue(proto.optional_message.HasField('c'))
2761    self.assertTrue(not proto.optional_message.HasField('dummy2'))
2762    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
2763                     proto.optional_message)
2764    self.assertEqual(1, proto.optional_message.a)
2765    self.assertEqual(1, proto.optional_message.b)
2766    self.assertEqual(1, proto.optional_message.c)
2767
2768  def testInitRepeatedKwargs(self):
2769    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
2770    self.assertTrue(proto.IsInitialized())
2771    self.assertEqual(1, proto.repeated_int32[0])
2772    self.assertEqual(2, proto.repeated_int32[1])
2773    self.assertEqual(3, proto.repeated_int32[2])
2774
2775
2776class OptionsTest(basetest.TestCase):
2777
2778  def testMessageOptions(self):
2779    proto = unittest_mset_pb2.TestMessageSet()
2780    self.assertEqual(True,
2781                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2782    proto = unittest_pb2.TestAllTypes()
2783    self.assertEqual(False,
2784                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
2785
2786  def testPackedOptions(self):
2787    proto = unittest_pb2.TestAllTypes()
2788    proto.optional_int32 = 1
2789    proto.optional_double = 3.0
2790    for field_descriptor, _ in proto.ListFields():
2791      self.assertEqual(False, field_descriptor.GetOptions().packed)
2792
2793    proto = unittest_pb2.TestPackedTypes()
2794    proto.packed_int32.append(1)
2795    proto.packed_double.append(3.0)
2796    for field_descriptor, _ in proto.ListFields():
2797      self.assertEqual(True, field_descriptor.GetOptions().packed)
2798      self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
2799                       field_descriptor.label)
2800
2801
2802
2803class ClassAPITest(basetest.TestCase):
2804
2805  def testMakeClassWithNestedDescriptor(self):
2806    leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
2807                                      containing_type=None, fields=[],
2808                                      nested_types=[], enum_types=[],
2809                                      extensions=[])
2810    child_desc = descriptor.Descriptor('child', 'package.parent.child', '',
2811                                       containing_type=None, fields=[],
2812                                       nested_types=[leaf_desc], enum_types=[],
2813                                       extensions=[])
2814    sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling',
2815                                         '', containing_type=None, fields=[],
2816                                         nested_types=[], enum_types=[],
2817                                         extensions=[])
2818    parent_desc = descriptor.Descriptor('parent', 'package.parent', '',
2819                                        containing_type=None, fields=[],
2820                                        nested_types=[child_desc, sibling_desc],
2821                                        enum_types=[], extensions=[])
2822    message_class = reflection.MakeClass(parent_desc)
2823    self.assertIn('child', message_class.__dict__)
2824    self.assertIn('sibling', message_class.__dict__)
2825    self.assertIn('leaf', message_class.child.__dict__)
2826
2827  def _GetSerializedFileDescriptor(self, name):
2828    """Get a serialized representation of a test FileDescriptorProto.
2829
2830    Args:
2831      name: All calls to this must use a unique message name, to avoid
2832          collisions in the cpp descriptor pool.
2833    Returns:
2834      A string containing the serialized form of a test FileDescriptorProto.
2835    """
2836    file_descriptor_str = (
2837        'message_type {'
2838        '  name: "' + name + '"'
2839        '  field {'
2840        '    name: "flat"'
2841        '    number: 1'
2842        '    label: LABEL_REPEATED'
2843        '    type: TYPE_UINT32'
2844        '  }'
2845        '  field {'
2846        '    name: "bar"'
2847        '    number: 2'
2848        '    label: LABEL_OPTIONAL'
2849        '    type: TYPE_MESSAGE'
2850        '    type_name: "Bar"'
2851        '  }'
2852        '  nested_type {'
2853        '    name: "Bar"'
2854        '    field {'
2855        '      name: "baz"'
2856        '      number: 3'
2857        '      label: LABEL_OPTIONAL'
2858        '      type: TYPE_MESSAGE'
2859        '      type_name: "Baz"'
2860        '    }'
2861        '    nested_type {'
2862        '      name: "Baz"'
2863        '      enum_type {'
2864        '        name: "deep_enum"'
2865        '        value {'
2866        '          name: "VALUE_A"'
2867        '          number: 0'
2868        '        }'
2869        '      }'
2870        '      field {'
2871        '        name: "deep"'
2872        '        number: 4'
2873        '        label: LABEL_OPTIONAL'
2874        '        type: TYPE_UINT32'
2875        '      }'
2876        '    }'
2877        '  }'
2878        '}')
2879    file_descriptor = descriptor_pb2.FileDescriptorProto()
2880    text_format.Merge(file_descriptor_str, file_descriptor)
2881    return file_descriptor.SerializeToString()
2882
2883  def testParsingFlatClassWithExplicitClassDeclaration(self):
2884    """Test that the generated class can parse a flat message."""
2885    file_descriptor = descriptor_pb2.FileDescriptorProto()
2886    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
2887    msg_descriptor = descriptor.MakeDescriptor(
2888        file_descriptor.message_type[0])
2889
2890    class MessageClass(message.Message):
2891      __metaclass__ = reflection.GeneratedProtocolMessageType
2892      DESCRIPTOR = msg_descriptor
2893    msg = MessageClass()
2894    msg_str = (
2895        'flat: 0 '
2896        'flat: 1 '
2897        'flat: 2 ')
2898    text_format.Merge(msg_str, msg)
2899    self.assertEqual(msg.flat, [0, 1, 2])
2900
2901  def testParsingFlatClass(self):
2902    """Test that the generated class can parse a flat message."""
2903    file_descriptor = descriptor_pb2.FileDescriptorProto()
2904    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
2905    msg_descriptor = descriptor.MakeDescriptor(
2906        file_descriptor.message_type[0])
2907    msg_class = reflection.MakeClass(msg_descriptor)
2908    msg = msg_class()
2909    msg_str = (
2910        'flat: 0 '
2911        'flat: 1 '
2912        'flat: 2 ')
2913    text_format.Merge(msg_str, msg)
2914    self.assertEqual(msg.flat, [0, 1, 2])
2915
2916  def testParsingNestedClass(self):
2917    """Test that the generated class can parse a nested message."""
2918    file_descriptor = descriptor_pb2.FileDescriptorProto()
2919    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
2920    msg_descriptor = descriptor.MakeDescriptor(
2921        file_descriptor.message_type[0])
2922    msg_class = reflection.MakeClass(msg_descriptor)
2923    msg = msg_class()
2924    msg_str = (
2925        'bar {'
2926        '  baz {'
2927        '    deep: 4'
2928        '  }'
2929        '}')
2930    text_format.Merge(msg_str, msg)
2931    self.assertEqual(msg.bar.baz.deep, 4)
2932
2933if __name__ == '__main__':
2934  basetest.main()
2935