1#! /usr/bin/env python
2#
3# Protocol Buffers - Google's data interchange format
4# Copyright 2008 Google Inc.  All rights reserved.
5# https://developers.google.com/protocol-buffers/
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are
9# met:
10#
11#     * Redistributions of source code must retain the above copyright
12# notice, this list of conditions and the following disclaimer.
13#     * Redistributions in binary form must reproduce the above
14# copyright notice, this list of conditions and the following disclaimer
15# in the documentation and/or other materials provided with the
16# distribution.
17#     * Neither the name of Google Inc. nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33"""Tests python protocol buffers against the golden message.
34
35Note that the golden messages exercise every known field type, thus this
36test ends up exercising and verifying nearly all of the parsing and
37serialization code in the whole library.
38
39TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
40sense to call this a test of the "message" module, which only declares an
41abstract interface.
42"""
43
44__author__ = 'gps@google.com (Gregory P. Smith)'
45
46
47import collections
48import copy
49import math
50import operator
51import pickle
52import six
53import sys
54
55try:
56  import unittest2 as unittest  #PY26
57except ImportError:
58  import unittest
59
60from google.protobuf import map_unittest_pb2
61from google.protobuf import unittest_pb2
62from google.protobuf import unittest_proto3_arena_pb2
63from google.protobuf import descriptor_pb2
64from google.protobuf import descriptor_pool
65from google.protobuf import message_factory
66from google.protobuf import text_format
67from google.protobuf.internal import api_implementation
68from google.protobuf.internal import packed_field_test_pb2
69from google.protobuf.internal import test_util
70from google.protobuf import message
71from google.protobuf.internal import _parameterized
72
73if six.PY3:
74  long = int
75
76
77# Python pre-2.6 does not have isinf() or isnan() functions, so we have
78# to provide our own.
79def isnan(val):
80  # NaN is never equal to itself.
81  return val != val
82def isinf(val):
83  # Infinity times zero equals NaN.
84  return not isnan(val) and isnan(val * 0)
85def IsPosInf(val):
86  return isinf(val) and (val > 0)
87def IsNegInf(val):
88  return isinf(val) and (val < 0)
89
90
91@_parameterized.Parameters(
92    (unittest_pb2),
93    (unittest_proto3_arena_pb2))
94class MessageTest(unittest.TestCase):
95
96  def testBadUtf8String(self, message_module):
97    if api_implementation.Type() != 'python':
98      self.skipTest("Skipping testBadUtf8String, currently only the python "
99                    "api implementation raises UnicodeDecodeError when a "
100                    "string field contains bad utf-8.")
101    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
102    with self.assertRaises(UnicodeDecodeError) as context:
103      message_module.TestAllTypes.FromString(bad_utf8_data)
104    self.assertIn('TestAllTypes.optional_string', str(context.exception))
105
106  def testGoldenMessage(self, message_module):
107    # Proto3 doesn't have the "default_foo" members or foreign enums,
108    # and doesn't preserve unknown fields, so for proto3 we use a golden
109    # message that doesn't have these fields set.
110    if message_module is unittest_pb2:
111      golden_data = test_util.GoldenFileData(
112          'golden_message_oneof_implemented')
113    else:
114      golden_data = test_util.GoldenFileData('golden_message_proto3')
115
116    golden_message = message_module.TestAllTypes()
117    golden_message.ParseFromString(golden_data)
118    if message_module is unittest_pb2:
119      test_util.ExpectAllFieldsSet(self, golden_message)
120    self.assertEqual(golden_data, golden_message.SerializeToString())
121    golden_copy = copy.deepcopy(golden_message)
122    self.assertEqual(golden_data, golden_copy.SerializeToString())
123
124  def testGoldenPackedMessage(self, message_module):
125    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
126    golden_message = message_module.TestPackedTypes()
127    golden_message.ParseFromString(golden_data)
128    all_set = message_module.TestPackedTypes()
129    test_util.SetAllPackedFields(all_set)
130    self.assertEqual(all_set, golden_message)
131    self.assertEqual(golden_data, all_set.SerializeToString())
132    golden_copy = copy.deepcopy(golden_message)
133    self.assertEqual(golden_data, golden_copy.SerializeToString())
134
135  def testPickleSupport(self, message_module):
136    golden_data = test_util.GoldenFileData('golden_message')
137    golden_message = message_module.TestAllTypes()
138    golden_message.ParseFromString(golden_data)
139    pickled_message = pickle.dumps(golden_message)
140
141    unpickled_message = pickle.loads(pickled_message)
142    self.assertEqual(unpickled_message, golden_message)
143
144  def testPositiveInfinity(self, message_module):
145    if message_module is unittest_pb2:
146      golden_data = (b'\x5D\x00\x00\x80\x7F'
147                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
148                     b'\xCD\x02\x00\x00\x80\x7F'
149                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
150    else:
151      golden_data = (b'\x5D\x00\x00\x80\x7F'
152                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
153                     b'\xCA\x02\x04\x00\x00\x80\x7F'
154                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
155
156    golden_message = message_module.TestAllTypes()
157    golden_message.ParseFromString(golden_data)
158    self.assertTrue(IsPosInf(golden_message.optional_float))
159    self.assertTrue(IsPosInf(golden_message.optional_double))
160    self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
161    self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
162    self.assertEqual(golden_data, golden_message.SerializeToString())
163
164  def testNegativeInfinity(self, message_module):
165    if message_module is unittest_pb2:
166      golden_data = (b'\x5D\x00\x00\x80\xFF'
167                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
168                     b'\xCD\x02\x00\x00\x80\xFF'
169                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
170    else:
171      golden_data = (b'\x5D\x00\x00\x80\xFF'
172                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
173                     b'\xCA\x02\x04\x00\x00\x80\xFF'
174                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
175
176    golden_message = message_module.TestAllTypes()
177    golden_message.ParseFromString(golden_data)
178    self.assertTrue(IsNegInf(golden_message.optional_float))
179    self.assertTrue(IsNegInf(golden_message.optional_double))
180    self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
181    self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
182    self.assertEqual(golden_data, golden_message.SerializeToString())
183
184  def testNotANumber(self, message_module):
185    golden_data = (b'\x5D\x00\x00\xC0\x7F'
186                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
187                   b'\xCD\x02\x00\x00\xC0\x7F'
188                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
189    golden_message = message_module.TestAllTypes()
190    golden_message.ParseFromString(golden_data)
191    self.assertTrue(isnan(golden_message.optional_float))
192    self.assertTrue(isnan(golden_message.optional_double))
193    self.assertTrue(isnan(golden_message.repeated_float[0]))
194    self.assertTrue(isnan(golden_message.repeated_double[0]))
195
196    # The protocol buffer may serialize to any one of multiple different
197    # representations of a NaN.  Rather than verify a specific representation,
198    # verify the serialized string can be converted into a correctly
199    # behaving protocol buffer.
200    serialized = golden_message.SerializeToString()
201    message = message_module.TestAllTypes()
202    message.ParseFromString(serialized)
203    self.assertTrue(isnan(message.optional_float))
204    self.assertTrue(isnan(message.optional_double))
205    self.assertTrue(isnan(message.repeated_float[0]))
206    self.assertTrue(isnan(message.repeated_double[0]))
207
208  def testPositiveInfinityPacked(self, message_module):
209    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
210                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
211    golden_message = message_module.TestPackedTypes()
212    golden_message.ParseFromString(golden_data)
213    self.assertTrue(IsPosInf(golden_message.packed_float[0]))
214    self.assertTrue(IsPosInf(golden_message.packed_double[0]))
215    self.assertEqual(golden_data, golden_message.SerializeToString())
216
217  def testNegativeInfinityPacked(self, message_module):
218    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
219                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
220    golden_message = message_module.TestPackedTypes()
221    golden_message.ParseFromString(golden_data)
222    self.assertTrue(IsNegInf(golden_message.packed_float[0]))
223    self.assertTrue(IsNegInf(golden_message.packed_double[0]))
224    self.assertEqual(golden_data, golden_message.SerializeToString())
225
226  def testNotANumberPacked(self, message_module):
227    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
228                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
229    golden_message = message_module.TestPackedTypes()
230    golden_message.ParseFromString(golden_data)
231    self.assertTrue(isnan(golden_message.packed_float[0]))
232    self.assertTrue(isnan(golden_message.packed_double[0]))
233
234    serialized = golden_message.SerializeToString()
235    message = message_module.TestPackedTypes()
236    message.ParseFromString(serialized)
237    self.assertTrue(isnan(message.packed_float[0]))
238    self.assertTrue(isnan(message.packed_double[0]))
239
240  def testExtremeFloatValues(self, message_module):
241    message = message_module.TestAllTypes()
242
243    # Most positive exponent, no significand bits set.
244    kMostPosExponentNoSigBits = math.pow(2, 127)
245    message.optional_float = kMostPosExponentNoSigBits
246    message.ParseFromString(message.SerializeToString())
247    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
248
249    # Most positive exponent, one significand bit set.
250    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
251    message.optional_float = kMostPosExponentOneSigBit
252    message.ParseFromString(message.SerializeToString())
253    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
254
255    # Repeat last two cases with values of same magnitude, but negative.
256    message.optional_float = -kMostPosExponentNoSigBits
257    message.ParseFromString(message.SerializeToString())
258    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
259
260    message.optional_float = -kMostPosExponentOneSigBit
261    message.ParseFromString(message.SerializeToString())
262    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
263
264    # Most negative exponent, no significand bits set.
265    kMostNegExponentNoSigBits = math.pow(2, -127)
266    message.optional_float = kMostNegExponentNoSigBits
267    message.ParseFromString(message.SerializeToString())
268    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
269
270    # Most negative exponent, one significand bit set.
271    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
272    message.optional_float = kMostNegExponentOneSigBit
273    message.ParseFromString(message.SerializeToString())
274    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
275
276    # Repeat last two cases with values of the same magnitude, but negative.
277    message.optional_float = -kMostNegExponentNoSigBits
278    message.ParseFromString(message.SerializeToString())
279    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
280
281    message.optional_float = -kMostNegExponentOneSigBit
282    message.ParseFromString(message.SerializeToString())
283    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
284
285  def testExtremeDoubleValues(self, message_module):
286    message = message_module.TestAllTypes()
287
288    # Most positive exponent, no significand bits set.
289    kMostPosExponentNoSigBits = math.pow(2, 1023)
290    message.optional_double = kMostPosExponentNoSigBits
291    message.ParseFromString(message.SerializeToString())
292    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
293
294    # Most positive exponent, one significand bit set.
295    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
296    message.optional_double = kMostPosExponentOneSigBit
297    message.ParseFromString(message.SerializeToString())
298    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
299
300    # Repeat last two cases with values of same magnitude, but negative.
301    message.optional_double = -kMostPosExponentNoSigBits
302    message.ParseFromString(message.SerializeToString())
303    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
304
305    message.optional_double = -kMostPosExponentOneSigBit
306    message.ParseFromString(message.SerializeToString())
307    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
308
309    # Most negative exponent, no significand bits set.
310    kMostNegExponentNoSigBits = math.pow(2, -1023)
311    message.optional_double = kMostNegExponentNoSigBits
312    message.ParseFromString(message.SerializeToString())
313    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
314
315    # Most negative exponent, one significand bit set.
316    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
317    message.optional_double = kMostNegExponentOneSigBit
318    message.ParseFromString(message.SerializeToString())
319    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
320
321    # Repeat last two cases with values of the same magnitude, but negative.
322    message.optional_double = -kMostNegExponentNoSigBits
323    message.ParseFromString(message.SerializeToString())
324    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
325
326    message.optional_double = -kMostNegExponentOneSigBit
327    message.ParseFromString(message.SerializeToString())
328    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
329
330  def testFloatPrinting(self, message_module):
331    message = message_module.TestAllTypes()
332    message.optional_float = 2.0
333    self.assertEqual(str(message), 'optional_float: 2.0\n')
334
335  def testHighPrecisionFloatPrinting(self, message_module):
336    message = message_module.TestAllTypes()
337    message.optional_double = 0.12345678912345678
338    if sys.version_info >= (3,):
339      self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
340    else:
341      self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
342
343  def testUnknownFieldPrinting(self, message_module):
344    populated = message_module.TestAllTypes()
345    test_util.SetAllNonLazyFields(populated)
346    empty = message_module.TestEmptyMessage()
347    empty.ParseFromString(populated.SerializeToString())
348    self.assertEqual(str(empty), '')
349
350  def testRepeatedNestedFieldIteration(self, message_module):
351    msg = message_module.TestAllTypes()
352    msg.repeated_nested_message.add(bb=1)
353    msg.repeated_nested_message.add(bb=2)
354    msg.repeated_nested_message.add(bb=3)
355    msg.repeated_nested_message.add(bb=4)
356
357    self.assertEqual([1, 2, 3, 4],
358                     [m.bb for m in msg.repeated_nested_message])
359    self.assertEqual([4, 3, 2, 1],
360                     [m.bb for m in reversed(msg.repeated_nested_message)])
361    self.assertEqual([4, 3, 2, 1],
362                     [m.bb for m in msg.repeated_nested_message[::-1]])
363
364  def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
365    """Check some different types with the default comparator."""
366    message = message_module.TestAllTypes()
367
368    # TODO(mattp): would testing more scalar types strengthen test?
369    message.repeated_int32.append(1)
370    message.repeated_int32.append(3)
371    message.repeated_int32.append(2)
372    message.repeated_int32.sort()
373    self.assertEqual(message.repeated_int32[0], 1)
374    self.assertEqual(message.repeated_int32[1], 2)
375    self.assertEqual(message.repeated_int32[2], 3)
376
377    message.repeated_float.append(1.1)
378    message.repeated_float.append(1.3)
379    message.repeated_float.append(1.2)
380    message.repeated_float.sort()
381    self.assertAlmostEqual(message.repeated_float[0], 1.1)
382    self.assertAlmostEqual(message.repeated_float[1], 1.2)
383    self.assertAlmostEqual(message.repeated_float[2], 1.3)
384
385    message.repeated_string.append('a')
386    message.repeated_string.append('c')
387    message.repeated_string.append('b')
388    message.repeated_string.sort()
389    self.assertEqual(message.repeated_string[0], 'a')
390    self.assertEqual(message.repeated_string[1], 'b')
391    self.assertEqual(message.repeated_string[2], 'c')
392
393    message.repeated_bytes.append(b'a')
394    message.repeated_bytes.append(b'c')
395    message.repeated_bytes.append(b'b')
396    message.repeated_bytes.sort()
397    self.assertEqual(message.repeated_bytes[0], b'a')
398    self.assertEqual(message.repeated_bytes[1], b'b')
399    self.assertEqual(message.repeated_bytes[2], b'c')
400
401  def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
402    """Check some different types with custom comparator."""
403    message = message_module.TestAllTypes()
404
405    message.repeated_int32.append(-3)
406    message.repeated_int32.append(-2)
407    message.repeated_int32.append(-1)
408    message.repeated_int32.sort(key=abs)
409    self.assertEqual(message.repeated_int32[0], -1)
410    self.assertEqual(message.repeated_int32[1], -2)
411    self.assertEqual(message.repeated_int32[2], -3)
412
413    message.repeated_string.append('aaa')
414    message.repeated_string.append('bb')
415    message.repeated_string.append('c')
416    message.repeated_string.sort(key=len)
417    self.assertEqual(message.repeated_string[0], 'c')
418    self.assertEqual(message.repeated_string[1], 'bb')
419    self.assertEqual(message.repeated_string[2], 'aaa')
420
421  def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
422    """Check passing a custom comparator to sort a repeated composite field."""
423    message = message_module.TestAllTypes()
424
425    message.repeated_nested_message.add().bb = 1
426    message.repeated_nested_message.add().bb = 3
427    message.repeated_nested_message.add().bb = 2
428    message.repeated_nested_message.add().bb = 6
429    message.repeated_nested_message.add().bb = 5
430    message.repeated_nested_message.add().bb = 4
431    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
432    self.assertEqual(message.repeated_nested_message[0].bb, 1)
433    self.assertEqual(message.repeated_nested_message[1].bb, 2)
434    self.assertEqual(message.repeated_nested_message[2].bb, 3)
435    self.assertEqual(message.repeated_nested_message[3].bb, 4)
436    self.assertEqual(message.repeated_nested_message[4].bb, 5)
437    self.assertEqual(message.repeated_nested_message[5].bb, 6)
438
439  def testSortingRepeatedCompositeFieldsStable(self, message_module):
440    """Check passing a custom comparator to sort a repeated composite field."""
441    message = message_module.TestAllTypes()
442
443    message.repeated_nested_message.add().bb = 21
444    message.repeated_nested_message.add().bb = 20
445    message.repeated_nested_message.add().bb = 13
446    message.repeated_nested_message.add().bb = 33
447    message.repeated_nested_message.add().bb = 11
448    message.repeated_nested_message.add().bb = 24
449    message.repeated_nested_message.add().bb = 10
450    message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
451    self.assertEqual(
452        [13, 11, 10, 21, 20, 24, 33],
453        [n.bb for n in message.repeated_nested_message])
454
455    # Make sure that for the C++ implementation, the underlying fields
456    # are actually reordered.
457    pb = message.SerializeToString()
458    message.Clear()
459    message.MergeFromString(pb)
460    self.assertEqual(
461        [13, 11, 10, 21, 20, 24, 33],
462        [n.bb for n in message.repeated_nested_message])
463
464  def testRepeatedCompositeFieldSortArguments(self, message_module):
465    """Check sorting a repeated composite field using list.sort() arguments."""
466    message = message_module.TestAllTypes()
467
468    get_bb = operator.attrgetter('bb')
469    cmp_bb = lambda a, b: cmp(a.bb, b.bb)
470    message.repeated_nested_message.add().bb = 1
471    message.repeated_nested_message.add().bb = 3
472    message.repeated_nested_message.add().bb = 2
473    message.repeated_nested_message.add().bb = 6
474    message.repeated_nested_message.add().bb = 5
475    message.repeated_nested_message.add().bb = 4
476    message.repeated_nested_message.sort(key=get_bb)
477    self.assertEqual([k.bb for k in message.repeated_nested_message],
478                     [1, 2, 3, 4, 5, 6])
479    message.repeated_nested_message.sort(key=get_bb, reverse=True)
480    self.assertEqual([k.bb for k in message.repeated_nested_message],
481                     [6, 5, 4, 3, 2, 1])
482    if sys.version_info >= (3,): return  # No cmp sorting in PY3.
483    message.repeated_nested_message.sort(sort_function=cmp_bb)
484    self.assertEqual([k.bb for k in message.repeated_nested_message],
485                     [1, 2, 3, 4, 5, 6])
486    message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
487    self.assertEqual([k.bb for k in message.repeated_nested_message],
488                     [6, 5, 4, 3, 2, 1])
489
490  def testRepeatedScalarFieldSortArguments(self, message_module):
491    """Check sorting a scalar field using list.sort() arguments."""
492    message = message_module.TestAllTypes()
493
494    message.repeated_int32.append(-3)
495    message.repeated_int32.append(-2)
496    message.repeated_int32.append(-1)
497    message.repeated_int32.sort(key=abs)
498    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
499    message.repeated_int32.sort(key=abs, reverse=True)
500    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
501    if sys.version_info < (3,):  # No cmp sorting in PY3.
502      abs_cmp = lambda a, b: cmp(abs(a), abs(b))
503      message.repeated_int32.sort(sort_function=abs_cmp)
504      self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
505      message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
506      self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
507
508    message.repeated_string.append('aaa')
509    message.repeated_string.append('bb')
510    message.repeated_string.append('c')
511    message.repeated_string.sort(key=len)
512    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
513    message.repeated_string.sort(key=len, reverse=True)
514    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
515    if sys.version_info < (3,):  # No cmp sorting in PY3.
516      len_cmp = lambda a, b: cmp(len(a), len(b))
517      message.repeated_string.sort(sort_function=len_cmp)
518      self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
519      message.repeated_string.sort(cmp=len_cmp, reverse=True)
520      self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
521
522  def testRepeatedFieldsComparable(self, message_module):
523    m1 = message_module.TestAllTypes()
524    m2 = message_module.TestAllTypes()
525    m1.repeated_int32.append(0)
526    m1.repeated_int32.append(1)
527    m1.repeated_int32.append(2)
528    m2.repeated_int32.append(0)
529    m2.repeated_int32.append(1)
530    m2.repeated_int32.append(2)
531    m1.repeated_nested_message.add().bb = 1
532    m1.repeated_nested_message.add().bb = 2
533    m1.repeated_nested_message.add().bb = 3
534    m2.repeated_nested_message.add().bb = 1
535    m2.repeated_nested_message.add().bb = 2
536    m2.repeated_nested_message.add().bb = 3
537
538    if sys.version_info >= (3,): return  # No cmp() in PY3.
539
540    # These comparisons should not raise errors.
541    _ = m1 < m2
542    _ = m1.repeated_nested_message < m2.repeated_nested_message
543
544    # Make sure cmp always works. If it wasn't defined, these would be
545    # id() comparisons and would all fail.
546    self.assertEqual(cmp(m1, m2), 0)
547    self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
548    self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
549    self.assertEqual(cmp(m1.repeated_nested_message,
550                         m2.repeated_nested_message), 0)
551    with self.assertRaises(TypeError):
552      # Can't compare repeated composite containers to lists.
553      cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
554
555    # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
556
557  def testRepeatedFieldsAreSequences(self, message_module):
558    m = message_module.TestAllTypes()
559    self.assertIsInstance(m.repeated_int32, collections.MutableSequence)
560    self.assertIsInstance(m.repeated_nested_message,
561                          collections.MutableSequence)
562
563  def ensureNestedMessageExists(self, msg, attribute):
564    """Make sure that a nested message object exists.
565
566    As soon as a nested message attribute is accessed, it will be present in the
567    _fields dict, without being marked as actually being set.
568    """
569    getattr(msg, attribute)
570    self.assertFalse(msg.HasField(attribute))
571
572  def testOneofGetCaseNonexistingField(self, message_module):
573    m = message_module.TestAllTypes()
574    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
575
576  def testOneofDefaultValues(self, message_module):
577    m = message_module.TestAllTypes()
578    self.assertIs(None, m.WhichOneof('oneof_field'))
579    self.assertFalse(m.HasField('oneof_uint32'))
580
581    # Oneof is set even when setting it to a default value.
582    m.oneof_uint32 = 0
583    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
584    self.assertTrue(m.HasField('oneof_uint32'))
585    self.assertFalse(m.HasField('oneof_string'))
586
587    m.oneof_string = ""
588    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
589    self.assertTrue(m.HasField('oneof_string'))
590    self.assertFalse(m.HasField('oneof_uint32'))
591
592  def testOneofSemantics(self, message_module):
593    m = message_module.TestAllTypes()
594    self.assertIs(None, m.WhichOneof('oneof_field'))
595
596    m.oneof_uint32 = 11
597    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
598    self.assertTrue(m.HasField('oneof_uint32'))
599
600    m.oneof_string = u'foo'
601    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
602    self.assertFalse(m.HasField('oneof_uint32'))
603    self.assertTrue(m.HasField('oneof_string'))
604
605    # Read nested message accessor without accessing submessage.
606    m.oneof_nested_message
607    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
608    self.assertTrue(m.HasField('oneof_string'))
609    self.assertFalse(m.HasField('oneof_nested_message'))
610
611    # Read accessor of nested message without accessing submessage.
612    m.oneof_nested_message.bb
613    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
614    self.assertTrue(m.HasField('oneof_string'))
615    self.assertFalse(m.HasField('oneof_nested_message'))
616
617    m.oneof_nested_message.bb = 11
618    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
619    self.assertFalse(m.HasField('oneof_string'))
620    self.assertTrue(m.HasField('oneof_nested_message'))
621
622    m.oneof_bytes = b'bb'
623    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
624    self.assertFalse(m.HasField('oneof_nested_message'))
625    self.assertTrue(m.HasField('oneof_bytes'))
626
627  def testOneofCompositeFieldReadAccess(self, message_module):
628    m = message_module.TestAllTypes()
629    m.oneof_uint32 = 11
630
631    self.ensureNestedMessageExists(m, 'oneof_nested_message')
632    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
633    self.assertEqual(11, m.oneof_uint32)
634
635  def testOneofWhichOneof(self, message_module):
636    m = message_module.TestAllTypes()
637    self.assertIs(None, m.WhichOneof('oneof_field'))
638    if message_module is unittest_pb2:
639      self.assertFalse(m.HasField('oneof_field'))
640
641    m.oneof_uint32 = 11
642    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
643    if message_module is unittest_pb2:
644      self.assertTrue(m.HasField('oneof_field'))
645
646    m.oneof_bytes = b'bb'
647    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
648
649    m.ClearField('oneof_bytes')
650    self.assertIs(None, m.WhichOneof('oneof_field'))
651    if message_module is unittest_pb2:
652      self.assertFalse(m.HasField('oneof_field'))
653
654  def testOneofClearField(self, message_module):
655    m = message_module.TestAllTypes()
656    m.oneof_uint32 = 11
657    m.ClearField('oneof_field')
658    if message_module is unittest_pb2:
659      self.assertFalse(m.HasField('oneof_field'))
660    self.assertFalse(m.HasField('oneof_uint32'))
661    self.assertIs(None, m.WhichOneof('oneof_field'))
662
663  def testOneofClearSetField(self, message_module):
664    m = message_module.TestAllTypes()
665    m.oneof_uint32 = 11
666    m.ClearField('oneof_uint32')
667    if message_module is unittest_pb2:
668      self.assertFalse(m.HasField('oneof_field'))
669    self.assertFalse(m.HasField('oneof_uint32'))
670    self.assertIs(None, m.WhichOneof('oneof_field'))
671
672  def testOneofClearUnsetField(self, message_module):
673    m = message_module.TestAllTypes()
674    m.oneof_uint32 = 11
675    self.ensureNestedMessageExists(m, 'oneof_nested_message')
676    m.ClearField('oneof_nested_message')
677    self.assertEqual(11, m.oneof_uint32)
678    if message_module is unittest_pb2:
679      self.assertTrue(m.HasField('oneof_field'))
680    self.assertTrue(m.HasField('oneof_uint32'))
681    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
682
683  def testOneofDeserialize(self, message_module):
684    m = message_module.TestAllTypes()
685    m.oneof_uint32 = 11
686    m2 = message_module.TestAllTypes()
687    m2.ParseFromString(m.SerializeToString())
688    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
689
690  def testOneofCopyFrom(self, message_module):
691    m = message_module.TestAllTypes()
692    m.oneof_uint32 = 11
693    m2 = message_module.TestAllTypes()
694    m2.CopyFrom(m)
695    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
696
697  def testOneofNestedMergeFrom(self, message_module):
698    m = message_module.NestedTestAllTypes()
699    m.payload.oneof_uint32 = 11
700    m2 = message_module.NestedTestAllTypes()
701    m2.payload.oneof_bytes = b'bb'
702    m2.child.payload.oneof_bytes = b'bb'
703    m2.MergeFrom(m)
704    self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
705    self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
706
707  def testOneofMessageMergeFrom(self, message_module):
708    m = message_module.NestedTestAllTypes()
709    m.payload.oneof_nested_message.bb = 11
710    m.child.payload.oneof_nested_message.bb = 12
711    m2 = message_module.NestedTestAllTypes()
712    m2.payload.oneof_uint32 = 13
713    m2.MergeFrom(m)
714    self.assertEqual('oneof_nested_message',
715                     m2.payload.WhichOneof('oneof_field'))
716    self.assertEqual('oneof_nested_message',
717                     m2.child.payload.WhichOneof('oneof_field'))
718
719  def testOneofNestedMessageInit(self, message_module):
720    m = message_module.TestAllTypes(
721        oneof_nested_message=message_module.TestAllTypes.NestedMessage())
722    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
723
724  def testOneofClear(self, message_module):
725    m = message_module.TestAllTypes()
726    m.oneof_uint32 = 11
727    m.Clear()
728    self.assertIsNone(m.WhichOneof('oneof_field'))
729    m.oneof_bytes = b'bb'
730    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
731
732  def testAssignByteStringToUnicodeField(self, message_module):
733    """Assigning a byte string to a string field should result
734    in the value being converted to a Unicode string."""
735    m = message_module.TestAllTypes()
736    m.optional_string = str('')
737    self.assertIsInstance(m.optional_string, six.text_type)
738
739  def testLongValuedSlice(self, message_module):
740    """It should be possible to use long-valued indicies in slices
741
742    This didn't used to work in the v2 C++ implementation.
743    """
744    m = message_module.TestAllTypes()
745
746    # Repeated scalar
747    m.repeated_int32.append(1)
748    sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
749    self.assertEqual(len(m.repeated_int32), len(sl))
750
751    # Repeated composite
752    m.repeated_nested_message.add().bb = 3
753    sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
754    self.assertEqual(len(m.repeated_nested_message), len(sl))
755
756  def testExtendShouldNotSwallowExceptions(self, message_module):
757    """This didn't use to work in the v2 C++ implementation."""
758    m = message_module.TestAllTypes()
759    with self.assertRaises(NameError) as _:
760      m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
761    with self.assertRaises(NameError) as _:
762      m.repeated_nested_enum.extend(
763          a for i in range(10))  # pylint: disable=undefined-variable
764
765  FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
766
767  def testExtendInt32WithNothing(self, message_module):
768    """Test no-ops extending repeated int32 fields."""
769    m = message_module.TestAllTypes()
770    self.assertSequenceEqual([], m.repeated_int32)
771
772    # TODO(ptucker): Deprecate this behavior. b/18413862
773    for falsy_value in MessageTest.FALSY_VALUES:
774      m.repeated_int32.extend(falsy_value)
775      self.assertSequenceEqual([], m.repeated_int32)
776
777    m.repeated_int32.extend([])
778    self.assertSequenceEqual([], m.repeated_int32)
779
780  def testExtendFloatWithNothing(self, message_module):
781    """Test no-ops extending repeated float fields."""
782    m = message_module.TestAllTypes()
783    self.assertSequenceEqual([], m.repeated_float)
784
785    # TODO(ptucker): Deprecate this behavior. b/18413862
786    for falsy_value in MessageTest.FALSY_VALUES:
787      m.repeated_float.extend(falsy_value)
788      self.assertSequenceEqual([], m.repeated_float)
789
790    m.repeated_float.extend([])
791    self.assertSequenceEqual([], m.repeated_float)
792
793  def testExtendStringWithNothing(self, message_module):
794    """Test no-ops extending repeated string fields."""
795    m = message_module.TestAllTypes()
796    self.assertSequenceEqual([], m.repeated_string)
797
798    # TODO(ptucker): Deprecate this behavior. b/18413862
799    for falsy_value in MessageTest.FALSY_VALUES:
800      m.repeated_string.extend(falsy_value)
801      self.assertSequenceEqual([], m.repeated_string)
802
803    m.repeated_string.extend([])
804    self.assertSequenceEqual([], m.repeated_string)
805
806  def testExtendInt32WithPythonList(self, message_module):
807    """Test extending repeated int32 fields with python lists."""
808    m = message_module.TestAllTypes()
809    self.assertSequenceEqual([], m.repeated_int32)
810    m.repeated_int32.extend([0])
811    self.assertSequenceEqual([0], m.repeated_int32)
812    m.repeated_int32.extend([1, 2])
813    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
814    m.repeated_int32.extend([3, 4])
815    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
816
817  def testExtendFloatWithPythonList(self, message_module):
818    """Test extending repeated float fields with python lists."""
819    m = message_module.TestAllTypes()
820    self.assertSequenceEqual([], m.repeated_float)
821    m.repeated_float.extend([0.0])
822    self.assertSequenceEqual([0.0], m.repeated_float)
823    m.repeated_float.extend([1.0, 2.0])
824    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
825    m.repeated_float.extend([3.0, 4.0])
826    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
827
828  def testExtendStringWithPythonList(self, message_module):
829    """Test extending repeated string fields with python lists."""
830    m = message_module.TestAllTypes()
831    self.assertSequenceEqual([], m.repeated_string)
832    m.repeated_string.extend([''])
833    self.assertSequenceEqual([''], m.repeated_string)
834    m.repeated_string.extend(['11', '22'])
835    self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
836    m.repeated_string.extend(['33', '44'])
837    self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
838
839  def testExtendStringWithString(self, message_module):
840    """Test extending repeated string fields with characters from a string."""
841    m = message_module.TestAllTypes()
842    self.assertSequenceEqual([], m.repeated_string)
843    m.repeated_string.extend('abc')
844    self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
845
846  class TestIterable(object):
847    """This iterable object mimics the behavior of numpy.array.
848
849    __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
850
851    """
852
853    def __init__(self, values=None):
854      self._list = values or []
855
856    def __nonzero__(self):
857      size = len(self._list)
858      if size == 0:
859        return False
860      if size == 1:
861        return bool(self._list[0])
862      raise ValueError('Truth value is ambiguous.')
863
864    def __len__(self):
865      return len(self._list)
866
867    def __iter__(self):
868      return self._list.__iter__()
869
870  def testExtendInt32WithIterable(self, message_module):
871    """Test extending repeated int32 fields with iterable."""
872    m = message_module.TestAllTypes()
873    self.assertSequenceEqual([], m.repeated_int32)
874    m.repeated_int32.extend(MessageTest.TestIterable([]))
875    self.assertSequenceEqual([], m.repeated_int32)
876    m.repeated_int32.extend(MessageTest.TestIterable([0]))
877    self.assertSequenceEqual([0], m.repeated_int32)
878    m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
879    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
880    m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
881    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
882
883  def testExtendFloatWithIterable(self, message_module):
884    """Test extending repeated float fields with iterable."""
885    m = message_module.TestAllTypes()
886    self.assertSequenceEqual([], m.repeated_float)
887    m.repeated_float.extend(MessageTest.TestIterable([]))
888    self.assertSequenceEqual([], m.repeated_float)
889    m.repeated_float.extend(MessageTest.TestIterable([0.0]))
890    self.assertSequenceEqual([0.0], m.repeated_float)
891    m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
892    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
893    m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
894    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
895
896  def testExtendStringWithIterable(self, message_module):
897    """Test extending repeated string fields with iterable."""
898    m = message_module.TestAllTypes()
899    self.assertSequenceEqual([], m.repeated_string)
900    m.repeated_string.extend(MessageTest.TestIterable([]))
901    self.assertSequenceEqual([], m.repeated_string)
902    m.repeated_string.extend(MessageTest.TestIterable(['']))
903    self.assertSequenceEqual([''], m.repeated_string)
904    m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
905    self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
906    m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
907    self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
908
909  def testPickleRepeatedScalarContainer(self, message_module):
910    # TODO(tibell): The pure-Python implementation support pickling of
911    #   scalar containers in *some* cases. For now the cpp2 version
912    #   throws an exception to avoid a segfault. Investigate if we
913    #   want to support pickling of these fields.
914    #
915    # For more information see: https://b2.corp.google.com/u/0/issues/18677897
916    if (api_implementation.Type() != 'cpp' or
917        api_implementation.Version() == 2):
918      return
919    m = message_module.TestAllTypes()
920    with self.assertRaises(pickle.PickleError) as _:
921      pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
922
923  def testSortEmptyRepeatedCompositeContainer(self, message_module):
924    """Exercise a scenario that has led to segfaults in the past.
925    """
926    m = message_module.TestAllTypes()
927    m.repeated_nested_message.sort()
928
929  def testHasFieldOnRepeatedField(self, message_module):
930    """Using HasField on a repeated field should raise an exception.
931    """
932    m = message_module.TestAllTypes()
933    with self.assertRaises(ValueError) as _:
934      m.HasField('repeated_int32')
935
936  def testRepeatedScalarFieldPop(self, message_module):
937    m = message_module.TestAllTypes()
938    with self.assertRaises(IndexError) as _:
939      m.repeated_int32.pop()
940    m.repeated_int32.extend(range(5))
941    self.assertEqual(4, m.repeated_int32.pop())
942    self.assertEqual(0, m.repeated_int32.pop(0))
943    self.assertEqual(2, m.repeated_int32.pop(1))
944    self.assertEqual([1, 3], m.repeated_int32)
945
946  def testRepeatedCompositeFieldPop(self, message_module):
947    m = message_module.TestAllTypes()
948    with self.assertRaises(IndexError) as _:
949      m.repeated_nested_message.pop()
950    for i in range(5):
951      n = m.repeated_nested_message.add()
952      n.bb = i
953    self.assertEqual(4, m.repeated_nested_message.pop().bb)
954    self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
955    self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
956    self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
957
958
959# Class to test proto2-only features (required, extensions, etc.)
960class Proto2Test(unittest.TestCase):
961
962  def testFieldPresence(self):
963    message = unittest_pb2.TestAllTypes()
964
965    self.assertFalse(message.HasField("optional_int32"))
966    self.assertFalse(message.HasField("optional_bool"))
967    self.assertFalse(message.HasField("optional_nested_message"))
968
969    with self.assertRaises(ValueError):
970      message.HasField("field_doesnt_exist")
971
972    with self.assertRaises(ValueError):
973      message.HasField("repeated_int32")
974    with self.assertRaises(ValueError):
975      message.HasField("repeated_nested_message")
976
977    self.assertEqual(0, message.optional_int32)
978    self.assertEqual(False, message.optional_bool)
979    self.assertEqual(0, message.optional_nested_message.bb)
980
981    # Fields are set even when setting the values to default values.
982    message.optional_int32 = 0
983    message.optional_bool = False
984    message.optional_nested_message.bb = 0
985    self.assertTrue(message.HasField("optional_int32"))
986    self.assertTrue(message.HasField("optional_bool"))
987    self.assertTrue(message.HasField("optional_nested_message"))
988
989    # Set the fields to non-default values.
990    message.optional_int32 = 5
991    message.optional_bool = True
992    message.optional_nested_message.bb = 15
993
994    self.assertTrue(message.HasField("optional_int32"))
995    self.assertTrue(message.HasField("optional_bool"))
996    self.assertTrue(message.HasField("optional_nested_message"))
997
998    # Clearing the fields unsets them and resets their value to default.
999    message.ClearField("optional_int32")
1000    message.ClearField("optional_bool")
1001    message.ClearField("optional_nested_message")
1002
1003    self.assertFalse(message.HasField("optional_int32"))
1004    self.assertFalse(message.HasField("optional_bool"))
1005    self.assertFalse(message.HasField("optional_nested_message"))
1006    self.assertEqual(0, message.optional_int32)
1007    self.assertEqual(False, message.optional_bool)
1008    self.assertEqual(0, message.optional_nested_message.bb)
1009
1010  # TODO(tibell): The C++ implementations actually allows assignment
1011  # of unknown enum values to *scalar* fields (but not repeated
1012  # fields). Once checked enum fields becomes the default in the
1013  # Python implementation, the C++ implementation should follow suit.
1014  def testAssignInvalidEnum(self):
1015    """It should not be possible to assign an invalid enum number to an
1016    enum field."""
1017    m = unittest_pb2.TestAllTypes()
1018
1019    with self.assertRaises(ValueError) as _:
1020      m.optional_nested_enum = 1234567
1021    self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1022
1023  def testGoldenExtensions(self):
1024    golden_data = test_util.GoldenFileData('golden_message')
1025    golden_message = unittest_pb2.TestAllExtensions()
1026    golden_message.ParseFromString(golden_data)
1027    all_set = unittest_pb2.TestAllExtensions()
1028    test_util.SetAllExtensions(all_set)
1029    self.assertEqual(all_set, golden_message)
1030    self.assertEqual(golden_data, golden_message.SerializeToString())
1031    golden_copy = copy.deepcopy(golden_message)
1032    self.assertEqual(golden_data, golden_copy.SerializeToString())
1033
1034  def testGoldenPackedExtensions(self):
1035    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1036    golden_message = unittest_pb2.TestPackedExtensions()
1037    golden_message.ParseFromString(golden_data)
1038    all_set = unittest_pb2.TestPackedExtensions()
1039    test_util.SetAllPackedExtensions(all_set)
1040    self.assertEqual(all_set, golden_message)
1041    self.assertEqual(golden_data, all_set.SerializeToString())
1042    golden_copy = copy.deepcopy(golden_message)
1043    self.assertEqual(golden_data, golden_copy.SerializeToString())
1044
1045  def testPickleIncompleteProto(self):
1046    golden_message = unittest_pb2.TestRequired(a=1)
1047    pickled_message = pickle.dumps(golden_message)
1048
1049    unpickled_message = pickle.loads(pickled_message)
1050    self.assertEqual(unpickled_message, golden_message)
1051    self.assertEqual(unpickled_message.a, 1)
1052    # This is still an incomplete proto - so serializing should fail
1053    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1054
1055
1056  # TODO(haberman): this isn't really a proto2-specific test except that this
1057  # message has a required field in it.  Should probably be factored out so
1058  # that we can test the other parts with proto3.
1059  def testParsingMerge(self):
1060    """Check the merge behavior when a required or optional field appears
1061    multiple times in the input."""
1062    messages = [
1063        unittest_pb2.TestAllTypes(),
1064        unittest_pb2.TestAllTypes(),
1065        unittest_pb2.TestAllTypes() ]
1066    messages[0].optional_int32 = 1
1067    messages[1].optional_int64 = 2
1068    messages[2].optional_int32 = 3
1069    messages[2].optional_string = 'hello'
1070
1071    merged_message = unittest_pb2.TestAllTypes()
1072    merged_message.optional_int32 = 3
1073    merged_message.optional_int64 = 2
1074    merged_message.optional_string = 'hello'
1075
1076    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1077    generator.field1.extend(messages)
1078    generator.field2.extend(messages)
1079    generator.field3.extend(messages)
1080    generator.ext1.extend(messages)
1081    generator.ext2.extend(messages)
1082    generator.group1.add().field1.MergeFrom(messages[0])
1083    generator.group1.add().field1.MergeFrom(messages[1])
1084    generator.group1.add().field1.MergeFrom(messages[2])
1085    generator.group2.add().field1.MergeFrom(messages[0])
1086    generator.group2.add().field1.MergeFrom(messages[1])
1087    generator.group2.add().field1.MergeFrom(messages[2])
1088
1089    data = generator.SerializeToString()
1090    parsing_merge = unittest_pb2.TestParsingMerge()
1091    parsing_merge.ParseFromString(data)
1092
1093    # Required and optional fields should be merged.
1094    self.assertEqual(parsing_merge.required_all_types, merged_message)
1095    self.assertEqual(parsing_merge.optional_all_types, merged_message)
1096    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1097                     merged_message)
1098    self.assertEqual(parsing_merge.Extensions[
1099                     unittest_pb2.TestParsingMerge.optional_ext],
1100                     merged_message)
1101
1102    # Repeated fields should not be merged.
1103    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1104    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1105    self.assertEqual(len(parsing_merge.Extensions[
1106        unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1107
1108  def testPythonicInit(self):
1109    message = unittest_pb2.TestAllTypes(
1110        optional_int32=100,
1111        optional_fixed32=200,
1112        optional_float=300.5,
1113        optional_bytes=b'x',
1114        optionalgroup={'a': 400},
1115        optional_nested_message={'bb': 500},
1116        optional_nested_enum='BAZ',
1117        repeatedgroup=[{'a': 600},
1118                       {'a': 700}],
1119        repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1120        default_int32=800,
1121        oneof_string='y')
1122    self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1123    self.assertEqual(100, message.optional_int32)
1124    self.assertEqual(200, message.optional_fixed32)
1125    self.assertEqual(300.5, message.optional_float)
1126    self.assertEqual(b'x', message.optional_bytes)
1127    self.assertEqual(400, message.optionalgroup.a)
1128    self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
1129    self.assertEqual(500, message.optional_nested_message.bb)
1130    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1131                     message.optional_nested_enum)
1132    self.assertEqual(2, len(message.repeatedgroup))
1133    self.assertEqual(600, message.repeatedgroup[0].a)
1134    self.assertEqual(700, message.repeatedgroup[1].a)
1135    self.assertEqual(2, len(message.repeated_nested_enum))
1136    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1137                     message.repeated_nested_enum[0])
1138    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1139                     message.repeated_nested_enum[1])
1140    self.assertEqual(800, message.default_int32)
1141    self.assertEqual('y', message.oneof_string)
1142    self.assertFalse(message.HasField('optional_int64'))
1143    self.assertEqual(0, len(message.repeated_float))
1144    self.assertEqual(42, message.default_int64)
1145
1146    message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1147    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1148                     message.optional_nested_enum)
1149
1150    with self.assertRaises(ValueError):
1151      unittest_pb2.TestAllTypes(
1152          optional_nested_message={'INVALID_NESTED_FIELD': 17})
1153
1154    with self.assertRaises(TypeError):
1155      unittest_pb2.TestAllTypes(
1156          optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1157
1158    with self.assertRaises(ValueError):
1159      unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1160
1161    with self.assertRaises(ValueError):
1162      unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1163
1164
1165
1166# Class to test proto3-only features/behavior (updated field presence & enums)
1167class Proto3Test(unittest.TestCase):
1168
1169  # Utility method for comparing equality with a map.
1170  def assertMapIterEquals(self, map_iter, dict_value):
1171    # Avoid mutating caller's copy.
1172    dict_value = dict(dict_value)
1173
1174    for k, v in map_iter:
1175      self.assertEqual(v, dict_value[k])
1176      del dict_value[k]
1177
1178    self.assertEqual({}, dict_value)
1179
1180  def testFieldPresence(self):
1181    message = unittest_proto3_arena_pb2.TestAllTypes()
1182
1183    # We can't test presence of non-repeated, non-submessage fields.
1184    with self.assertRaises(ValueError):
1185      message.HasField('optional_int32')
1186    with self.assertRaises(ValueError):
1187      message.HasField('optional_float')
1188    with self.assertRaises(ValueError):
1189      message.HasField('optional_string')
1190    with self.assertRaises(ValueError):
1191      message.HasField('optional_bool')
1192
1193    # But we can still test presence of submessage fields.
1194    self.assertFalse(message.HasField('optional_nested_message'))
1195
1196    # As with proto2, we can't test presence of fields that don't exist, or
1197    # repeated fields.
1198    with self.assertRaises(ValueError):
1199      message.HasField('field_doesnt_exist')
1200
1201    with self.assertRaises(ValueError):
1202      message.HasField('repeated_int32')
1203    with self.assertRaises(ValueError):
1204      message.HasField('repeated_nested_message')
1205
1206    # Fields should default to their type-specific default.
1207    self.assertEqual(0, message.optional_int32)
1208    self.assertEqual(0, message.optional_float)
1209    self.assertEqual('', message.optional_string)
1210    self.assertEqual(False, message.optional_bool)
1211    self.assertEqual(0, message.optional_nested_message.bb)
1212
1213    # Setting a submessage should still return proper presence information.
1214    message.optional_nested_message.bb = 0
1215    self.assertTrue(message.HasField('optional_nested_message'))
1216
1217    # Set the fields to non-default values.
1218    message.optional_int32 = 5
1219    message.optional_float = 1.1
1220    message.optional_string = 'abc'
1221    message.optional_bool = True
1222    message.optional_nested_message.bb = 15
1223
1224    # Clearing the fields unsets them and resets their value to default.
1225    message.ClearField('optional_int32')
1226    message.ClearField('optional_float')
1227    message.ClearField('optional_string')
1228    message.ClearField('optional_bool')
1229    message.ClearField('optional_nested_message')
1230
1231    self.assertEqual(0, message.optional_int32)
1232    self.assertEqual(0, message.optional_float)
1233    self.assertEqual('', message.optional_string)
1234    self.assertEqual(False, message.optional_bool)
1235    self.assertEqual(0, message.optional_nested_message.bb)
1236
1237  def testAssignUnknownEnum(self):
1238    """Assigning an unknown enum value is allowed and preserves the value."""
1239    m = unittest_proto3_arena_pb2.TestAllTypes()
1240
1241    m.optional_nested_enum = 1234567
1242    self.assertEqual(1234567, m.optional_nested_enum)
1243    m.repeated_nested_enum.append(22334455)
1244    self.assertEqual(22334455, m.repeated_nested_enum[0])
1245    # Assignment is a different code path than append for the C++ impl.
1246    m.repeated_nested_enum[0] = 7654321
1247    self.assertEqual(7654321, m.repeated_nested_enum[0])
1248    serialized = m.SerializeToString()
1249
1250    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1251    m2.ParseFromString(serialized)
1252    self.assertEqual(1234567, m2.optional_nested_enum)
1253    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1254
1255  # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1256  # of google/protobuf/map_unittest.proto right now, so it's not easy to
1257  # test both with the same test like we do for the other proto2/proto3 tests.
1258  # (google/protobuf/map_protobuf_unittest.proto is very different in the set
1259  # of messages and fields it contains).
1260  def testScalarMapDefaults(self):
1261    msg = map_unittest_pb2.TestMap()
1262
1263    # Scalars start out unset.
1264    self.assertFalse(-123 in msg.map_int32_int32)
1265    self.assertFalse(-2**33 in msg.map_int64_int64)
1266    self.assertFalse(123 in msg.map_uint32_uint32)
1267    self.assertFalse(2**33 in msg.map_uint64_uint64)
1268    self.assertFalse(123 in msg.map_int32_double)
1269    self.assertFalse(False in msg.map_bool_bool)
1270    self.assertFalse('abc' in msg.map_string_string)
1271    self.assertFalse(111 in msg.map_int32_bytes)
1272    self.assertFalse(888 in msg.map_int32_enum)
1273
1274    # Accessing an unset key returns the default.
1275    self.assertEqual(0, msg.map_int32_int32[-123])
1276    self.assertEqual(0, msg.map_int64_int64[-2**33])
1277    self.assertEqual(0, msg.map_uint32_uint32[123])
1278    self.assertEqual(0, msg.map_uint64_uint64[2**33])
1279    self.assertEqual(0.0, msg.map_int32_double[123])
1280    self.assertTrue(isinstance(msg.map_int32_double[123], float))
1281    self.assertEqual(False, msg.map_bool_bool[False])
1282    self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
1283    self.assertEqual('', msg.map_string_string['abc'])
1284    self.assertEqual(b'', msg.map_int32_bytes[111])
1285    self.assertEqual(0, msg.map_int32_enum[888])
1286
1287    # It also sets the value in the map
1288    self.assertTrue(-123 in msg.map_int32_int32)
1289    self.assertTrue(-2**33 in msg.map_int64_int64)
1290    self.assertTrue(123 in msg.map_uint32_uint32)
1291    self.assertTrue(2**33 in msg.map_uint64_uint64)
1292    self.assertTrue(123 in msg.map_int32_double)
1293    self.assertTrue(False in msg.map_bool_bool)
1294    self.assertTrue('abc' in msg.map_string_string)
1295    self.assertTrue(111 in msg.map_int32_bytes)
1296    self.assertTrue(888 in msg.map_int32_enum)
1297
1298    self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
1299
1300    # Accessing an unset key still throws TypeError if the type of the key
1301    # is incorrect.
1302    with self.assertRaises(TypeError):
1303      msg.map_string_string[123]
1304
1305    with self.assertRaises(TypeError):
1306      123 in msg.map_string_string
1307
1308  def testMapGet(self):
1309    # Need to test that get() properly returns the default, even though the dict
1310    # has defaultdict-like semantics.
1311    msg = map_unittest_pb2.TestMap()
1312
1313    self.assertIsNone(msg.map_int32_int32.get(5))
1314    self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1315    self.assertIsNone(msg.map_int32_int32.get(5))
1316
1317    msg.map_int32_int32[5] = 15
1318    self.assertEqual(15, msg.map_int32_int32.get(5))
1319
1320    self.assertIsNone(msg.map_int32_foreign_message.get(5))
1321    self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1322
1323    submsg = msg.map_int32_foreign_message[5]
1324    self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1325
1326  def testScalarMap(self):
1327    msg = map_unittest_pb2.TestMap()
1328
1329    self.assertEqual(0, len(msg.map_int32_int32))
1330    self.assertFalse(5 in msg.map_int32_int32)
1331
1332    msg.map_int32_int32[-123] = -456
1333    msg.map_int64_int64[-2**33] = -2**34
1334    msg.map_uint32_uint32[123] = 456
1335    msg.map_uint64_uint64[2**33] = 2**34
1336    msg.map_string_string['abc'] = '123'
1337    msg.map_int32_enum[888] = 2
1338
1339    self.assertEqual([], msg.FindInitializationErrors())
1340
1341    self.assertEqual(1, len(msg.map_string_string))
1342
1343    # Bad key.
1344    with self.assertRaises(TypeError):
1345      msg.map_string_string[123] = '123'
1346
1347    # Verify that trying to assign a bad key doesn't actually add a member to
1348    # the map.
1349    self.assertEqual(1, len(msg.map_string_string))
1350
1351    # Bad value.
1352    with self.assertRaises(TypeError):
1353      msg.map_string_string['123'] = 123
1354
1355    serialized = msg.SerializeToString()
1356    msg2 = map_unittest_pb2.TestMap()
1357    msg2.ParseFromString(serialized)
1358
1359    # Bad key.
1360    with self.assertRaises(TypeError):
1361      msg2.map_string_string[123] = '123'
1362
1363    # Bad value.
1364    with self.assertRaises(TypeError):
1365      msg2.map_string_string['123'] = 123
1366
1367    self.assertEqual(-456, msg2.map_int32_int32[-123])
1368    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1369    self.assertEqual(456, msg2.map_uint32_uint32[123])
1370    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1371    self.assertEqual('123', msg2.map_string_string['abc'])
1372    self.assertEqual(2, msg2.map_int32_enum[888])
1373
1374  def testStringUnicodeConversionInMap(self):
1375    msg = map_unittest_pb2.TestMap()
1376
1377    unicode_obj = u'\u1234'
1378    bytes_obj = unicode_obj.encode('utf8')
1379
1380    msg.map_string_string[bytes_obj] = bytes_obj
1381
1382    (key, value) = list(msg.map_string_string.items())[0]
1383
1384    self.assertEqual(key, unicode_obj)
1385    self.assertEqual(value, unicode_obj)
1386
1387    self.assertIsInstance(key, six.text_type)
1388    self.assertIsInstance(value, six.text_type)
1389
1390  def testMessageMap(self):
1391    msg = map_unittest_pb2.TestMap()
1392
1393    self.assertEqual(0, len(msg.map_int32_foreign_message))
1394    self.assertFalse(5 in msg.map_int32_foreign_message)
1395
1396    msg.map_int32_foreign_message[123]
1397    # get_or_create() is an alias for getitem.
1398    msg.map_int32_foreign_message.get_or_create(-456)
1399
1400    self.assertEqual(2, len(msg.map_int32_foreign_message))
1401    self.assertIn(123, msg.map_int32_foreign_message)
1402    self.assertIn(-456, msg.map_int32_foreign_message)
1403    self.assertEqual(2, len(msg.map_int32_foreign_message))
1404
1405    # Bad key.
1406    with self.assertRaises(TypeError):
1407      msg.map_int32_foreign_message['123']
1408
1409    # Can't assign directly to submessage.
1410    with self.assertRaises(ValueError):
1411      msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1412
1413    # Verify that trying to assign a bad key doesn't actually add a member to
1414    # the map.
1415    self.assertEqual(2, len(msg.map_int32_foreign_message))
1416
1417    serialized = msg.SerializeToString()
1418    msg2 = map_unittest_pb2.TestMap()
1419    msg2.ParseFromString(serialized)
1420
1421    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1422    self.assertIn(123, msg2.map_int32_foreign_message)
1423    self.assertIn(-456, msg2.map_int32_foreign_message)
1424    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1425
1426  def testMergeFrom(self):
1427    msg = map_unittest_pb2.TestMap()
1428    msg.map_int32_int32[12] = 34
1429    msg.map_int32_int32[56] = 78
1430    msg.map_int64_int64[22] = 33
1431    msg.map_int32_foreign_message[111].c = 5
1432    msg.map_int32_foreign_message[222].c = 10
1433
1434    msg2 = map_unittest_pb2.TestMap()
1435    msg2.map_int32_int32[12] = 55
1436    msg2.map_int64_int64[88] = 99
1437    msg2.map_int32_foreign_message[222].c = 15
1438
1439    msg2.MergeFrom(msg)
1440
1441    self.assertEqual(34, msg2.map_int32_int32[12])
1442    self.assertEqual(78, msg2.map_int32_int32[56])
1443    self.assertEqual(33, msg2.map_int64_int64[22])
1444    self.assertEqual(99, msg2.map_int64_int64[88])
1445    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1446    self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
1447
1448    # Verify that there is only one entry per key, even though the MergeFrom
1449    # may have internally created multiple entries for a single key in the
1450    # list representation.
1451    as_dict = {}
1452    for key in msg2.map_int32_foreign_message:
1453      self.assertFalse(key in as_dict)
1454      as_dict[key] = msg2.map_int32_foreign_message[key].c
1455
1456    self.assertEqual({111: 5, 222: 10}, as_dict)
1457
1458    # Special case: test that delete of item really removes the item, even if
1459    # there might have physically been duplicate keys due to the previous merge.
1460    # This is only a special case for the C++ implementation which stores the
1461    # map as an array.
1462    del msg2.map_int32_int32[12]
1463    self.assertFalse(12 in msg2.map_int32_int32)
1464
1465    del msg2.map_int32_foreign_message[222]
1466    self.assertFalse(222 in msg2.map_int32_foreign_message)
1467
1468  def testMergeFromBadType(self):
1469    msg = map_unittest_pb2.TestMap()
1470    with self.assertRaisesRegexp(
1471        TypeError,
1472        r'Parameter to MergeFrom\(\) must be instance of same class: expected '
1473        r'.*TestMap got int\.'):
1474      msg.MergeFrom(1)
1475
1476  def testCopyFromBadType(self):
1477    msg = map_unittest_pb2.TestMap()
1478    with self.assertRaisesRegexp(
1479        TypeError,
1480        r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
1481        r'expected .*TestMap got int\.'):
1482      msg.CopyFrom(1)
1483
1484  def testIntegerMapWithLongs(self):
1485    msg = map_unittest_pb2.TestMap()
1486    msg.map_int32_int32[long(-123)] = long(-456)
1487    msg.map_int64_int64[long(-2**33)] = long(-2**34)
1488    msg.map_uint32_uint32[long(123)] = long(456)
1489    msg.map_uint64_uint64[long(2**33)] = long(2**34)
1490
1491    serialized = msg.SerializeToString()
1492    msg2 = map_unittest_pb2.TestMap()
1493    msg2.ParseFromString(serialized)
1494
1495    self.assertEqual(-456, msg2.map_int32_int32[-123])
1496    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1497    self.assertEqual(456, msg2.map_uint32_uint32[123])
1498    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1499
1500  def testMapAssignmentCausesPresence(self):
1501    msg = map_unittest_pb2.TestMapSubmessage()
1502    msg.test_map.map_int32_int32[123] = 456
1503
1504    serialized = msg.SerializeToString()
1505    msg2 = map_unittest_pb2.TestMapSubmessage()
1506    msg2.ParseFromString(serialized)
1507
1508    self.assertEqual(msg, msg2)
1509
1510    # Now test that various mutations of the map properly invalidate the
1511    # cached size of the submessage.
1512    msg.test_map.map_int32_int32[888] = 999
1513    serialized = msg.SerializeToString()
1514    msg2.ParseFromString(serialized)
1515    self.assertEqual(msg, msg2)
1516
1517    msg.test_map.map_int32_int32.clear()
1518    serialized = msg.SerializeToString()
1519    msg2.ParseFromString(serialized)
1520    self.assertEqual(msg, msg2)
1521
1522  def testMapAssignmentCausesPresenceForSubmessages(self):
1523    msg = map_unittest_pb2.TestMapSubmessage()
1524    msg.test_map.map_int32_foreign_message[123].c = 5
1525
1526    serialized = msg.SerializeToString()
1527    msg2 = map_unittest_pb2.TestMapSubmessage()
1528    msg2.ParseFromString(serialized)
1529
1530    self.assertEqual(msg, msg2)
1531
1532    # Now test that various mutations of the map properly invalidate the
1533    # cached size of the submessage.
1534    msg.test_map.map_int32_foreign_message[888].c = 7
1535    serialized = msg.SerializeToString()
1536    msg2.ParseFromString(serialized)
1537    self.assertEqual(msg, msg2)
1538
1539    msg.test_map.map_int32_foreign_message[888].MergeFrom(
1540        msg.test_map.map_int32_foreign_message[123])
1541    serialized = msg.SerializeToString()
1542    msg2.ParseFromString(serialized)
1543    self.assertEqual(msg, msg2)
1544
1545    msg.test_map.map_int32_foreign_message.clear()
1546    serialized = msg.SerializeToString()
1547    msg2.ParseFromString(serialized)
1548    self.assertEqual(msg, msg2)
1549
1550  def testModifyMapWhileIterating(self):
1551    msg = map_unittest_pb2.TestMap()
1552
1553    string_string_iter = iter(msg.map_string_string)
1554    int32_foreign_iter = iter(msg.map_int32_foreign_message)
1555
1556    msg.map_string_string['abc'] = '123'
1557    msg.map_int32_foreign_message[5].c = 5
1558
1559    with self.assertRaises(RuntimeError):
1560      for key in string_string_iter:
1561        pass
1562
1563    with self.assertRaises(RuntimeError):
1564      for key in int32_foreign_iter:
1565        pass
1566
1567  def testSubmessageMap(self):
1568    msg = map_unittest_pb2.TestMap()
1569
1570    submsg = msg.map_int32_foreign_message[111]
1571    self.assertIs(submsg, msg.map_int32_foreign_message[111])
1572    self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
1573
1574    submsg.c = 5
1575
1576    serialized = msg.SerializeToString()
1577    msg2 = map_unittest_pb2.TestMap()
1578    msg2.ParseFromString(serialized)
1579
1580    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1581
1582    # Doesn't allow direct submessage assignment.
1583    with self.assertRaises(ValueError):
1584      msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
1585
1586  def testMapIteration(self):
1587    msg = map_unittest_pb2.TestMap()
1588
1589    for k, v in msg.map_int32_int32.items():
1590      # Should not be reached.
1591      self.assertTrue(False)
1592
1593    msg.map_int32_int32[2] = 4
1594    msg.map_int32_int32[3] = 6
1595    msg.map_int32_int32[4] = 8
1596    self.assertEqual(3, len(msg.map_int32_int32))
1597
1598    matching_dict = {2: 4, 3: 6, 4: 8}
1599    self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
1600
1601  def testMapItems(self):
1602    # Map items used to have strange behaviors when use c extension. Because
1603    # [] may reorder the map and invalidate any exsting iterators.
1604    # TODO(jieluo): Check if [] reordering the map is a bug or intended
1605    # behavior.
1606    msg = map_unittest_pb2.TestMap()
1607    msg.map_string_string['local_init_op'] = ''
1608    msg.map_string_string['trainable_variables'] = ''
1609    msg.map_string_string['variables'] = ''
1610    msg.map_string_string['init_op'] = ''
1611    msg.map_string_string['summaries'] = ''
1612    items1 = msg.map_string_string.items()
1613    items2 = msg.map_string_string.items()
1614    self.assertEqual(items1, items2)
1615
1616  def testMapIterationClearMessage(self):
1617    # Iterator needs to work even if message and map are deleted.
1618    msg = map_unittest_pb2.TestMap()
1619
1620    msg.map_int32_int32[2] = 4
1621    msg.map_int32_int32[3] = 6
1622    msg.map_int32_int32[4] = 8
1623
1624    it = msg.map_int32_int32.items()
1625    del msg
1626
1627    matching_dict = {2: 4, 3: 6, 4: 8}
1628    self.assertMapIterEquals(it, matching_dict)
1629
1630  def testMapConstruction(self):
1631    msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
1632    self.assertEqual(2, msg.map_int32_int32[1])
1633    self.assertEqual(4, msg.map_int32_int32[3])
1634
1635    msg = map_unittest_pb2.TestMap(
1636        map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
1637    self.assertEqual(5, msg.map_int32_foreign_message[3].c)
1638
1639  def testMapValidAfterFieldCleared(self):
1640    # Map needs to work even if field is cleared.
1641    # For the C++ implementation this tests the correctness of
1642    # ScalarMapContainer::Release()
1643    msg = map_unittest_pb2.TestMap()
1644    int32_map = msg.map_int32_int32
1645
1646    int32_map[2] = 4
1647    int32_map[3] = 6
1648    int32_map[4] = 8
1649
1650    msg.ClearField('map_int32_int32')
1651    self.assertEqual(b'', msg.SerializeToString())
1652    matching_dict = {2: 4, 3: 6, 4: 8}
1653    self.assertMapIterEquals(int32_map.items(), matching_dict)
1654
1655  def testMessageMapValidAfterFieldCleared(self):
1656    # Map needs to work even if field is cleared.
1657    # For the C++ implementation this tests the correctness of
1658    # ScalarMapContainer::Release()
1659    msg = map_unittest_pb2.TestMap()
1660    int32_foreign_message = msg.map_int32_foreign_message
1661
1662    int32_foreign_message[2].c = 5
1663
1664    msg.ClearField('map_int32_foreign_message')
1665    self.assertEqual(b'', msg.SerializeToString())
1666    self.assertTrue(2 in int32_foreign_message.keys())
1667
1668  def testMapIterInvalidatedByClearField(self):
1669    # Map iterator is invalidated when field is cleared.
1670    # But this case does need to not crash the interpreter.
1671    # For the C++ implementation this tests the correctness of
1672    # ScalarMapContainer::Release()
1673    msg = map_unittest_pb2.TestMap()
1674
1675    it = iter(msg.map_int32_int32)
1676
1677    msg.ClearField('map_int32_int32')
1678    with self.assertRaises(RuntimeError):
1679      for _ in it:
1680        pass
1681
1682    it = iter(msg.map_int32_foreign_message)
1683    msg.ClearField('map_int32_foreign_message')
1684    with self.assertRaises(RuntimeError):
1685      for _ in it:
1686        pass
1687
1688  def testMapDelete(self):
1689    msg = map_unittest_pb2.TestMap()
1690
1691    self.assertEqual(0, len(msg.map_int32_int32))
1692
1693    msg.map_int32_int32[4] = 6
1694    self.assertEqual(1, len(msg.map_int32_int32))
1695
1696    with self.assertRaises(KeyError):
1697      del msg.map_int32_int32[88]
1698
1699    del msg.map_int32_int32[4]
1700    self.assertEqual(0, len(msg.map_int32_int32))
1701
1702  def testMapsAreMapping(self):
1703    msg = map_unittest_pb2.TestMap()
1704    self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
1705    self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping)
1706    self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping)
1707    self.assertIsInstance(msg.map_int32_foreign_message,
1708                          collections.MutableMapping)
1709
1710  def testMapFindInitializationErrorsSmokeTest(self):
1711    msg = map_unittest_pb2.TestMap()
1712    msg.map_string_string['abc'] = '123'
1713    msg.map_int32_int32[35] = 64
1714    msg.map_string_foreign_message['foo'].c = 5
1715    self.assertEqual(0, len(msg.FindInitializationErrors()))
1716
1717
1718
1719class ValidTypeNamesTest(unittest.TestCase):
1720
1721  def assertImportFromName(self, msg, base_name):
1722    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
1723    tp_name = str(type(msg)).split("'")[1]
1724    valid_names = ('Repeated%sContainer' % base_name,
1725                   'Repeated%sFieldContainer' % base_name)
1726    self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
1727                    '%r does end with any of %r' % (tp_name, valid_names))
1728
1729    parts = tp_name.split('.')
1730    class_name = parts[-1]
1731    module_name = '.'.join(parts[:-1])
1732    __import__(module_name, fromlist=[class_name])
1733
1734  def testTypeNamesCanBeImported(self):
1735    # If import doesn't work, pickling won't work either.
1736    pb = unittest_pb2.TestAllTypes()
1737    self.assertImportFromName(pb.repeated_int32, 'Scalar')
1738    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
1739
1740class PackedFieldTest(unittest.TestCase):
1741
1742  def setMessage(self, message):
1743    message.repeated_int32.append(1)
1744    message.repeated_int64.append(1)
1745    message.repeated_uint32.append(1)
1746    message.repeated_uint64.append(1)
1747    message.repeated_sint32.append(1)
1748    message.repeated_sint64.append(1)
1749    message.repeated_fixed32.append(1)
1750    message.repeated_fixed64.append(1)
1751    message.repeated_sfixed32.append(1)
1752    message.repeated_sfixed64.append(1)
1753    message.repeated_float.append(1.0)
1754    message.repeated_double.append(1.0)
1755    message.repeated_bool.append(True)
1756    message.repeated_nested_enum.append(1)
1757
1758  def testPackedFields(self):
1759    message = packed_field_test_pb2.TestPackedTypes()
1760    self.setMessage(message)
1761    golden_data = (b'\x0A\x01\x01'
1762                   b'\x12\x01\x01'
1763                   b'\x1A\x01\x01'
1764                   b'\x22\x01\x01'
1765                   b'\x2A\x01\x02'
1766                   b'\x32\x01\x02'
1767                   b'\x3A\x04\x01\x00\x00\x00'
1768                   b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
1769                   b'\x4A\x04\x01\x00\x00\x00'
1770                   b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
1771                   b'\x5A\x04\x00\x00\x80\x3f'
1772                   b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
1773                   b'\x6A\x01\x01'
1774                   b'\x72\x01\x01')
1775    self.assertEqual(golden_data, message.SerializeToString())
1776
1777  def testUnpackedFields(self):
1778    message = packed_field_test_pb2.TestUnpackedTypes()
1779    self.setMessage(message)
1780    golden_data = (b'\x08\x01'
1781                   b'\x10\x01'
1782                   b'\x18\x01'
1783                   b'\x20\x01'
1784                   b'\x28\x02'
1785                   b'\x30\x02'
1786                   b'\x3D\x01\x00\x00\x00'
1787                   b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
1788                   b'\x4D\x01\x00\x00\x00'
1789                   b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
1790                   b'\x5D\x00\x00\x80\x3f'
1791                   b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
1792                   b'\x68\x01'
1793                   b'\x70\x01')
1794    self.assertEqual(golden_data, message.SerializeToString())
1795
1796
1797@unittest.skipIf(api_implementation.Type() != 'cpp',
1798                 'explicit tests of the C++ implementation')
1799class OversizeProtosTest(unittest.TestCase):
1800
1801  def setUp(self):
1802    self.file_desc = """
1803      name: "f/f.msg2"
1804      package: "f"
1805      message_type {
1806        name: "msg1"
1807        field {
1808          name: "payload"
1809          number: 1
1810          label: LABEL_OPTIONAL
1811          type: TYPE_STRING
1812        }
1813      }
1814      message_type {
1815        name: "msg2"
1816        field {
1817          name: "field"
1818          number: 1
1819          label: LABEL_OPTIONAL
1820          type: TYPE_MESSAGE
1821          type_name: "msg1"
1822        }
1823      }
1824    """
1825    pool = descriptor_pool.DescriptorPool()
1826    desc = descriptor_pb2.FileDescriptorProto()
1827    text_format.Parse(self.file_desc, desc)
1828    pool.Add(desc)
1829    self.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
1830        pool.FindMessageTypeByName('f.msg2'))
1831    self.p = self.proto_cls()
1832    self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
1833    self.p_serialized = self.p.SerializeToString()
1834
1835  def testAssertOversizeProto(self):
1836    from google.protobuf.pyext._message import SetAllowOversizeProtos
1837    SetAllowOversizeProtos(False)
1838    q = self.proto_cls()
1839    try:
1840      q.ParseFromString(self.p_serialized)
1841    except message.DecodeError as e:
1842      self.assertEqual(str(e), 'Error parsing message')
1843
1844  def testSucceedOversizeProto(self):
1845    from google.protobuf.pyext._message import SetAllowOversizeProtos
1846    SetAllowOversizeProtos(True)
1847    q = self.proto_cls()
1848    q.ParseFromString(self.p_serialized)
1849    self.assertEqual(self.p.field.payload, q.field.payload)
1850
1851if __name__ == '__main__':
1852  unittest.main()
1853