1#! /usr/bin/python
2#
3# Protocol Buffers - Google's data interchange format
4# Copyright 2008 Google Inc.  All rights reserved.
5# https://developers.google.com/protocol-buffers/
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are
9# met:
10#
11#     * Redistributions of source code must retain the above copyright
12# notice, this list of conditions and the following disclaimer.
13#     * Redistributions in binary form must reproduce the above
14# copyright notice, this list of conditions and the following disclaimer
15# in the documentation and/or other materials provided with the
16# distribution.
17#     * Neither the name of Google Inc. nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33"""Tests python protocol buffers against the golden message.
34
35Note that the golden messages exercise every known field type, thus this
36test ends up exercising and verifying nearly all of the parsing and
37serialization code in the whole library.
38
39TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
40sense to call this a test of the "message" module, which only declares an
41abstract interface.
42"""
43
44__author__ = 'gps@google.com (Gregory P. Smith)'
45
46import copy
47import math
48import operator
49import pickle
50import sys
51
52from google.apputils import basetest
53from google.protobuf import unittest_pb2
54from google.protobuf.internal import api_implementation
55from google.protobuf.internal import test_util
56from google.protobuf import message
57
58# Python pre-2.6 does not have isinf() or isnan() functions, so we have
59# to provide our own.
60def isnan(val):
61  # NaN is never equal to itself.
62  return val != val
63def isinf(val):
64  # Infinity times zero equals NaN.
65  return not isnan(val) and isnan(val * 0)
66def IsPosInf(val):
67  return isinf(val) and (val > 0)
68def IsNegInf(val):
69  return isinf(val) and (val < 0)
70
71
72class MessageTest(basetest.TestCase):
73
74  def testBadUtf8String(self):
75    if api_implementation.Type() != 'python':
76      self.skipTest("Skipping testBadUtf8String, currently only the python "
77                    "api implementation raises UnicodeDecodeError when a "
78                    "string field contains bad utf-8.")
79    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
80    with self.assertRaises(UnicodeDecodeError) as context:
81      unittest_pb2.TestAllTypes.FromString(bad_utf8_data)
82    self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string',
83                  str(context.exception))
84
85  def testGoldenMessage(self):
86    golden_data = test_util.GoldenFileData(
87        'golden_message_oneof_implemented')
88    golden_message = unittest_pb2.TestAllTypes()
89    golden_message.ParseFromString(golden_data)
90    test_util.ExpectAllFieldsSet(self, golden_message)
91    self.assertEqual(golden_data, golden_message.SerializeToString())
92    golden_copy = copy.deepcopy(golden_message)
93    self.assertEqual(golden_data, golden_copy.SerializeToString())
94
95  def testGoldenExtensions(self):
96    golden_data = test_util.GoldenFileData('golden_message')
97    golden_message = unittest_pb2.TestAllExtensions()
98    golden_message.ParseFromString(golden_data)
99    all_set = unittest_pb2.TestAllExtensions()
100    test_util.SetAllExtensions(all_set)
101    self.assertEquals(all_set, golden_message)
102    self.assertEqual(golden_data, golden_message.SerializeToString())
103    golden_copy = copy.deepcopy(golden_message)
104    self.assertEqual(golden_data, golden_copy.SerializeToString())
105
106  def testGoldenPackedMessage(self):
107    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
108    golden_message = unittest_pb2.TestPackedTypes()
109    golden_message.ParseFromString(golden_data)
110    all_set = unittest_pb2.TestPackedTypes()
111    test_util.SetAllPackedFields(all_set)
112    self.assertEquals(all_set, golden_message)
113    self.assertEqual(golden_data, all_set.SerializeToString())
114    golden_copy = copy.deepcopy(golden_message)
115    self.assertEqual(golden_data, golden_copy.SerializeToString())
116
117  def testGoldenPackedExtensions(self):
118    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
119    golden_message = unittest_pb2.TestPackedExtensions()
120    golden_message.ParseFromString(golden_data)
121    all_set = unittest_pb2.TestPackedExtensions()
122    test_util.SetAllPackedExtensions(all_set)
123    self.assertEquals(all_set, golden_message)
124    self.assertEqual(golden_data, all_set.SerializeToString())
125    golden_copy = copy.deepcopy(golden_message)
126    self.assertEqual(golden_data, golden_copy.SerializeToString())
127
128  def testPickleSupport(self):
129    golden_data = test_util.GoldenFileData('golden_message')
130    golden_message = unittest_pb2.TestAllTypes()
131    golden_message.ParseFromString(golden_data)
132    pickled_message = pickle.dumps(golden_message)
133
134    unpickled_message = pickle.loads(pickled_message)
135    self.assertEquals(unpickled_message, golden_message)
136
137
138  def testPickleIncompleteProto(self):
139    golden_message = unittest_pb2.TestRequired(a=1)
140    pickled_message = pickle.dumps(golden_message)
141
142    unpickled_message = pickle.loads(pickled_message)
143    self.assertEquals(unpickled_message, golden_message)
144    self.assertEquals(unpickled_message.a, 1)
145    # This is still an incomplete proto - so serializing should fail
146    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
147
148  def testPositiveInfinity(self):
149    golden_data = (b'\x5D\x00\x00\x80\x7F'
150                   b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
151                   b'\xCD\x02\x00\x00\x80\x7F'
152                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
153    golden_message = unittest_pb2.TestAllTypes()
154    golden_message.ParseFromString(golden_data)
155    self.assertTrue(IsPosInf(golden_message.optional_float))
156    self.assertTrue(IsPosInf(golden_message.optional_double))
157    self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
158    self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
159    self.assertEqual(golden_data, golden_message.SerializeToString())
160
161  def testNegativeInfinity(self):
162    golden_data = (b'\x5D\x00\x00\x80\xFF'
163                   b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
164                   b'\xCD\x02\x00\x00\x80\xFF'
165                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
166    golden_message = unittest_pb2.TestAllTypes()
167    golden_message.ParseFromString(golden_data)
168    self.assertTrue(IsNegInf(golden_message.optional_float))
169    self.assertTrue(IsNegInf(golden_message.optional_double))
170    self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
171    self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
172    self.assertEqual(golden_data, golden_message.SerializeToString())
173
174  def testNotANumber(self):
175    golden_data = (b'\x5D\x00\x00\xC0\x7F'
176                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
177                   b'\xCD\x02\x00\x00\xC0\x7F'
178                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
179    golden_message = unittest_pb2.TestAllTypes()
180    golden_message.ParseFromString(golden_data)
181    self.assertTrue(isnan(golden_message.optional_float))
182    self.assertTrue(isnan(golden_message.optional_double))
183    self.assertTrue(isnan(golden_message.repeated_float[0]))
184    self.assertTrue(isnan(golden_message.repeated_double[0]))
185
186    # The protocol buffer may serialize to any one of multiple different
187    # representations of a NaN.  Rather than verify a specific representation,
188    # verify the serialized string can be converted into a correctly
189    # behaving protocol buffer.
190    serialized = golden_message.SerializeToString()
191    message = unittest_pb2.TestAllTypes()
192    message.ParseFromString(serialized)
193    self.assertTrue(isnan(message.optional_float))
194    self.assertTrue(isnan(message.optional_double))
195    self.assertTrue(isnan(message.repeated_float[0]))
196    self.assertTrue(isnan(message.repeated_double[0]))
197
198  def testPositiveInfinityPacked(self):
199    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
200                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
201    golden_message = unittest_pb2.TestPackedTypes()
202    golden_message.ParseFromString(golden_data)
203    self.assertTrue(IsPosInf(golden_message.packed_float[0]))
204    self.assertTrue(IsPosInf(golden_message.packed_double[0]))
205    self.assertEqual(golden_data, golden_message.SerializeToString())
206
207  def testNegativeInfinityPacked(self):
208    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
209                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
210    golden_message = unittest_pb2.TestPackedTypes()
211    golden_message.ParseFromString(golden_data)
212    self.assertTrue(IsNegInf(golden_message.packed_float[0]))
213    self.assertTrue(IsNegInf(golden_message.packed_double[0]))
214    self.assertEqual(golden_data, golden_message.SerializeToString())
215
216  def testNotANumberPacked(self):
217    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
218                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
219    golden_message = unittest_pb2.TestPackedTypes()
220    golden_message.ParseFromString(golden_data)
221    self.assertTrue(isnan(golden_message.packed_float[0]))
222    self.assertTrue(isnan(golden_message.packed_double[0]))
223
224    serialized = golden_message.SerializeToString()
225    message = unittest_pb2.TestPackedTypes()
226    message.ParseFromString(serialized)
227    self.assertTrue(isnan(message.packed_float[0]))
228    self.assertTrue(isnan(message.packed_double[0]))
229
230  def testExtremeFloatValues(self):
231    message = unittest_pb2.TestAllTypes()
232
233    # Most positive exponent, no significand bits set.
234    kMostPosExponentNoSigBits = math.pow(2, 127)
235    message.optional_float = kMostPosExponentNoSigBits
236    message.ParseFromString(message.SerializeToString())
237    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
238
239    # Most positive exponent, one significand bit set.
240    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
241    message.optional_float = kMostPosExponentOneSigBit
242    message.ParseFromString(message.SerializeToString())
243    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
244
245    # Repeat last two cases with values of same magnitude, but negative.
246    message.optional_float = -kMostPosExponentNoSigBits
247    message.ParseFromString(message.SerializeToString())
248    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
249
250    message.optional_float = -kMostPosExponentOneSigBit
251    message.ParseFromString(message.SerializeToString())
252    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
253
254    # Most negative exponent, no significand bits set.
255    kMostNegExponentNoSigBits = math.pow(2, -127)
256    message.optional_float = kMostNegExponentNoSigBits
257    message.ParseFromString(message.SerializeToString())
258    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
259
260    # Most negative exponent, one significand bit set.
261    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
262    message.optional_float = kMostNegExponentOneSigBit
263    message.ParseFromString(message.SerializeToString())
264    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
265
266    # Repeat last two cases with values of the same magnitude, but negative.
267    message.optional_float = -kMostNegExponentNoSigBits
268    message.ParseFromString(message.SerializeToString())
269    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
270
271    message.optional_float = -kMostNegExponentOneSigBit
272    message.ParseFromString(message.SerializeToString())
273    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
274
275  def testExtremeDoubleValues(self):
276    message = unittest_pb2.TestAllTypes()
277
278    # Most positive exponent, no significand bits set.
279    kMostPosExponentNoSigBits = math.pow(2, 1023)
280    message.optional_double = kMostPosExponentNoSigBits
281    message.ParseFromString(message.SerializeToString())
282    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
283
284    # Most positive exponent, one significand bit set.
285    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
286    message.optional_double = kMostPosExponentOneSigBit
287    message.ParseFromString(message.SerializeToString())
288    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
289
290    # Repeat last two cases with values of same magnitude, but negative.
291    message.optional_double = -kMostPosExponentNoSigBits
292    message.ParseFromString(message.SerializeToString())
293    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
294
295    message.optional_double = -kMostPosExponentOneSigBit
296    message.ParseFromString(message.SerializeToString())
297    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
298
299    # Most negative exponent, no significand bits set.
300    kMostNegExponentNoSigBits = math.pow(2, -1023)
301    message.optional_double = kMostNegExponentNoSigBits
302    message.ParseFromString(message.SerializeToString())
303    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
304
305    # Most negative exponent, one significand bit set.
306    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
307    message.optional_double = kMostNegExponentOneSigBit
308    message.ParseFromString(message.SerializeToString())
309    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
310
311    # Repeat last two cases with values of the same magnitude, but negative.
312    message.optional_double = -kMostNegExponentNoSigBits
313    message.ParseFromString(message.SerializeToString())
314    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
315
316    message.optional_double = -kMostNegExponentOneSigBit
317    message.ParseFromString(message.SerializeToString())
318    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
319
320  def testFloatPrinting(self):
321    message = unittest_pb2.TestAllTypes()
322    message.optional_float = 2.0
323    self.assertEqual(str(message), 'optional_float: 2.0\n')
324
325  def testHighPrecisionFloatPrinting(self):
326    message = unittest_pb2.TestAllTypes()
327    message.optional_double = 0.12345678912345678
328    if sys.version_info.major >= 3:
329      self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
330    else:
331      self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
332
333  def testUnknownFieldPrinting(self):
334    populated = unittest_pb2.TestAllTypes()
335    test_util.SetAllNonLazyFields(populated)
336    empty = unittest_pb2.TestEmptyMessage()
337    empty.ParseFromString(populated.SerializeToString())
338    self.assertEqual(str(empty), '')
339
340  def testSortingRepeatedScalarFieldsDefaultComparator(self):
341    """Check some different types with the default comparator."""
342    message = unittest_pb2.TestAllTypes()
343
344    # TODO(mattp): would testing more scalar types strengthen test?
345    message.repeated_int32.append(1)
346    message.repeated_int32.append(3)
347    message.repeated_int32.append(2)
348    message.repeated_int32.sort()
349    self.assertEqual(message.repeated_int32[0], 1)
350    self.assertEqual(message.repeated_int32[1], 2)
351    self.assertEqual(message.repeated_int32[2], 3)
352
353    message.repeated_float.append(1.1)
354    message.repeated_float.append(1.3)
355    message.repeated_float.append(1.2)
356    message.repeated_float.sort()
357    self.assertAlmostEqual(message.repeated_float[0], 1.1)
358    self.assertAlmostEqual(message.repeated_float[1], 1.2)
359    self.assertAlmostEqual(message.repeated_float[2], 1.3)
360
361    message.repeated_string.append('a')
362    message.repeated_string.append('c')
363    message.repeated_string.append('b')
364    message.repeated_string.sort()
365    self.assertEqual(message.repeated_string[0], 'a')
366    self.assertEqual(message.repeated_string[1], 'b')
367    self.assertEqual(message.repeated_string[2], 'c')
368
369    message.repeated_bytes.append(b'a')
370    message.repeated_bytes.append(b'c')
371    message.repeated_bytes.append(b'b')
372    message.repeated_bytes.sort()
373    self.assertEqual(message.repeated_bytes[0], b'a')
374    self.assertEqual(message.repeated_bytes[1], b'b')
375    self.assertEqual(message.repeated_bytes[2], b'c')
376
377  def testSortingRepeatedScalarFieldsCustomComparator(self):
378    """Check some different types with custom comparator."""
379    message = unittest_pb2.TestAllTypes()
380
381    message.repeated_int32.append(-3)
382    message.repeated_int32.append(-2)
383    message.repeated_int32.append(-1)
384    message.repeated_int32.sort(key=abs)
385    self.assertEqual(message.repeated_int32[0], -1)
386    self.assertEqual(message.repeated_int32[1], -2)
387    self.assertEqual(message.repeated_int32[2], -3)
388
389    message.repeated_string.append('aaa')
390    message.repeated_string.append('bb')
391    message.repeated_string.append('c')
392    message.repeated_string.sort(key=len)
393    self.assertEqual(message.repeated_string[0], 'c')
394    self.assertEqual(message.repeated_string[1], 'bb')
395    self.assertEqual(message.repeated_string[2], 'aaa')
396
397  def testSortingRepeatedCompositeFieldsCustomComparator(self):
398    """Check passing a custom comparator to sort a repeated composite field."""
399    message = unittest_pb2.TestAllTypes()
400
401    message.repeated_nested_message.add().bb = 1
402    message.repeated_nested_message.add().bb = 3
403    message.repeated_nested_message.add().bb = 2
404    message.repeated_nested_message.add().bb = 6
405    message.repeated_nested_message.add().bb = 5
406    message.repeated_nested_message.add().bb = 4
407    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
408    self.assertEqual(message.repeated_nested_message[0].bb, 1)
409    self.assertEqual(message.repeated_nested_message[1].bb, 2)
410    self.assertEqual(message.repeated_nested_message[2].bb, 3)
411    self.assertEqual(message.repeated_nested_message[3].bb, 4)
412    self.assertEqual(message.repeated_nested_message[4].bb, 5)
413    self.assertEqual(message.repeated_nested_message[5].bb, 6)
414
415  def testRepeatedCompositeFieldSortArguments(self):
416    """Check sorting a repeated composite field using list.sort() arguments."""
417    message = unittest_pb2.TestAllTypes()
418
419    get_bb = operator.attrgetter('bb')
420    cmp_bb = lambda a, b: cmp(a.bb, b.bb)
421    message.repeated_nested_message.add().bb = 1
422    message.repeated_nested_message.add().bb = 3
423    message.repeated_nested_message.add().bb = 2
424    message.repeated_nested_message.add().bb = 6
425    message.repeated_nested_message.add().bb = 5
426    message.repeated_nested_message.add().bb = 4
427    message.repeated_nested_message.sort(key=get_bb)
428    self.assertEqual([k.bb for k in message.repeated_nested_message],
429                     [1, 2, 3, 4, 5, 6])
430    message.repeated_nested_message.sort(key=get_bb, reverse=True)
431    self.assertEqual([k.bb for k in message.repeated_nested_message],
432                     [6, 5, 4, 3, 2, 1])
433    if sys.version_info.major >= 3: return  # No cmp sorting in PY3.
434    message.repeated_nested_message.sort(sort_function=cmp_bb)
435    self.assertEqual([k.bb for k in message.repeated_nested_message],
436                     [1, 2, 3, 4, 5, 6])
437    message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
438    self.assertEqual([k.bb for k in message.repeated_nested_message],
439                     [6, 5, 4, 3, 2, 1])
440
441  def testRepeatedScalarFieldSortArguments(self):
442    """Check sorting a scalar field using list.sort() arguments."""
443    message = unittest_pb2.TestAllTypes()
444
445    message.repeated_int32.append(-3)
446    message.repeated_int32.append(-2)
447    message.repeated_int32.append(-1)
448    message.repeated_int32.sort(key=abs)
449    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
450    message.repeated_int32.sort(key=abs, reverse=True)
451    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
452    if sys.version_info.major < 3:  # No cmp sorting in PY3.
453      abs_cmp = lambda a, b: cmp(abs(a), abs(b))
454      message.repeated_int32.sort(sort_function=abs_cmp)
455      self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
456      message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
457      self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
458
459    message.repeated_string.append('aaa')
460    message.repeated_string.append('bb')
461    message.repeated_string.append('c')
462    message.repeated_string.sort(key=len)
463    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
464    message.repeated_string.sort(key=len, reverse=True)
465    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
466    if sys.version_info.major < 3:  # No cmp sorting in PY3.
467      len_cmp = lambda a, b: cmp(len(a), len(b))
468      message.repeated_string.sort(sort_function=len_cmp)
469      self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
470      message.repeated_string.sort(cmp=len_cmp, reverse=True)
471      self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
472
473  def testRepeatedFieldsComparable(self):
474    m1 = unittest_pb2.TestAllTypes()
475    m2 = unittest_pb2.TestAllTypes()
476    m1.repeated_int32.append(0)
477    m1.repeated_int32.append(1)
478    m1.repeated_int32.append(2)
479    m2.repeated_int32.append(0)
480    m2.repeated_int32.append(1)
481    m2.repeated_int32.append(2)
482    m1.repeated_nested_message.add().bb = 1
483    m1.repeated_nested_message.add().bb = 2
484    m1.repeated_nested_message.add().bb = 3
485    m2.repeated_nested_message.add().bb = 1
486    m2.repeated_nested_message.add().bb = 2
487    m2.repeated_nested_message.add().bb = 3
488
489    if sys.version_info.major >= 3: return  # No cmp() in PY3.
490
491    # These comparisons should not raise errors.
492    _ = m1 < m2
493    _ = m1.repeated_nested_message < m2.repeated_nested_message
494
495    # Make sure cmp always works. If it wasn't defined, these would be
496    # id() comparisons and would all fail.
497    self.assertEqual(cmp(m1, m2), 0)
498    self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
499    self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
500    self.assertEqual(cmp(m1.repeated_nested_message,
501                         m2.repeated_nested_message), 0)
502    with self.assertRaises(TypeError):
503      # Can't compare repeated composite containers to lists.
504      cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
505
506    # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
507
508  def testParsingMerge(self):
509    """Check the merge behavior when a required or optional field appears
510    multiple times in the input."""
511    messages = [
512        unittest_pb2.TestAllTypes(),
513        unittest_pb2.TestAllTypes(),
514        unittest_pb2.TestAllTypes() ]
515    messages[0].optional_int32 = 1
516    messages[1].optional_int64 = 2
517    messages[2].optional_int32 = 3
518    messages[2].optional_string = 'hello'
519
520    merged_message = unittest_pb2.TestAllTypes()
521    merged_message.optional_int32 = 3
522    merged_message.optional_int64 = 2
523    merged_message.optional_string = 'hello'
524
525    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
526    generator.field1.extend(messages)
527    generator.field2.extend(messages)
528    generator.field3.extend(messages)
529    generator.ext1.extend(messages)
530    generator.ext2.extend(messages)
531    generator.group1.add().field1.MergeFrom(messages[0])
532    generator.group1.add().field1.MergeFrom(messages[1])
533    generator.group1.add().field1.MergeFrom(messages[2])
534    generator.group2.add().field1.MergeFrom(messages[0])
535    generator.group2.add().field1.MergeFrom(messages[1])
536    generator.group2.add().field1.MergeFrom(messages[2])
537
538    data = generator.SerializeToString()
539    parsing_merge = unittest_pb2.TestParsingMerge()
540    parsing_merge.ParseFromString(data)
541
542    # Required and optional fields should be merged.
543    self.assertEqual(parsing_merge.required_all_types, merged_message)
544    self.assertEqual(parsing_merge.optional_all_types, merged_message)
545    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
546                     merged_message)
547    self.assertEqual(parsing_merge.Extensions[
548                     unittest_pb2.TestParsingMerge.optional_ext],
549                     merged_message)
550
551    # Repeated fields should not be merged.
552    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
553    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
554    self.assertEqual(len(parsing_merge.Extensions[
555        unittest_pb2.TestParsingMerge.repeated_ext]), 3)
556
557  def ensureNestedMessageExists(self, msg, attribute):
558    """Make sure that a nested message object exists.
559
560    As soon as a nested message attribute is accessed, it will be present in the
561    _fields dict, without being marked as actually being set.
562    """
563    getattr(msg, attribute)
564    self.assertFalse(msg.HasField(attribute))
565
566  def testOneofGetCaseNonexistingField(self):
567    m = unittest_pb2.TestAllTypes()
568    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
569
570  def testOneofSemantics(self):
571    m = unittest_pb2.TestAllTypes()
572    self.assertIs(None, m.WhichOneof('oneof_field'))
573
574    m.oneof_uint32 = 11
575    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
576    self.assertTrue(m.HasField('oneof_uint32'))
577
578    m.oneof_string = u'foo'
579    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
580    self.assertFalse(m.HasField('oneof_uint32'))
581    self.assertTrue(m.HasField('oneof_string'))
582
583    m.oneof_nested_message.bb = 11
584    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
585    self.assertFalse(m.HasField('oneof_string'))
586    self.assertTrue(m.HasField('oneof_nested_message'))
587
588    m.oneof_bytes = b'bb'
589    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
590    self.assertFalse(m.HasField('oneof_nested_message'))
591    self.assertTrue(m.HasField('oneof_bytes'))
592
593  def testOneofCompositeFieldReadAccess(self):
594    m = unittest_pb2.TestAllTypes()
595    m.oneof_uint32 = 11
596
597    self.ensureNestedMessageExists(m, 'oneof_nested_message')
598    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
599    self.assertEqual(11, m.oneof_uint32)
600
601  def testOneofHasField(self):
602    m = unittest_pb2.TestAllTypes()
603    self.assertFalse(m.HasField('oneof_field'))
604    m.oneof_uint32 = 11
605    self.assertTrue(m.HasField('oneof_field'))
606    m.oneof_bytes = b'bb'
607    self.assertTrue(m.HasField('oneof_field'))
608    m.ClearField('oneof_bytes')
609    self.assertFalse(m.HasField('oneof_field'))
610
611  def testOneofClearField(self):
612    m = unittest_pb2.TestAllTypes()
613    m.oneof_uint32 = 11
614    m.ClearField('oneof_field')
615    self.assertFalse(m.HasField('oneof_field'))
616    self.assertFalse(m.HasField('oneof_uint32'))
617    self.assertIs(None, m.WhichOneof('oneof_field'))
618
619  def testOneofClearSetField(self):
620    m = unittest_pb2.TestAllTypes()
621    m.oneof_uint32 = 11
622    m.ClearField('oneof_uint32')
623    self.assertFalse(m.HasField('oneof_field'))
624    self.assertFalse(m.HasField('oneof_uint32'))
625    self.assertIs(None, m.WhichOneof('oneof_field'))
626
627  def testOneofClearUnsetField(self):
628    m = unittest_pb2.TestAllTypes()
629    m.oneof_uint32 = 11
630    self.ensureNestedMessageExists(m, 'oneof_nested_message')
631    m.ClearField('oneof_nested_message')
632    self.assertEqual(11, m.oneof_uint32)
633    self.assertTrue(m.HasField('oneof_field'))
634    self.assertTrue(m.HasField('oneof_uint32'))
635    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
636
637  def testOneofDeserialize(self):
638    m = unittest_pb2.TestAllTypes()
639    m.oneof_uint32 = 11
640    m2 = unittest_pb2.TestAllTypes()
641    m2.ParseFromString(m.SerializeToString())
642    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
643
644  def testSortEmptyRepeatedCompositeContainer(self):
645    """Exercise a scenario that has led to segfaults in the past.
646    """
647    m = unittest_pb2.TestAllTypes()
648    m.repeated_nested_message.sort()
649
650  def testHasFieldOnRepeatedField(self):
651    """Using HasField on a repeated field should raise an exception.
652    """
653    m = unittest_pb2.TestAllTypes()
654    with self.assertRaises(ValueError) as _:
655      m.HasField('repeated_int32')
656
657
658class ValidTypeNamesTest(basetest.TestCase):
659
660  def assertImportFromName(self, msg, base_name):
661    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
662    tp_name = str(type(msg)).split("'")[1]
663    valid_names = ('Repeated%sContainer' % base_name,
664                   'Repeated%sFieldContainer' % base_name)
665    self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
666                    '%r does end with any of %r' % (tp_name, valid_names))
667
668    parts = tp_name.split('.')
669    class_name = parts[-1]
670    module_name = '.'.join(parts[:-1])
671    __import__(module_name, fromlist=[class_name])
672
673  def testTypeNamesCanBeImported(self):
674    # If import doesn't work, pickling won't work either.
675    pb = unittest_pb2.TestAllTypes()
676    self.assertImportFromName(pb.repeated_int32, 'Scalar')
677    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
678
679
680if __name__ == '__main__':
681  basetest.main()
682