1#! /usr/bin/env python
2#
3# Protocol Buffers - Google's data interchange format
4# Copyright 2008 Google Inc.  All rights reserved.
5# https://developers.google.com/protocol-buffers/
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are
9# met:
10#
11#     * Redistributions of source code must retain the above copyright
12# notice, this list of conditions and the following disclaimer.
13#     * Redistributions in binary form must reproduce the above
14# copyright notice, this list of conditions and the following disclaimer
15# in the documentation and/or other materials provided with the
16# distribution.
17#     * Neither the name of Google Inc. nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33"""Test for google.protobuf.text_format."""
34
35__author__ = 'kenton@google.com (Kenton Varda)'
36
37
38import re
39import six
40import string
41
42try:
43  import unittest2 as unittest  # PY26, pylint: disable=g-import-not-at-top
44except ImportError:
45  import unittest  # pylint: disable=g-import-not-at-top
46
47from google.protobuf.internal import _parameterized
48
49from google.protobuf import any_test_pb2
50from google.protobuf import map_unittest_pb2
51from google.protobuf import unittest_mset_pb2
52from google.protobuf import unittest_pb2
53from google.protobuf import unittest_proto3_arena_pb2
54from google.protobuf.internal import api_implementation
55from google.protobuf.internal import test_util
56from google.protobuf.internal import message_set_extensions_pb2
57from google.protobuf import descriptor_pool
58from google.protobuf import text_format
59
60
61# Low-level nuts-n-bolts tests.
62class SimpleTextFormatTests(unittest.TestCase):
63
64  # The members of _QUOTES are formatted into a regexp template that
65  # expects single characters.  Therefore it's an error (in addition to being
66  # non-sensical in the first place) to try to specify a "quote mark" that is
67  # more than one character.
68  def testQuoteMarksAreSingleChars(self):
69    for quote in text_format._QUOTES:
70      self.assertEqual(1, len(quote))
71
72
73# Base class with some common functionality.
74class TextFormatBase(unittest.TestCase):
75
76  def ReadGolden(self, golden_filename):
77    with test_util.GoldenFile(golden_filename) as f:
78      return (f.readlines() if str is bytes else  # PY3
79              [golden_line.decode('utf-8') for golden_line in f])
80
81  def CompareToGoldenFile(self, text, golden_filename):
82    golden_lines = self.ReadGolden(golden_filename)
83    self.assertMultiLineEqual(text, ''.join(golden_lines))
84
85  def CompareToGoldenText(self, text, golden_text):
86    self.assertEqual(text, golden_text)
87
88  def RemoveRedundantZeros(self, text):
89    # Some platforms print 1e+5 as 1e+005.  This is fine, but we need to remove
90    # these zeros in order to match the golden file.
91    text = text.replace('e+0','e+').replace('e+0','e+') \
92               .replace('e-0','e-').replace('e-0','e-')
93    # Floating point fields are printed with .0 suffix even if they are
94    # actualy integer numbers.
95    text = re.compile(r'\.0$', re.MULTILINE).sub('', text)
96    return text
97
98
99@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2))
100class TextFormatTest(TextFormatBase):
101
102  def testPrintExotic(self, message_module):
103    message = message_module.TestAllTypes()
104    message.repeated_int64.append(-9223372036854775808)
105    message.repeated_uint64.append(18446744073709551615)
106    message.repeated_double.append(123.456)
107    message.repeated_double.append(1.23e22)
108    message.repeated_double.append(1.23e-18)
109    message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
110    message.repeated_string.append(u'\u00fc\ua71f')
111    self.CompareToGoldenText(
112        self.RemoveRedundantZeros(text_format.MessageToString(message)),
113        'repeated_int64: -9223372036854775808\n'
114        'repeated_uint64: 18446744073709551615\n'
115        'repeated_double: 123.456\n'
116        'repeated_double: 1.23e+22\n'
117        'repeated_double: 1.23e-18\n'
118        'repeated_string:'
119        ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
120        'repeated_string: "\\303\\274\\352\\234\\237"\n')
121
122  def testPrintExoticUnicodeSubclass(self, message_module):
123
124    class UnicodeSub(six.text_type):
125      pass
126
127    message = message_module.TestAllTypes()
128    message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f'))
129    self.CompareToGoldenText(
130        text_format.MessageToString(message),
131        'repeated_string: "\\303\\274\\352\\234\\237"\n')
132
133  def testPrintNestedMessageAsOneLine(self, message_module):
134    message = message_module.TestAllTypes()
135    msg = message.repeated_nested_message.add()
136    msg.bb = 42
137    self.CompareToGoldenText(
138        text_format.MessageToString(message, as_one_line=True),
139        'repeated_nested_message { bb: 42 }')
140
141  def testPrintRepeatedFieldsAsOneLine(self, message_module):
142    message = message_module.TestAllTypes()
143    message.repeated_int32.append(1)
144    message.repeated_int32.append(1)
145    message.repeated_int32.append(3)
146    message.repeated_string.append('Google')
147    message.repeated_string.append('Zurich')
148    self.CompareToGoldenText(
149        text_format.MessageToString(message, as_one_line=True),
150        'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
151        'repeated_string: "Google" repeated_string: "Zurich"')
152
153  def testPrintNestedNewLineInStringAsOneLine(self, message_module):
154    message = message_module.TestAllTypes()
155    message.optional_string = 'a\nnew\nline'
156    self.CompareToGoldenText(
157        text_format.MessageToString(message, as_one_line=True),
158        'optional_string: "a\\nnew\\nline"')
159
160  def testPrintExoticAsOneLine(self, message_module):
161    message = message_module.TestAllTypes()
162    message.repeated_int64.append(-9223372036854775808)
163    message.repeated_uint64.append(18446744073709551615)
164    message.repeated_double.append(123.456)
165    message.repeated_double.append(1.23e22)
166    message.repeated_double.append(1.23e-18)
167    message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
168    message.repeated_string.append(u'\u00fc\ua71f')
169    self.CompareToGoldenText(
170        self.RemoveRedundantZeros(text_format.MessageToString(
171            message, as_one_line=True)),
172        'repeated_int64: -9223372036854775808'
173        ' repeated_uint64: 18446744073709551615'
174        ' repeated_double: 123.456'
175        ' repeated_double: 1.23e+22'
176        ' repeated_double: 1.23e-18'
177        ' repeated_string: '
178        '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""'
179        ' repeated_string: "\\303\\274\\352\\234\\237"')
180
181  def testRoundTripExoticAsOneLine(self, message_module):
182    message = message_module.TestAllTypes()
183    message.repeated_int64.append(-9223372036854775808)
184    message.repeated_uint64.append(18446744073709551615)
185    message.repeated_double.append(123.456)
186    message.repeated_double.append(1.23e22)
187    message.repeated_double.append(1.23e-18)
188    message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
189    message.repeated_string.append(u'\u00fc\ua71f')
190
191    # Test as_utf8 = False.
192    wire_text = text_format.MessageToString(message,
193                                            as_one_line=True,
194                                            as_utf8=False)
195    parsed_message = message_module.TestAllTypes()
196    r = text_format.Parse(wire_text, parsed_message)
197    self.assertIs(r, parsed_message)
198    self.assertEqual(message, parsed_message)
199
200    # Test as_utf8 = True.
201    wire_text = text_format.MessageToString(message,
202                                            as_one_line=True,
203                                            as_utf8=True)
204    parsed_message = message_module.TestAllTypes()
205    r = text_format.Parse(wire_text, parsed_message)
206    self.assertIs(r, parsed_message)
207    self.assertEqual(message, parsed_message,
208                     '\n%s != %s' % (message, parsed_message))
209
210  def testPrintRawUtf8String(self, message_module):
211    message = message_module.TestAllTypes()
212    message.repeated_string.append(u'\u00fc\ua71f')
213    text = text_format.MessageToString(message, as_utf8=True)
214    self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
215    parsed_message = message_module.TestAllTypes()
216    text_format.Parse(text, parsed_message)
217    self.assertEqual(message, parsed_message,
218                     '\n%s != %s' % (message, parsed_message))
219
220  def testPrintFloatFormat(self, message_module):
221    # Check that float_format argument is passed to sub-message formatting.
222    message = message_module.NestedTestAllTypes()
223    # We use 1.25 as it is a round number in binary.  The proto 32-bit float
224    # will not gain additional imprecise digits as a 64-bit Python float and
225    # show up in its str.  32-bit 1.2 is noisy when extended to 64-bit:
226    #  >>> struct.unpack('f', struct.pack('f', 1.2))[0]
227    #  1.2000000476837158
228    #  >>> struct.unpack('f', struct.pack('f', 1.25))[0]
229    #  1.25
230    message.payload.optional_float = 1.25
231    # Check rounding at 15 significant digits
232    message.payload.optional_double = -.000003456789012345678
233    # Check no decimal point.
234    message.payload.repeated_float.append(-5642)
235    # Check no trailing zeros.
236    message.payload.repeated_double.append(.000078900)
237    formatted_fields = ['optional_float: 1.25',
238                        'optional_double: -3.45678901234568e-6',
239                        'repeated_float: -5642', 'repeated_double: 7.89e-5']
240    text_message = text_format.MessageToString(message, float_format='.15g')
241    self.CompareToGoldenText(
242        self.RemoveRedundantZeros(text_message),
243        'payload {{\n  {0}\n  {1}\n  {2}\n  {3}\n}}\n'.format(
244            *formatted_fields))
245    # as_one_line=True is a separate code branch where float_format is passed.
246    text_message = text_format.MessageToString(message,
247                                               as_one_line=True,
248                                               float_format='.15g')
249    self.CompareToGoldenText(
250        self.RemoveRedundantZeros(text_message),
251        'payload {{ {0} {1} {2} {3} }}'.format(*formatted_fields))
252
253  def testMessageToString(self, message_module):
254    message = message_module.ForeignMessage()
255    message.c = 123
256    self.assertEqual('c: 123\n', str(message))
257
258  def testPrintField(self, message_module):
259    message = message_module.TestAllTypes()
260    field = message.DESCRIPTOR.fields_by_name['optional_float']
261    value = message.optional_float
262    out = text_format.TextWriter(False)
263    text_format.PrintField(field, value, out)
264    self.assertEqual('optional_float: 0.0\n', out.getvalue())
265    out.close()
266    # Test Printer
267    out = text_format.TextWriter(False)
268    printer = text_format._Printer(out)
269    printer.PrintField(field, value)
270    self.assertEqual('optional_float: 0.0\n', out.getvalue())
271    out.close()
272
273  def testPrintFieldValue(self, message_module):
274    message = message_module.TestAllTypes()
275    field = message.DESCRIPTOR.fields_by_name['optional_float']
276    value = message.optional_float
277    out = text_format.TextWriter(False)
278    text_format.PrintFieldValue(field, value, out)
279    self.assertEqual('0.0', out.getvalue())
280    out.close()
281    # Test Printer
282    out = text_format.TextWriter(False)
283    printer = text_format._Printer(out)
284    printer.PrintFieldValue(field, value)
285    self.assertEqual('0.0', out.getvalue())
286    out.close()
287
288  def testParseAllFields(self, message_module):
289    message = message_module.TestAllTypes()
290    test_util.SetAllFields(message)
291    ascii_text = text_format.MessageToString(message)
292
293    parsed_message = message_module.TestAllTypes()
294    text_format.Parse(ascii_text, parsed_message)
295    self.assertEqual(message, parsed_message)
296    if message_module is unittest_pb2:
297      test_util.ExpectAllFieldsSet(self, message)
298
299  def testParseExotic(self, message_module):
300    message = message_module.TestAllTypes()
301    text = ('repeated_int64: -9223372036854775808\n'
302            'repeated_uint64: 18446744073709551615\n'
303            'repeated_double: 123.456\n'
304            'repeated_double: 1.23e+22\n'
305            'repeated_double: 1.23e-18\n'
306            'repeated_string: \n'
307            '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
308            'repeated_string: "foo" \'corge\' "grault"\n'
309            'repeated_string: "\\303\\274\\352\\234\\237"\n'
310            'repeated_string: "\\xc3\\xbc"\n'
311            'repeated_string: "\xc3\xbc"\n')
312    text_format.Parse(text, message)
313
314    self.assertEqual(-9223372036854775808, message.repeated_int64[0])
315    self.assertEqual(18446744073709551615, message.repeated_uint64[0])
316    self.assertEqual(123.456, message.repeated_double[0])
317    self.assertEqual(1.23e22, message.repeated_double[1])
318    self.assertEqual(1.23e-18, message.repeated_double[2])
319    self.assertEqual('\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
320    self.assertEqual('foocorgegrault', message.repeated_string[1])
321    self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
322    self.assertEqual(u'\u00fc', message.repeated_string[3])
323
324  def testParseTrailingCommas(self, message_module):
325    message = message_module.TestAllTypes()
326    text = ('repeated_int64: 100;\n'
327            'repeated_int64: 200;\n'
328            'repeated_int64: 300,\n'
329            'repeated_string: "one",\n'
330            'repeated_string: "two";\n')
331    text_format.Parse(text, message)
332
333    self.assertEqual(100, message.repeated_int64[0])
334    self.assertEqual(200, message.repeated_int64[1])
335    self.assertEqual(300, message.repeated_int64[2])
336    self.assertEqual(u'one', message.repeated_string[0])
337    self.assertEqual(u'two', message.repeated_string[1])
338
339  def testParseRepeatedScalarShortFormat(self, message_module):
340    message = message_module.TestAllTypes()
341    text = ('repeated_int64: [100, 200];\n'
342            'repeated_int64: 300,\n'
343            'repeated_string: ["one", "two"];\n')
344    text_format.Parse(text, message)
345
346    self.assertEqual(100, message.repeated_int64[0])
347    self.assertEqual(200, message.repeated_int64[1])
348    self.assertEqual(300, message.repeated_int64[2])
349    self.assertEqual(u'one', message.repeated_string[0])
350    self.assertEqual(u'two', message.repeated_string[1])
351
352  def testParseRepeatedMessageShortFormat(self, message_module):
353    message = message_module.TestAllTypes()
354    text = ('repeated_nested_message: [{bb: 100}, {bb: 200}],\n'
355            'repeated_nested_message: {bb: 300}\n'
356            'repeated_nested_message [{bb: 400}];\n')
357    text_format.Parse(text, message)
358
359    self.assertEqual(100, message.repeated_nested_message[0].bb)
360    self.assertEqual(200, message.repeated_nested_message[1].bb)
361    self.assertEqual(300, message.repeated_nested_message[2].bb)
362    self.assertEqual(400, message.repeated_nested_message[3].bb)
363
364  def testParseEmptyText(self, message_module):
365    message = message_module.TestAllTypes()
366    text = ''
367    text_format.Parse(text, message)
368    self.assertEqual(message_module.TestAllTypes(), message)
369
370  def testParseInvalidUtf8(self, message_module):
371    message = message_module.TestAllTypes()
372    text = 'repeated_string: "\\xc3\\xc3"'
373    self.assertRaises(text_format.ParseError, text_format.Parse, text, message)
374
375  def testParseSingleWord(self, message_module):
376    message = message_module.TestAllTypes()
377    text = 'foo'
378    six.assertRaisesRegex(self, text_format.ParseError, (
379        r'1:1 : Message type "\w+.TestAllTypes" has no field named '
380        r'"foo".'), text_format.Parse, text, message)
381
382  def testParseUnknownField(self, message_module):
383    message = message_module.TestAllTypes()
384    text = 'unknown_field: 8\n'
385    six.assertRaisesRegex(self, text_format.ParseError, (
386        r'1:1 : Message type "\w+.TestAllTypes" has no field named '
387        r'"unknown_field".'), text_format.Parse, text, message)
388
389  def testParseBadEnumValue(self, message_module):
390    message = message_module.TestAllTypes()
391    text = 'optional_nested_enum: BARR'
392    six.assertRaisesRegex(self, text_format.ParseError,
393                          (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
394                           r'has no value named BARR.'), text_format.Parse,
395                          text, message)
396
397    message = message_module.TestAllTypes()
398    text = 'optional_nested_enum: 100'
399    six.assertRaisesRegex(self, text_format.ParseError,
400                          (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
401                           r'has no value with number 100.'), text_format.Parse,
402                          text, message)
403
404  def testParseBadIntValue(self, message_module):
405    message = message_module.TestAllTypes()
406    text = 'optional_int32: bork'
407    six.assertRaisesRegex(self, text_format.ParseError,
408                          ('1:17 : Couldn\'t parse integer: bork'),
409                          text_format.Parse, text, message)
410
411  def testParseStringFieldUnescape(self, message_module):
412    message = message_module.TestAllTypes()
413    text = r'''repeated_string: "\xf\x62"
414               repeated_string: "\\xf\\x62"
415               repeated_string: "\\\xf\\\x62"
416               repeated_string: "\\\\xf\\\\x62"
417               repeated_string: "\\\\\xf\\\\\x62"
418               repeated_string: "\x5cx20"'''
419
420    text_format.Parse(text, message)
421
422    SLASH = '\\'
423    self.assertEqual('\x0fb', message.repeated_string[0])
424    self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1])
425    self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2])
426    self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62',
427                     message.repeated_string[3])
428    self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b',
429                     message.repeated_string[4])
430    self.assertEqual(SLASH + 'x20', message.repeated_string[5])
431
432  def testMergeDuplicateScalars(self, message_module):
433    message = message_module.TestAllTypes()
434    text = ('optional_int32: 42 ' 'optional_int32: 67')
435    r = text_format.Merge(text, message)
436    self.assertIs(r, message)
437    self.assertEqual(67, message.optional_int32)
438
439  def testMergeDuplicateNestedMessageScalars(self, message_module):
440    message = message_module.TestAllTypes()
441    text = ('optional_nested_message { bb: 1 } '
442            'optional_nested_message { bb: 2 }')
443    r = text_format.Merge(text, message)
444    self.assertTrue(r is message)
445    self.assertEqual(2, message.optional_nested_message.bb)
446
447  def testParseOneof(self, message_module):
448    m = message_module.TestAllTypes()
449    m.oneof_uint32 = 11
450    m2 = message_module.TestAllTypes()
451    text_format.Parse(text_format.MessageToString(m), m2)
452    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
453
454  def testParseMultipleOneof(self, message_module):
455    m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
456    m2 = message_module.TestAllTypes()
457    if message_module is unittest_pb2:
458      with self.assertRaisesRegexp(text_format.ParseError,
459                                   ' is specified along with field '):
460        text_format.Parse(m_string, m2)
461    else:
462      text_format.Parse(m_string, m2)
463      self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
464
465
466# These are tests that aren't fundamentally specific to proto2, but are at
467# the moment because of differences between the proto2 and proto3 test schemas.
468# Ideally the schemas would be made more similar so these tests could pass.
469class OnlyWorksWithProto2RightNowTests(TextFormatBase):
470
471  def testPrintAllFieldsPointy(self):
472    message = unittest_pb2.TestAllTypes()
473    test_util.SetAllFields(message)
474    self.CompareToGoldenFile(
475        self.RemoveRedundantZeros(text_format.MessageToString(
476            message, pointy_brackets=True)),
477        'text_format_unittest_data_pointy_oneof.txt')
478
479  def testParseGolden(self):
480    golden_text = '\n'.join(self.ReadGolden(
481        'text_format_unittest_data_oneof_implemented.txt'))
482    parsed_message = unittest_pb2.TestAllTypes()
483    r = text_format.Parse(golden_text, parsed_message)
484    self.assertIs(r, parsed_message)
485
486    message = unittest_pb2.TestAllTypes()
487    test_util.SetAllFields(message)
488    self.assertEqual(message, parsed_message)
489
490  def testPrintAllFields(self):
491    message = unittest_pb2.TestAllTypes()
492    test_util.SetAllFields(message)
493    self.CompareToGoldenFile(
494        self.RemoveRedundantZeros(text_format.MessageToString(message)),
495        'text_format_unittest_data_oneof_implemented.txt')
496
497  def testPrintInIndexOrder(self):
498    message = unittest_pb2.TestFieldOrderings()
499    message.my_string = '115'
500    message.my_int = 101
501    message.my_float = 111
502    message.optional_nested_message.oo = 0
503    message.optional_nested_message.bb = 1
504    self.CompareToGoldenText(
505        self.RemoveRedundantZeros(text_format.MessageToString(
506            message, use_index_order=True)),
507        'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n'
508        'optional_nested_message {\n  oo: 0\n  bb: 1\n}\n')
509    self.CompareToGoldenText(
510        self.RemoveRedundantZeros(text_format.MessageToString(message)),
511        'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n'
512        'optional_nested_message {\n  bb: 1\n  oo: 0\n}\n')
513
514  def testMergeLinesGolden(self):
515    opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
516    parsed_message = unittest_pb2.TestAllTypes()
517    r = text_format.MergeLines(opened, parsed_message)
518    self.assertIs(r, parsed_message)
519
520    message = unittest_pb2.TestAllTypes()
521    test_util.SetAllFields(message)
522    self.assertEqual(message, parsed_message)
523
524  def testParseLinesGolden(self):
525    opened = self.ReadGolden('text_format_unittest_data_oneof_implemented.txt')
526    parsed_message = unittest_pb2.TestAllTypes()
527    r = text_format.ParseLines(opened, parsed_message)
528    self.assertIs(r, parsed_message)
529
530    message = unittest_pb2.TestAllTypes()
531    test_util.SetAllFields(message)
532    self.assertEqual(message, parsed_message)
533
534  def testPrintMap(self):
535    message = map_unittest_pb2.TestMap()
536
537    message.map_int32_int32[-123] = -456
538    message.map_int64_int64[-2**33] = -2**34
539    message.map_uint32_uint32[123] = 456
540    message.map_uint64_uint64[2**33] = 2**34
541    message.map_string_string['abc'] = '123'
542    message.map_int32_foreign_message[111].c = 5
543
544    # Maps are serialized to text format using their underlying repeated
545    # representation.
546    self.CompareToGoldenText(
547        text_format.MessageToString(message), 'map_int32_int32 {\n'
548        '  key: -123\n'
549        '  value: -456\n'
550        '}\n'
551        'map_int64_int64 {\n'
552        '  key: -8589934592\n'
553        '  value: -17179869184\n'
554        '}\n'
555        'map_uint32_uint32 {\n'
556        '  key: 123\n'
557        '  value: 456\n'
558        '}\n'
559        'map_uint64_uint64 {\n'
560        '  key: 8589934592\n'
561        '  value: 17179869184\n'
562        '}\n'
563        'map_string_string {\n'
564        '  key: "abc"\n'
565        '  value: "123"\n'
566        '}\n'
567        'map_int32_foreign_message {\n'
568        '  key: 111\n'
569        '  value {\n'
570        '    c: 5\n'
571        '  }\n'
572        '}\n')
573
574  def testMapOrderEnforcement(self):
575    message = map_unittest_pb2.TestMap()
576    for letter in string.ascii_uppercase[13:26]:
577      message.map_string_string[letter] = 'dummy'
578    for letter in reversed(string.ascii_uppercase[0:13]):
579      message.map_string_string[letter] = 'dummy'
580    golden = ''.join(('map_string_string {\n  key: "%c"\n  value: "dummy"\n}\n'
581                      % (letter,) for letter in string.ascii_uppercase))
582    self.CompareToGoldenText(text_format.MessageToString(message), golden)
583
584  def testMapOrderSemantics(self):
585    golden_lines = self.ReadGolden('map_test_data.txt')
586    # The C++ implementation emits defaulted-value fields, while the Python
587    # implementation does not.  Adjusting for this is awkward, but it is
588    # valuable to test against a common golden file.
589    line_blacklist = ('  key: 0\n', '  value: 0\n', '  key: false\n',
590                      '  value: false\n')
591    golden_lines = [line for line in golden_lines if line not in line_blacklist]
592
593    message = map_unittest_pb2.TestMap()
594    text_format.ParseLines(golden_lines, message)
595    candidate = text_format.MessageToString(message)
596    # The Python implementation emits "1.0" for the double value that the C++
597    # implementation emits as "1".
598    candidate = candidate.replace('1.0', '1', 2)
599    self.assertMultiLineEqual(candidate, ''.join(golden_lines))
600
601
602# Tests of proto2-only features (MessageSet, extensions, etc.).
603class Proto2Tests(TextFormatBase):
604
605  def testPrintMessageSet(self):
606    message = unittest_mset_pb2.TestMessageSetContainer()
607    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
608    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
609    message.message_set.Extensions[ext1].i = 23
610    message.message_set.Extensions[ext2].str = 'foo'
611    self.CompareToGoldenText(
612        text_format.MessageToString(message), 'message_set {\n'
613        '  [protobuf_unittest.TestMessageSetExtension1] {\n'
614        '    i: 23\n'
615        '  }\n'
616        '  [protobuf_unittest.TestMessageSetExtension2] {\n'
617        '    str: \"foo\"\n'
618        '  }\n'
619        '}\n')
620
621    message = message_set_extensions_pb2.TestMessageSet()
622    ext = message_set_extensions_pb2.message_set_extension3
623    message.Extensions[ext].text = 'bar'
624    self.CompareToGoldenText(
625        text_format.MessageToString(message),
626        '[google.protobuf.internal.TestMessageSetExtension3] {\n'
627        '  text: \"bar\"\n'
628        '}\n')
629
630  def testPrintMessageSetByFieldNumber(self):
631    out = text_format.TextWriter(False)
632    message = unittest_mset_pb2.TestMessageSetContainer()
633    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
634    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
635    message.message_set.Extensions[ext1].i = 23
636    message.message_set.Extensions[ext2].str = 'foo'
637    text_format.PrintMessage(message, out, use_field_number=True)
638    self.CompareToGoldenText(out.getvalue(), '1 {\n'
639                             '  1545008 {\n'
640                             '    15: 23\n'
641                             '  }\n'
642                             '  1547769 {\n'
643                             '    25: \"foo\"\n'
644                             '  }\n'
645                             '}\n')
646    out.close()
647
648  def testPrintMessageSetAsOneLine(self):
649    message = unittest_mset_pb2.TestMessageSetContainer()
650    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
651    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
652    message.message_set.Extensions[ext1].i = 23
653    message.message_set.Extensions[ext2].str = 'foo'
654    self.CompareToGoldenText(
655        text_format.MessageToString(message, as_one_line=True),
656        'message_set {'
657        ' [protobuf_unittest.TestMessageSetExtension1] {'
658        ' i: 23'
659        ' }'
660        ' [protobuf_unittest.TestMessageSetExtension2] {'
661        ' str: \"foo\"'
662        ' }'
663        ' }')
664
665  def testParseMessageSet(self):
666    message = unittest_pb2.TestAllTypes()
667    text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n')
668    text_format.Parse(text, message)
669    self.assertEqual(1, message.repeated_uint64[0])
670    self.assertEqual(2, message.repeated_uint64[1])
671
672    message = unittest_mset_pb2.TestMessageSetContainer()
673    text = ('message_set {\n'
674            '  [protobuf_unittest.TestMessageSetExtension1] {\n'
675            '    i: 23\n'
676            '  }\n'
677            '  [protobuf_unittest.TestMessageSetExtension2] {\n'
678            '    str: \"foo\"\n'
679            '  }\n'
680            '}\n')
681    text_format.Parse(text, message)
682    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
683    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
684    self.assertEqual(23, message.message_set.Extensions[ext1].i)
685    self.assertEqual('foo', message.message_set.Extensions[ext2].str)
686
687  def testParseMessageByFieldNumber(self):
688    message = unittest_pb2.TestAllTypes()
689    text = ('34: 1\n' 'repeated_uint64: 2\n')
690    text_format.Parse(text, message, allow_field_number=True)
691    self.assertEqual(1, message.repeated_uint64[0])
692    self.assertEqual(2, message.repeated_uint64[1])
693
694    message = unittest_mset_pb2.TestMessageSetContainer()
695    text = ('1 {\n'
696            '  1545008 {\n'
697            '    15: 23\n'
698            '  }\n'
699            '  1547769 {\n'
700            '    25: \"foo\"\n'
701            '  }\n'
702            '}\n')
703    text_format.Parse(text, message, allow_field_number=True)
704    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
705    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
706    self.assertEqual(23, message.message_set.Extensions[ext1].i)
707    self.assertEqual('foo', message.message_set.Extensions[ext2].str)
708
709    # Can't parse field number without set allow_field_number=True.
710    message = unittest_pb2.TestAllTypes()
711    text = '34:1\n'
712    six.assertRaisesRegex(self, text_format.ParseError, (
713        r'1:1 : Message type "\w+.TestAllTypes" has no field named '
714        r'"34".'), text_format.Parse, text, message)
715
716    # Can't parse if field number is not found.
717    text = '1234:1\n'
718    six.assertRaisesRegex(
719        self,
720        text_format.ParseError,
721        (r'1:1 : Message type "\w+.TestAllTypes" has no field named '
722         r'"1234".'),
723        text_format.Parse,
724        text,
725        message,
726        allow_field_number=True)
727
728  def testPrintAllExtensions(self):
729    message = unittest_pb2.TestAllExtensions()
730    test_util.SetAllExtensions(message)
731    self.CompareToGoldenFile(
732        self.RemoveRedundantZeros(text_format.MessageToString(message)),
733        'text_format_unittest_extensions_data.txt')
734
735  def testPrintAllExtensionsPointy(self):
736    message = unittest_pb2.TestAllExtensions()
737    test_util.SetAllExtensions(message)
738    self.CompareToGoldenFile(
739        self.RemoveRedundantZeros(text_format.MessageToString(
740            message, pointy_brackets=True)),
741        'text_format_unittest_extensions_data_pointy.txt')
742
743  def testParseGoldenExtensions(self):
744    golden_text = '\n'.join(self.ReadGolden(
745        'text_format_unittest_extensions_data.txt'))
746    parsed_message = unittest_pb2.TestAllExtensions()
747    text_format.Parse(golden_text, parsed_message)
748
749    message = unittest_pb2.TestAllExtensions()
750    test_util.SetAllExtensions(message)
751    self.assertEqual(message, parsed_message)
752
753  def testParseAllExtensions(self):
754    message = unittest_pb2.TestAllExtensions()
755    test_util.SetAllExtensions(message)
756    ascii_text = text_format.MessageToString(message)
757
758    parsed_message = unittest_pb2.TestAllExtensions()
759    text_format.Parse(ascii_text, parsed_message)
760    self.assertEqual(message, parsed_message)
761
762  def testParseAllowedUnknownExtension(self):
763    # Skip over unknown extension correctly.
764    message = unittest_mset_pb2.TestMessageSetContainer()
765    text = ('message_set {\n'
766            '  [unknown_extension] {\n'
767            '    i: 23\n'
768            '    bin: "\xe0"'
769            '    [nested_unknown_ext]: {\n'
770            '      i: 23\n'
771            '      test: "test_string"\n'
772            '      floaty_float: -0.315\n'
773            '      num: -inf\n'
774            '      multiline_str: "abc"\n'
775            '          "def"\n'
776            '          "xyz."\n'
777            '      [nested_unknown_ext]: <\n'
778            '        i: 23\n'
779            '        i: 24\n'
780            '        pointfloat: .3\n'
781            '        test: "test_string"\n'
782            '        floaty_float: -0.315\n'
783            '        num: -inf\n'
784            '        long_string: "test" "test2" \n'
785            '      >\n'
786            '    }\n'
787            '  }\n'
788            '  [unknown_extension]: 5\n'
789            '}\n')
790    text_format.Parse(text, message, allow_unknown_extension=True)
791    golden = 'message_set {\n}\n'
792    self.CompareToGoldenText(text_format.MessageToString(message), golden)
793
794    # Catch parse errors in unknown extension.
795    message = unittest_mset_pb2.TestMessageSetContainer()
796    malformed = ('message_set {\n'
797                 '  [unknown_extension] {\n'
798                 '    i:\n'  # Missing value.
799                 '  }\n'
800                 '}\n')
801    six.assertRaisesRegex(self,
802                          text_format.ParseError,
803                          'Invalid field value: }',
804                          text_format.Parse,
805                          malformed,
806                          message,
807                          allow_unknown_extension=True)
808
809    message = unittest_mset_pb2.TestMessageSetContainer()
810    malformed = ('message_set {\n'
811                 '  [unknown_extension] {\n'
812                 '    str: "malformed string\n'  # Missing closing quote.
813                 '  }\n'
814                 '}\n')
815    six.assertRaisesRegex(self,
816                          text_format.ParseError,
817                          'Invalid field value: "',
818                          text_format.Parse,
819                          malformed,
820                          message,
821                          allow_unknown_extension=True)
822
823    message = unittest_mset_pb2.TestMessageSetContainer()
824    malformed = ('message_set {\n'
825                 '  [unknown_extension] {\n'
826                 '    str: "malformed\n multiline\n string\n'
827                 '  }\n'
828                 '}\n')
829    six.assertRaisesRegex(self,
830                          text_format.ParseError,
831                          'Invalid field value: "',
832                          text_format.Parse,
833                          malformed,
834                          message,
835                          allow_unknown_extension=True)
836
837    message = unittest_mset_pb2.TestMessageSetContainer()
838    malformed = ('message_set {\n'
839                 '  [malformed_extension] <\n'
840                 '    i: -5\n'
841                 '  \n'  # Missing '>' here.
842                 '}\n')
843    six.assertRaisesRegex(self,
844                          text_format.ParseError,
845                          '5:1 : Expected ">".',
846                          text_format.Parse,
847                          malformed,
848                          message,
849                          allow_unknown_extension=True)
850
851    # Don't allow unknown fields with allow_unknown_extension=True.
852    message = unittest_mset_pb2.TestMessageSetContainer()
853    malformed = ('message_set {\n'
854                 '  unknown_field: true\n'
855                 '  \n'  # Missing '>' here.
856                 '}\n')
857    six.assertRaisesRegex(self,
858                          text_format.ParseError,
859                          ('2:3 : Message type '
860                           '"proto2_wireformat_unittest.TestMessageSet" has no'
861                           ' field named "unknown_field".'),
862                          text_format.Parse,
863                          malformed,
864                          message,
865                          allow_unknown_extension=True)
866
867    # Parse known extension correcty.
868    message = unittest_mset_pb2.TestMessageSetContainer()
869    text = ('message_set {\n'
870            '  [protobuf_unittest.TestMessageSetExtension1] {\n'
871            '    i: 23\n'
872            '  }\n'
873            '  [protobuf_unittest.TestMessageSetExtension2] {\n'
874            '    str: \"foo\"\n'
875            '  }\n'
876            '}\n')
877    text_format.Parse(text, message, allow_unknown_extension=True)
878    ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
879    ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
880    self.assertEqual(23, message.message_set.Extensions[ext1].i)
881    self.assertEqual('foo', message.message_set.Extensions[ext2].str)
882
883  def testParseBadExtension(self):
884    message = unittest_pb2.TestAllExtensions()
885    text = '[unknown_extension]: 8\n'
886    six.assertRaisesRegex(self, text_format.ParseError,
887                          '1:2 : Extension "unknown_extension" not registered.',
888                          text_format.Parse, text, message)
889    message = unittest_pb2.TestAllTypes()
890    six.assertRaisesRegex(self, text_format.ParseError, (
891        '1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
892        'extensions.'), text_format.Parse, text, message)
893
894  def testMergeDuplicateExtensionScalars(self):
895    message = unittest_pb2.TestAllExtensions()
896    text = ('[protobuf_unittest.optional_int32_extension]: 42 '
897            '[protobuf_unittest.optional_int32_extension]: 67')
898    text_format.Merge(text, message)
899    self.assertEqual(67,
900                     message.Extensions[unittest_pb2.optional_int32_extension])
901
902  def testParseDuplicateExtensionScalars(self):
903    message = unittest_pb2.TestAllExtensions()
904    text = ('[protobuf_unittest.optional_int32_extension]: 42 '
905            '[protobuf_unittest.optional_int32_extension]: 67')
906    six.assertRaisesRegex(self, text_format.ParseError, (
907        '1:96 : Message type "protobuf_unittest.TestAllExtensions" '
908        'should not have multiple '
909        '"protobuf_unittest.optional_int32_extension" extensions.'),
910                          text_format.Parse, text, message)
911
912  def testParseDuplicateNestedMessageScalars(self):
913    message = unittest_pb2.TestAllTypes()
914    text = ('optional_nested_message { bb: 1 } '
915            'optional_nested_message { bb: 2 }')
916    six.assertRaisesRegex(self, text_format.ParseError, (
917        '1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" '
918        'should not have multiple "bb" fields.'), text_format.Parse, text,
919                          message)
920
921  def testParseDuplicateScalars(self):
922    message = unittest_pb2.TestAllTypes()
923    text = ('optional_int32: 42 ' 'optional_int32: 67')
924    six.assertRaisesRegex(self, text_format.ParseError, (
925        '1:36 : Message type "protobuf_unittest.TestAllTypes" should not '
926        'have multiple "optional_int32" fields.'), text_format.Parse, text,
927                          message)
928
929  def testParseGroupNotClosed(self):
930    message = unittest_pb2.TestAllTypes()
931    text = 'RepeatedGroup: <'
932    six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected ">".',
933                          text_format.Parse, text, message)
934    text = 'RepeatedGroup: {'
935    six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected "}".',
936                          text_format.Parse, text, message)
937
938  def testParseEmptyGroup(self):
939    message = unittest_pb2.TestAllTypes()
940    text = 'OptionalGroup: {}'
941    text_format.Parse(text, message)
942    self.assertTrue(message.HasField('optionalgroup'))
943
944    message.Clear()
945
946    message = unittest_pb2.TestAllTypes()
947    text = 'OptionalGroup: <>'
948    text_format.Parse(text, message)
949    self.assertTrue(message.HasField('optionalgroup'))
950
951  # Maps aren't really proto2-only, but our test schema only has maps for
952  # proto2.
953  def testParseMap(self):
954    text = ('map_int32_int32 {\n'
955            '  key: -123\n'
956            '  value: -456\n'
957            '}\n'
958            'map_int64_int64 {\n'
959            '  key: -8589934592\n'
960            '  value: -17179869184\n'
961            '}\n'
962            'map_uint32_uint32 {\n'
963            '  key: 123\n'
964            '  value: 456\n'
965            '}\n'
966            'map_uint64_uint64 {\n'
967            '  key: 8589934592\n'
968            '  value: 17179869184\n'
969            '}\n'
970            'map_string_string {\n'
971            '  key: "abc"\n'
972            '  value: "123"\n'
973            '}\n'
974            'map_int32_foreign_message {\n'
975            '  key: 111\n'
976            '  value {\n'
977            '    c: 5\n'
978            '  }\n'
979            '}\n')
980    message = map_unittest_pb2.TestMap()
981    text_format.Parse(text, message)
982
983    self.assertEqual(-456, message.map_int32_int32[-123])
984    self.assertEqual(-2**34, message.map_int64_int64[-2**33])
985    self.assertEqual(456, message.map_uint32_uint32[123])
986    self.assertEqual(2**34, message.map_uint64_uint64[2**33])
987    self.assertEqual('123', message.map_string_string['abc'])
988    self.assertEqual(5, message.map_int32_foreign_message[111].c)
989
990
991class Proto3Tests(unittest.TestCase):
992
993  def testPrintMessageExpandAny(self):
994    packed_message = unittest_pb2.OneString()
995    packed_message.data = 'string'
996    message = any_test_pb2.TestAny()
997    message.any_value.Pack(packed_message)
998    self.assertEqual(
999        text_format.MessageToString(message,
1000                                    descriptor_pool=descriptor_pool.Default()),
1001        'any_value {\n'
1002        '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1003        '    data: "string"\n'
1004        '  }\n'
1005        '}\n')
1006
1007  def testPrintMessageExpandAnyRepeated(self):
1008    packed_message = unittest_pb2.OneString()
1009    message = any_test_pb2.TestAny()
1010    packed_message.data = 'string0'
1011    message.repeated_any_value.add().Pack(packed_message)
1012    packed_message.data = 'string1'
1013    message.repeated_any_value.add().Pack(packed_message)
1014    self.assertEqual(
1015        text_format.MessageToString(message,
1016                                    descriptor_pool=descriptor_pool.Default()),
1017        'repeated_any_value {\n'
1018        '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1019        '    data: "string0"\n'
1020        '  }\n'
1021        '}\n'
1022        'repeated_any_value {\n'
1023        '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1024        '    data: "string1"\n'
1025        '  }\n'
1026        '}\n')
1027
1028  def testPrintMessageExpandAnyNoDescriptorPool(self):
1029    packed_message = unittest_pb2.OneString()
1030    packed_message.data = 'string'
1031    message = any_test_pb2.TestAny()
1032    message.any_value.Pack(packed_message)
1033    self.assertEqual(
1034        text_format.MessageToString(message, descriptor_pool=None),
1035        'any_value {\n'
1036        '  type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
1037        '  value: "\\n\\006string"\n'
1038        '}\n')
1039
1040  def testPrintMessageExpandAnyDescriptorPoolMissingType(self):
1041    packed_message = unittest_pb2.OneString()
1042    packed_message.data = 'string'
1043    message = any_test_pb2.TestAny()
1044    message.any_value.Pack(packed_message)
1045    empty_pool = descriptor_pool.DescriptorPool()
1046    self.assertEqual(
1047        text_format.MessageToString(message, descriptor_pool=empty_pool),
1048        'any_value {\n'
1049        '  type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
1050        '  value: "\\n\\006string"\n'
1051        '}\n')
1052
1053  def testPrintMessageExpandAnyPointyBrackets(self):
1054    packed_message = unittest_pb2.OneString()
1055    packed_message.data = 'string'
1056    message = any_test_pb2.TestAny()
1057    message.any_value.Pack(packed_message)
1058    self.assertEqual(
1059        text_format.MessageToString(message,
1060                                    pointy_brackets=True,
1061                                    descriptor_pool=descriptor_pool.Default()),
1062        'any_value <\n'
1063        '  [type.googleapis.com/protobuf_unittest.OneString] <\n'
1064        '    data: "string"\n'
1065        '  >\n'
1066        '>\n')
1067
1068  def testPrintMessageExpandAnyAsOneLine(self):
1069    packed_message = unittest_pb2.OneString()
1070    packed_message.data = 'string'
1071    message = any_test_pb2.TestAny()
1072    message.any_value.Pack(packed_message)
1073    self.assertEqual(
1074        text_format.MessageToString(message,
1075                                    as_one_line=True,
1076                                    descriptor_pool=descriptor_pool.Default()),
1077        'any_value {'
1078        ' [type.googleapis.com/protobuf_unittest.OneString]'
1079        ' { data: "string" } '
1080        '}')
1081
1082  def testPrintMessageExpandAnyAsOneLinePointyBrackets(self):
1083    packed_message = unittest_pb2.OneString()
1084    packed_message.data = 'string'
1085    message = any_test_pb2.TestAny()
1086    message.any_value.Pack(packed_message)
1087    self.assertEqual(
1088        text_format.MessageToString(message,
1089                                    as_one_line=True,
1090                                    pointy_brackets=True,
1091                                    descriptor_pool=descriptor_pool.Default()),
1092        'any_value <'
1093        ' [type.googleapis.com/protobuf_unittest.OneString]'
1094        ' < data: "string" > '
1095        '>')
1096
1097  def testMergeExpandedAny(self):
1098    message = any_test_pb2.TestAny()
1099    text = ('any_value {\n'
1100            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1101            '    data: "string"\n'
1102            '  }\n'
1103            '}\n')
1104    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
1105    packed_message = unittest_pb2.OneString()
1106    message.any_value.Unpack(packed_message)
1107    self.assertEqual('string', packed_message.data)
1108
1109  def testMergeExpandedAnyRepeated(self):
1110    message = any_test_pb2.TestAny()
1111    text = ('repeated_any_value {\n'
1112            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1113            '    data: "string0"\n'
1114            '  }\n'
1115            '}\n'
1116            'repeated_any_value {\n'
1117            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1118            '    data: "string1"\n'
1119            '  }\n'
1120            '}\n')
1121    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
1122    packed_message = unittest_pb2.OneString()
1123    message.repeated_any_value[0].Unpack(packed_message)
1124    self.assertEqual('string0', packed_message.data)
1125    message.repeated_any_value[1].Unpack(packed_message)
1126    self.assertEqual('string1', packed_message.data)
1127
1128  def testMergeExpandedAnyPointyBrackets(self):
1129    message = any_test_pb2.TestAny()
1130    text = ('any_value {\n'
1131            '  [type.googleapis.com/protobuf_unittest.OneString] <\n'
1132            '    data: "string"\n'
1133            '  >\n'
1134            '}\n')
1135    text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
1136    packed_message = unittest_pb2.OneString()
1137    message.any_value.Unpack(packed_message)
1138    self.assertEqual('string', packed_message.data)
1139
1140  def testMergeExpandedAnyNoDescriptorPool(self):
1141    message = any_test_pb2.TestAny()
1142    text = ('any_value {\n'
1143            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1144            '    data: "string"\n'
1145            '  }\n'
1146            '}\n')
1147    with self.assertRaises(text_format.ParseError) as e:
1148      text_format.Merge(text, message, descriptor_pool=None)
1149    self.assertEqual(str(e.exception),
1150                     'Descriptor pool required to parse expanded Any field')
1151
1152  def testMergeExpandedAnyDescriptorPoolMissingType(self):
1153    message = any_test_pb2.TestAny()
1154    text = ('any_value {\n'
1155            '  [type.googleapis.com/protobuf_unittest.OneString] {\n'
1156            '    data: "string"\n'
1157            '  }\n'
1158            '}\n')
1159    with self.assertRaises(text_format.ParseError) as e:
1160      empty_pool = descriptor_pool.DescriptorPool()
1161      text_format.Merge(text, message, descriptor_pool=empty_pool)
1162    self.assertEqual(
1163        str(e.exception),
1164        'Type protobuf_unittest.OneString not found in descriptor pool')
1165
1166  def testMergeUnexpandedAny(self):
1167    text = ('any_value {\n'
1168            '  type_url: "type.googleapis.com/protobuf_unittest.OneString"\n'
1169            '  value: "\\n\\006string"\n'
1170            '}\n')
1171    message = any_test_pb2.TestAny()
1172    text_format.Merge(text, message)
1173    packed_message = unittest_pb2.OneString()
1174    message.any_value.Unpack(packed_message)
1175    self.assertEqual('string', packed_message.data)
1176
1177
1178class TokenizerTest(unittest.TestCase):
1179
1180  def testSimpleTokenCases(self):
1181    text = ('identifier1:"string1"\n     \n\n'
1182            'identifier2 : \n \n123  \n  identifier3 :\'string\'\n'
1183            'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n'
1184            'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
1185            'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
1186            'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f '
1187            'false_bool:  0 true_BOOL:t \n true_bool1:  1 false_BOOL1:f ')
1188    tokenizer = text_format.Tokenizer(text.splitlines())
1189    methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':',
1190               (tokenizer.ConsumeString, 'string1'),
1191               (tokenizer.ConsumeIdentifier, 'identifier2'), ':',
1192               (tokenizer.ConsumeInteger, 123),
1193               (tokenizer.ConsumeIdentifier, 'identifier3'), ':',
1194               (tokenizer.ConsumeString, 'string'),
1195               (tokenizer.ConsumeIdentifier, 'identifiER_4'), ':',
1196               (tokenizer.ConsumeFloat, 1.1e+2),
1197               (tokenizer.ConsumeIdentifier, 'ID5'), ':',
1198               (tokenizer.ConsumeFloat, -0.23),
1199               (tokenizer.ConsumeIdentifier, 'ID6'), ':',
1200               (tokenizer.ConsumeString, 'aaaa\'bbbb'),
1201               (tokenizer.ConsumeIdentifier, 'ID7'), ':',
1202               (tokenizer.ConsumeString, 'aa\"bb'),
1203               (tokenizer.ConsumeIdentifier, 'ID8'), ':', '{',
1204               (tokenizer.ConsumeIdentifier, 'A'), ':',
1205               (tokenizer.ConsumeFloat, float('inf')),
1206               (tokenizer.ConsumeIdentifier, 'B'), ':',
1207               (tokenizer.ConsumeFloat, -float('inf')),
1208               (tokenizer.ConsumeIdentifier, 'C'), ':',
1209               (tokenizer.ConsumeBool, True),
1210               (tokenizer.ConsumeIdentifier, 'D'), ':',
1211               (tokenizer.ConsumeBool, False), '}',
1212               (tokenizer.ConsumeIdentifier, 'ID9'), ':',
1213               (tokenizer.ConsumeInteger, 22),
1214               (tokenizer.ConsumeIdentifier, 'ID10'), ':',
1215               (tokenizer.ConsumeInteger, -111111111111111111),
1216               (tokenizer.ConsumeIdentifier, 'ID11'), ':',
1217               (tokenizer.ConsumeInteger, -22),
1218               (tokenizer.ConsumeIdentifier, 'ID12'), ':',
1219               (tokenizer.ConsumeInteger, 2222222222222222222),
1220               (tokenizer.ConsumeIdentifier, 'ID13'), ':',
1221               (tokenizer.ConsumeFloat, 1.23456),
1222               (tokenizer.ConsumeIdentifier, 'ID14'), ':',
1223               (tokenizer.ConsumeFloat, 1.2e+2),
1224               (tokenizer.ConsumeIdentifier, 'false_bool'), ':',
1225               (tokenizer.ConsumeBool, False),
1226               (tokenizer.ConsumeIdentifier, 'true_BOOL'), ':',
1227               (tokenizer.ConsumeBool, True),
1228               (tokenizer.ConsumeIdentifier, 'true_bool1'), ':',
1229               (tokenizer.ConsumeBool, True),
1230               (tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':',
1231               (tokenizer.ConsumeBool, False)]
1232
1233    i = 0
1234    while not tokenizer.AtEnd():
1235      m = methods[i]
1236      if isinstance(m, str):
1237        token = tokenizer.token
1238        self.assertEqual(token, m)
1239        tokenizer.NextToken()
1240      else:
1241        self.assertEqual(m[1], m[0]())
1242      i += 1
1243
1244  def testConsumeAbstractIntegers(self):
1245    # This test only tests the failures in the integer parsing methods as well
1246    # as the '0' special cases.
1247    int64_max = (1 << 63) - 1
1248    uint32_max = (1 << 32) - 1
1249    text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
1250    tokenizer = text_format.Tokenizer(text.splitlines())
1251    self.assertEqual(-1, tokenizer.ConsumeInteger())
1252
1253    self.assertEqual(uint32_max + 1, tokenizer.ConsumeInteger())
1254
1255    self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger())
1256    self.assertTrue(tokenizer.AtEnd())
1257
1258    text = '-0 0'
1259    tokenizer = text_format.Tokenizer(text.splitlines())
1260    self.assertEqual(0, tokenizer.ConsumeInteger())
1261    self.assertEqual(0, tokenizer.ConsumeInteger())
1262    self.assertTrue(tokenizer.AtEnd())
1263
1264  def testConsumeIntegers(self):
1265    # This test only tests the failures in the integer parsing methods as well
1266    # as the '0' special cases.
1267    int64_max = (1 << 63) - 1
1268    uint32_max = (1 << 32) - 1
1269    text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
1270    tokenizer = text_format.Tokenizer(text.splitlines())
1271    self.assertRaises(text_format.ParseError,
1272                      text_format._ConsumeUint32, tokenizer)
1273    self.assertRaises(text_format.ParseError,
1274                      text_format._ConsumeUint64, tokenizer)
1275    self.assertEqual(-1, text_format._ConsumeInt32(tokenizer))
1276
1277    self.assertRaises(text_format.ParseError,
1278                      text_format._ConsumeUint32, tokenizer)
1279    self.assertRaises(text_format.ParseError,
1280                      text_format._ConsumeInt32, tokenizer)
1281    self.assertEqual(uint32_max + 1, text_format._ConsumeInt64(tokenizer))
1282
1283    self.assertRaises(text_format.ParseError,
1284                      text_format._ConsumeInt64, tokenizer)
1285    self.assertEqual(int64_max + 1, text_format._ConsumeUint64(tokenizer))
1286    self.assertTrue(tokenizer.AtEnd())
1287
1288    text = '-0 -0 0 0'
1289    tokenizer = text_format.Tokenizer(text.splitlines())
1290    self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
1291    self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
1292    self.assertEqual(0, text_format._ConsumeUint32(tokenizer))
1293    self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
1294    self.assertTrue(tokenizer.AtEnd())
1295
1296  def testConsumeByteString(self):
1297    text = '"string1\''
1298    tokenizer = text_format.Tokenizer(text.splitlines())
1299    self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
1300
1301    text = 'string1"'
1302    tokenizer = text_format.Tokenizer(text.splitlines())
1303    self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
1304
1305    text = '\n"\\xt"'
1306    tokenizer = text_format.Tokenizer(text.splitlines())
1307    self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
1308
1309    text = '\n"\\"'
1310    tokenizer = text_format.Tokenizer(text.splitlines())
1311    self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
1312
1313    text = '\n"\\x"'
1314    tokenizer = text_format.Tokenizer(text.splitlines())
1315    self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
1316
1317  def testConsumeBool(self):
1318    text = 'not-a-bool'
1319    tokenizer = text_format.Tokenizer(text.splitlines())
1320    self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
1321
1322  def testSkipComment(self):
1323    tokenizer = text_format.Tokenizer('# some comment'.splitlines())
1324    self.assertTrue(tokenizer.AtEnd())
1325    self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
1326
1327  def testConsumeComment(self):
1328    tokenizer = text_format.Tokenizer('# some comment'.splitlines(),
1329                                      skip_comments=False)
1330    self.assertFalse(tokenizer.AtEnd())
1331    self.assertEqual('# some comment', tokenizer.ConsumeComment())
1332    self.assertTrue(tokenizer.AtEnd())
1333
1334  def testConsumeTwoComments(self):
1335    text = '# some comment\n# another comment'
1336    tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
1337    self.assertEqual('# some comment', tokenizer.ConsumeComment())
1338    self.assertFalse(tokenizer.AtEnd())
1339    self.assertEqual('# another comment', tokenizer.ConsumeComment())
1340    self.assertTrue(tokenizer.AtEnd())
1341
1342  def testConsumeTrailingComment(self):
1343    text = 'some_number: 4\n# some comment'
1344    tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
1345    self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
1346
1347    self.assertEqual('some_number', tokenizer.ConsumeIdentifier())
1348    self.assertEqual(tokenizer.token, ':')
1349    tokenizer.NextToken()
1350    self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment)
1351    self.assertEqual(4, tokenizer.ConsumeInteger())
1352    self.assertFalse(tokenizer.AtEnd())
1353
1354    self.assertEqual('# some comment', tokenizer.ConsumeComment())
1355    self.assertTrue(tokenizer.AtEnd())
1356
1357
1358if __name__ == '__main__':
1359  unittest.main()
1360