1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains routines for printing protocol messages in text format.
32
33Simple usage example:
34
35  # Create a proto object and serialize it to a text proto string.
36  message = my_proto_pb2.MyMessage(foo='bar')
37  text_proto = text_format.MessageToString(message)
38
39  # Parse a text proto string.
40  message = text_format.Parse(text_proto, my_proto_pb2.MyMessage())
41"""
42
43__author__ = 'kenton@google.com (Kenton Varda)'
44
45import io
46import re
47
48import six
49
50if six.PY3:
51  long = int
52
53from google.protobuf.internal import type_checkers
54from google.protobuf import descriptor
55from google.protobuf import text_encoding
56
57__all__ = ['MessageToString', 'PrintMessage', 'PrintField',
58           'PrintFieldValue', 'Merge']
59
60
61_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
62                     type_checkers.Int32ValueChecker(),
63                     type_checkers.Uint64ValueChecker(),
64                     type_checkers.Int64ValueChecker())
65_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
66_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
67_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
68                          descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
69_QUOTES = frozenset(("'", '"'))
70
71
72class Error(Exception):
73  """Top-level module error for text_format."""
74
75
76class ParseError(Error):
77  """Thrown in case of text parsing error."""
78
79
80class TextWriter(object):
81  def __init__(self, as_utf8):
82    if six.PY2:
83      self._writer = io.BytesIO()
84    else:
85      self._writer = io.StringIO()
86
87  def write(self, val):
88    if six.PY2:
89      if isinstance(val, six.text_type):
90        val = val.encode('utf-8')
91    return self._writer.write(val)
92
93  def close(self):
94    return self._writer.close()
95
96  def getvalue(self):
97    return self._writer.getvalue()
98
99
100def MessageToString(message, as_utf8=False, as_one_line=False,
101                    pointy_brackets=False, use_index_order=False,
102                    float_format=None, use_field_number=False):
103  """Convert protobuf message to text format.
104
105  Floating point values can be formatted compactly with 15 digits of
106  precision (which is the most that IEEE 754 "double" can guarantee)
107  using float_format='.15g'. To ensure that converting to text and back to a
108  proto will result in an identical value, float_format='.17g' should be used.
109
110  Args:
111    message: The protocol buffers message.
112    as_utf8: Produce text output in UTF8 format.
113    as_one_line: Don't introduce newlines between fields.
114    pointy_brackets: If True, use angle brackets instead of curly braces for
115      nesting.
116    use_index_order: If True, print fields of a proto message using the order
117      defined in source code instead of the field number. By default, use the
118      field number order.
119    float_format: If set, use this to specify floating point number formatting
120      (per the "Format Specification Mini-Language"); otherwise, str() is used.
121    use_field_number: If True, print field numbers instead of names.
122
123  Returns:
124    A string of the text formatted protocol buffer message.
125  """
126  out = TextWriter(as_utf8)
127  printer = _Printer(out, 0, as_utf8, as_one_line,
128                     pointy_brackets, use_index_order, float_format,
129                     use_field_number)
130  printer.PrintMessage(message)
131  result = out.getvalue()
132  out.close()
133  if as_one_line:
134    return result.rstrip()
135  return result
136
137
138def _IsMapEntry(field):
139  return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
140          field.message_type.has_options and
141          field.message_type.GetOptions().map_entry)
142
143
144def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False,
145                 pointy_brackets=False, use_index_order=False,
146                 float_format=None, use_field_number=False):
147  printer = _Printer(out, indent, as_utf8, as_one_line,
148                     pointy_brackets, use_index_order, float_format,
149                     use_field_number)
150  printer.PrintMessage(message)
151
152
153def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False,
154               pointy_brackets=False, use_index_order=False, float_format=None):
155  """Print a single field name/value pair."""
156  printer = _Printer(out, indent, as_utf8, as_one_line,
157                     pointy_brackets, use_index_order, float_format)
158  printer.PrintField(field, value)
159
160
161def PrintFieldValue(field, value, out, indent=0, as_utf8=False,
162                    as_one_line=False, pointy_brackets=False,
163                    use_index_order=False,
164                    float_format=None):
165  """Print a single field value (not including name)."""
166  printer = _Printer(out, indent, as_utf8, as_one_line,
167                     pointy_brackets, use_index_order, float_format)
168  printer.PrintFieldValue(field, value)
169
170
171class _Printer(object):
172  """Text format printer for protocol message."""
173
174  def __init__(self, out, indent=0, as_utf8=False, as_one_line=False,
175               pointy_brackets=False, use_index_order=False, float_format=None,
176               use_field_number=False):
177    """Initialize the Printer.
178
179    Floating point values can be formatted compactly with 15 digits of
180    precision (which is the most that IEEE 754 "double" can guarantee)
181    using float_format='.15g'. To ensure that converting to text and back to a
182    proto will result in an identical value, float_format='.17g' should be used.
183
184    Args:
185      out: To record the text format result.
186      indent: The indent level for pretty print.
187      as_utf8: Produce text output in UTF8 format.
188      as_one_line: Don't introduce newlines between fields.
189      pointy_brackets: If True, use angle brackets instead of curly braces for
190        nesting.
191      use_index_order: If True, print fields of a proto message using the order
192        defined in source code instead of the field number. By default, use the
193        field number order.
194      float_format: If set, use this to specify floating point number formatting
195        (per the "Format Specification Mini-Language"); otherwise, str() is
196        used.
197      use_field_number: If True, print field numbers instead of names.
198    """
199    self.out = out
200    self.indent = indent
201    self.as_utf8 = as_utf8
202    self.as_one_line = as_one_line
203    self.pointy_brackets = pointy_brackets
204    self.use_index_order = use_index_order
205    self.float_format = float_format
206    self.use_field_number = use_field_number
207
208  def PrintMessage(self, message):
209    """Convert protobuf message to text format.
210
211    Args:
212      message: The protocol buffers message.
213    """
214    fields = message.ListFields()
215    if self.use_index_order:
216      fields.sort(key=lambda x: x[0].index)
217    for field, value in fields:
218      if _IsMapEntry(field):
219        for key in sorted(value):
220          # This is slow for maps with submessage entires because it copies the
221          # entire tree.  Unfortunately this would take significant refactoring
222          # of this file to work around.
223          #
224          # TODO(haberman): refactor and optimize if this becomes an issue.
225          entry_submsg = field.message_type._concrete_class(
226              key=key, value=value[key])
227          self.PrintField(field, entry_submsg)
228      elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
229        for element in value:
230          self.PrintField(field, element)
231      else:
232        self.PrintField(field, value)
233
234  def PrintField(self, field, value):
235    """Print a single field name/value pair."""
236    out = self.out
237    out.write(' ' * self.indent)
238    if self.use_field_number:
239      out.write(str(field.number))
240    else:
241      if field.is_extension:
242        out.write('[')
243        if (field.containing_type.GetOptions().message_set_wire_format and
244            field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
245            field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
246          out.write(field.message_type.full_name)
247        else:
248          out.write(field.full_name)
249        out.write(']')
250      elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
251        # For groups, use the capitalized name.
252        out.write(field.message_type.name)
253      else:
254        out.write(field.name)
255
256    if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
257      # The colon is optional in this case, but our cross-language golden files
258      # don't include it.
259      out.write(': ')
260
261    self.PrintFieldValue(field, value)
262    if self.as_one_line:
263      out.write(' ')
264    else:
265      out.write('\n')
266
267  def PrintFieldValue(self, field, value):
268    """Print a single field value (not including name).
269
270    For repeated fields, the value should be a single element.
271
272    Args:
273      field: The descriptor of the field to be printed.
274      value: The value of the field.
275    """
276    out = self.out
277    if self.pointy_brackets:
278      openb = '<'
279      closeb = '>'
280    else:
281      openb = '{'
282      closeb = '}'
283
284    if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
285      if self.as_one_line:
286        out.write(' %s ' % openb)
287        self.PrintMessage(value)
288        out.write(closeb)
289      else:
290        out.write(' %s\n' % openb)
291        self.indent += 2
292        self.PrintMessage(value)
293        self.indent -= 2
294        out.write(' ' * self.indent + closeb)
295    elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
296      enum_value = field.enum_type.values_by_number.get(value, None)
297      if enum_value is not None:
298        out.write(enum_value.name)
299      else:
300        out.write(str(value))
301    elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
302      out.write('\"')
303      if isinstance(value, six.text_type):
304        out_value = value.encode('utf-8')
305      else:
306        out_value = value
307      if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
308        # We need to escape non-UTF8 chars in TYPE_BYTES field.
309        out_as_utf8 = False
310      else:
311        out_as_utf8 = self.as_utf8
312      out.write(text_encoding.CEscape(out_value, out_as_utf8))
313      out.write('\"')
314    elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
315      if value:
316        out.write('true')
317      else:
318        out.write('false')
319    elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None:
320      out.write('{1:{0}}'.format(self.float_format, value))
321    else:
322      out.write(str(value))
323
324
325def Parse(text, message,
326          allow_unknown_extension=False, allow_field_number=False):
327  """Parses an text representation of a protocol message into a message.
328
329  Args:
330    text: Message text representation.
331    message: A protocol buffer message to merge into.
332    allow_unknown_extension: if True, skip over missing extensions and keep
333      parsing
334    allow_field_number: if True, both field number and field name are allowed.
335
336  Returns:
337    The same message passed as argument.
338
339  Raises:
340    ParseError: On text parsing problems.
341  """
342  if not isinstance(text, str):
343    text = text.decode('utf-8')
344  return ParseLines(text.split('\n'), message, allow_unknown_extension,
345                    allow_field_number)
346
347
348def Merge(text, message, allow_unknown_extension=False,
349          allow_field_number=False):
350  """Parses an text representation of a protocol message into a message.
351
352  Like Parse(), but allows repeated values for a non-repeated field, and uses
353  the last one.
354
355  Args:
356    text: Message text representation.
357    message: A protocol buffer message to merge into.
358    allow_unknown_extension: if True, skip over missing extensions and keep
359      parsing
360    allow_field_number: if True, both field number and field name are allowed.
361
362  Returns:
363    The same message passed as argument.
364
365  Raises:
366    ParseError: On text parsing problems.
367  """
368  return MergeLines(text.split('\n'), message, allow_unknown_extension,
369                    allow_field_number)
370
371
372def ParseLines(lines, message, allow_unknown_extension=False,
373               allow_field_number=False):
374  """Parses an text representation of a protocol message into a message.
375
376  Args:
377    lines: An iterable of lines of a message's text representation.
378    message: A protocol buffer message to merge into.
379    allow_unknown_extension: if True, skip over missing extensions and keep
380      parsing
381    allow_field_number: if True, both field number and field name are allowed.
382
383  Returns:
384    The same message passed as argument.
385
386  Raises:
387    ParseError: On text parsing problems.
388  """
389  parser = _Parser(allow_unknown_extension, allow_field_number)
390  return parser.ParseLines(lines, message)
391
392
393def MergeLines(lines, message, allow_unknown_extension=False,
394               allow_field_number=False):
395  """Parses an text representation of a protocol message into a message.
396
397  Args:
398    lines: An iterable of lines of a message's text representation.
399    message: A protocol buffer message to merge into.
400    allow_unknown_extension: if True, skip over missing extensions and keep
401      parsing
402    allow_field_number: if True, both field number and field name are allowed.
403
404  Returns:
405    The same message passed as argument.
406
407  Raises:
408    ParseError: On text parsing problems.
409  """
410  parser = _Parser(allow_unknown_extension, allow_field_number)
411  return parser.MergeLines(lines, message)
412
413
414class _Parser(object):
415  """Text format parser for protocol message."""
416
417  def __init__(self, allow_unknown_extension=False, allow_field_number=False):
418    self.allow_unknown_extension = allow_unknown_extension
419    self.allow_field_number = allow_field_number
420
421  def ParseFromString(self, text, message):
422    """Parses an text representation of a protocol message into a message."""
423    if not isinstance(text, str):
424      text = text.decode('utf-8')
425    return self.ParseLines(text.split('\n'), message)
426
427  def ParseLines(self, lines, message):
428    """Parses an text representation of a protocol message into a message."""
429    self._allow_multiple_scalars = False
430    self._ParseOrMerge(lines, message)
431    return message
432
433  def MergeFromString(self, text, message):
434    """Merges an text representation of a protocol message into a message."""
435    return self._MergeLines(text.split('\n'), message)
436
437  def MergeLines(self, lines, message):
438    """Merges an text representation of a protocol message into a message."""
439    self._allow_multiple_scalars = True
440    self._ParseOrMerge(lines, message)
441    return message
442
443  def _ParseOrMerge(self, lines, message):
444    """Converts an text representation of a protocol message into a message.
445
446    Args:
447      lines: Lines of a message's text representation.
448      message: A protocol buffer message to merge into.
449
450    Raises:
451      ParseError: On text parsing problems.
452    """
453    tokenizer = _Tokenizer(lines)
454    while not tokenizer.AtEnd():
455      self._MergeField(tokenizer, message)
456
457  def _MergeField(self, tokenizer, message):
458    """Merges a single protocol message field into a message.
459
460    Args:
461      tokenizer: A tokenizer to parse the field name and values.
462      message: A protocol message to record the data.
463
464    Raises:
465      ParseError: In case of text parsing problems.
466    """
467    message_descriptor = message.DESCRIPTOR
468    if (hasattr(message_descriptor, 'syntax') and
469        message_descriptor.syntax == 'proto3'):
470      # Proto3 doesn't represent presence so we can't test if multiple
471      # scalars have occurred.  We have to allow them.
472      self._allow_multiple_scalars = True
473    if tokenizer.TryConsume('['):
474      name = [tokenizer.ConsumeIdentifier()]
475      while tokenizer.TryConsume('.'):
476        name.append(tokenizer.ConsumeIdentifier())
477      name = '.'.join(name)
478
479      if not message_descriptor.is_extendable:
480        raise tokenizer.ParseErrorPreviousToken(
481            'Message type "%s" does not have extensions.' %
482            message_descriptor.full_name)
483      # pylint: disable=protected-access
484      field = message.Extensions._FindExtensionByName(name)
485      # pylint: enable=protected-access
486      if not field:
487        if self.allow_unknown_extension:
488          field = None
489        else:
490          raise tokenizer.ParseErrorPreviousToken(
491              'Extension "%s" not registered.' % name)
492      elif message_descriptor != field.containing_type:
493        raise tokenizer.ParseErrorPreviousToken(
494            'Extension "%s" does not extend message type "%s".' % (
495                name, message_descriptor.full_name))
496
497      tokenizer.Consume(']')
498
499    else:
500      name = tokenizer.ConsumeIdentifier()
501      if self.allow_field_number and name.isdigit():
502        number = ParseInteger(name, True, True)
503        field = message_descriptor.fields_by_number.get(number, None)
504        if not field and message_descriptor.is_extendable:
505          field = message.Extensions._FindExtensionByNumber(number)
506      else:
507        field = message_descriptor.fields_by_name.get(name, None)
508
509        # Group names are expected to be capitalized as they appear in the
510        # .proto file, which actually matches their type names, not their field
511        # names.
512        if not field:
513          field = message_descriptor.fields_by_name.get(name.lower(), None)
514          if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
515            field = None
516
517        if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
518            field.message_type.name != name):
519          field = None
520
521      if not field:
522        raise tokenizer.ParseErrorPreviousToken(
523            'Message type "%s" has no field named "%s".' % (
524                message_descriptor.full_name, name))
525
526    if field:
527      if not self._allow_multiple_scalars and field.containing_oneof:
528        # Check if there's a different field set in this oneof.
529        # Note that we ignore the case if the same field was set before, and we
530        # apply _allow_multiple_scalars to non-scalar fields as well.
531        which_oneof = message.WhichOneof(field.containing_oneof.name)
532        if which_oneof is not None and which_oneof != field.name:
533          raise tokenizer.ParseErrorPreviousToken(
534              'Field "%s" is specified along with field "%s", another member '
535              'of oneof "%s" for message type "%s".' % (
536                  field.name, which_oneof, field.containing_oneof.name,
537                  message_descriptor.full_name))
538
539      if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
540        tokenizer.TryConsume(':')
541        merger = self._MergeMessageField
542      else:
543        tokenizer.Consume(':')
544        merger = self._MergeScalarField
545
546      if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED
547          and tokenizer.TryConsume('[')):
548        # Short repeated format, e.g. "foo: [1, 2, 3]"
549        while True:
550          merger(tokenizer, message, field)
551          if tokenizer.TryConsume(']'): break
552          tokenizer.Consume(',')
553
554      else:
555        merger(tokenizer, message, field)
556
557    else:  # Proto field is unknown.
558      assert self.allow_unknown_extension
559      _SkipFieldContents(tokenizer)
560
561    # For historical reasons, fields may optionally be separated by commas or
562    # semicolons.
563    if not tokenizer.TryConsume(','):
564      tokenizer.TryConsume(';')
565
566  def _MergeMessageField(self, tokenizer, message, field):
567    """Merges a single scalar field into a message.
568
569    Args:
570      tokenizer: A tokenizer to parse the field value.
571      message: The message of which field is a member.
572      field: The descriptor of the field to be merged.
573
574    Raises:
575      ParseError: In case of text parsing problems.
576    """
577    is_map_entry = _IsMapEntry(field)
578
579    if tokenizer.TryConsume('<'):
580      end_token = '>'
581    else:
582      tokenizer.Consume('{')
583      end_token = '}'
584
585    if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
586      if field.is_extension:
587        sub_message = message.Extensions[field].add()
588      elif is_map_entry:
589        # pylint: disable=protected-access
590        sub_message = field.message_type._concrete_class()
591      else:
592        sub_message = getattr(message, field.name).add()
593    else:
594      if field.is_extension:
595        sub_message = message.Extensions[field]
596      else:
597        sub_message = getattr(message, field.name)
598      sub_message.SetInParent()
599
600    while not tokenizer.TryConsume(end_token):
601      if tokenizer.AtEnd():
602        raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,))
603      self._MergeField(tokenizer, sub_message)
604
605    if is_map_entry:
606      value_cpptype = field.message_type.fields_by_name['value'].cpp_type
607      if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
608        value = getattr(message, field.name)[sub_message.key]
609        value.MergeFrom(sub_message.value)
610      else:
611        getattr(message, field.name)[sub_message.key] = sub_message.value
612
613  def _MergeScalarField(self, tokenizer, message, field):
614    """Merges a single scalar field into a message.
615
616    Args:
617      tokenizer: A tokenizer to parse the field value.
618      message: A protocol message to record the data.
619      field: The descriptor of the field to be merged.
620
621    Raises:
622      ParseError: In case of text parsing problems.
623      RuntimeError: On runtime errors.
624    """
625    _ = self.allow_unknown_extension
626    value = None
627
628    if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
629                      descriptor.FieldDescriptor.TYPE_SINT32,
630                      descriptor.FieldDescriptor.TYPE_SFIXED32):
631      value = tokenizer.ConsumeInt32()
632    elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
633                        descriptor.FieldDescriptor.TYPE_SINT64,
634                        descriptor.FieldDescriptor.TYPE_SFIXED64):
635      value = tokenizer.ConsumeInt64()
636    elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
637                        descriptor.FieldDescriptor.TYPE_FIXED32):
638      value = tokenizer.ConsumeUint32()
639    elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
640                        descriptor.FieldDescriptor.TYPE_FIXED64):
641      value = tokenizer.ConsumeUint64()
642    elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
643                        descriptor.FieldDescriptor.TYPE_DOUBLE):
644      value = tokenizer.ConsumeFloat()
645    elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
646      value = tokenizer.ConsumeBool()
647    elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
648      value = tokenizer.ConsumeString()
649    elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
650      value = tokenizer.ConsumeByteString()
651    elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
652      value = tokenizer.ConsumeEnum(field)
653    else:
654      raise RuntimeError('Unknown field type %d' % field.type)
655
656    if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
657      if field.is_extension:
658        message.Extensions[field].append(value)
659      else:
660        getattr(message, field.name).append(value)
661    else:
662      if field.is_extension:
663        if not self._allow_multiple_scalars and message.HasExtension(field):
664          raise tokenizer.ParseErrorPreviousToken(
665              'Message type "%s" should not have multiple "%s" extensions.' %
666              (message.DESCRIPTOR.full_name, field.full_name))
667        else:
668          message.Extensions[field] = value
669      else:
670        if not self._allow_multiple_scalars and message.HasField(field.name):
671          raise tokenizer.ParseErrorPreviousToken(
672              'Message type "%s" should not have multiple "%s" fields.' %
673              (message.DESCRIPTOR.full_name, field.name))
674        else:
675          setattr(message, field.name, value)
676
677
678def _SkipFieldContents(tokenizer):
679  """Skips over contents (value or message) of a field.
680
681  Args:
682    tokenizer: A tokenizer to parse the field name and values.
683  """
684  # Try to guess the type of this field.
685  # If this field is not a message, there should be a ":" between the
686  # field name and the field value and also the field value should not
687  # start with "{" or "<" which indicates the beginning of a message body.
688  # If there is no ":" or there is a "{" or "<" after ":", this field has
689  # to be a message or the input is ill-formed.
690  if tokenizer.TryConsume(':') and not tokenizer.LookingAt(
691      '{') and not tokenizer.LookingAt('<'):
692    _SkipFieldValue(tokenizer)
693  else:
694    _SkipFieldMessage(tokenizer)
695
696
697def _SkipField(tokenizer):
698  """Skips over a complete field (name and value/message).
699
700  Args:
701    tokenizer: A tokenizer to parse the field name and values.
702  """
703  if tokenizer.TryConsume('['):
704    # Consume extension name.
705    tokenizer.ConsumeIdentifier()
706    while tokenizer.TryConsume('.'):
707      tokenizer.ConsumeIdentifier()
708    tokenizer.Consume(']')
709  else:
710    tokenizer.ConsumeIdentifier()
711
712  _SkipFieldContents(tokenizer)
713
714  # For historical reasons, fields may optionally be separated by commas or
715  # semicolons.
716  if not tokenizer.TryConsume(','):
717    tokenizer.TryConsume(';')
718
719
720def _SkipFieldMessage(tokenizer):
721  """Skips over a field message.
722
723  Args:
724    tokenizer: A tokenizer to parse the field name and values.
725  """
726
727  if tokenizer.TryConsume('<'):
728    delimiter = '>'
729  else:
730    tokenizer.Consume('{')
731    delimiter = '}'
732
733  while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'):
734    _SkipField(tokenizer)
735
736  tokenizer.Consume(delimiter)
737
738
739def _SkipFieldValue(tokenizer):
740  """Skips over a field value.
741
742  Args:
743    tokenizer: A tokenizer to parse the field name and values.
744
745  Raises:
746    ParseError: In case an invalid field value is found.
747  """
748  # String/bytes tokens can come in multiple adjacent string literals.
749  # If we can consume one, consume as many as we can.
750  if tokenizer.TryConsumeByteString():
751    while tokenizer.TryConsumeByteString():
752      pass
753    return
754
755  if (not tokenizer.TryConsumeIdentifier() and
756      not tokenizer.TryConsumeInt64() and
757      not tokenizer.TryConsumeUint64() and
758      not tokenizer.TryConsumeFloat()):
759    raise ParseError('Invalid field value: ' + tokenizer.token)
760
761
762class _Tokenizer(object):
763  """Protocol buffer text representation tokenizer.
764
765  This class handles the lower level string parsing by splitting it into
766  meaningful tokens.
767
768  It was directly ported from the Java protocol buffer API.
769  """
770
771  _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
772  _TOKEN = re.compile('|'.join([
773      r'[a-zA-Z_][0-9a-zA-Z_+-]*',             # an identifier
774      r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*',  # a number
775  ] + [                                        # quoted str for each quote mark
776      r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
777  ]))
778
779  _IDENTIFIER = re.compile(r'\w+')
780
781  def __init__(self, lines):
782    self._position = 0
783    self._line = -1
784    self._column = 0
785    self._token_start = None
786    self.token = ''
787    self._lines = iter(lines)
788    self._current_line = ''
789    self._previous_line = 0
790    self._previous_column = 0
791    self._more_lines = True
792    self._SkipWhitespace()
793    self.NextToken()
794
795  def LookingAt(self, token):
796    return self.token == token
797
798  def AtEnd(self):
799    """Checks the end of the text was reached.
800
801    Returns:
802      True iff the end was reached.
803    """
804    return not self.token
805
806  def _PopLine(self):
807    while len(self._current_line) <= self._column:
808      try:
809        self._current_line = next(self._lines)
810      except StopIteration:
811        self._current_line = ''
812        self._more_lines = False
813        return
814      else:
815        self._line += 1
816        self._column = 0
817
818  def _SkipWhitespace(self):
819    while True:
820      self._PopLine()
821      match = self._WHITESPACE.match(self._current_line, self._column)
822      if not match:
823        break
824      length = len(match.group(0))
825      self._column += length
826
827  def TryConsume(self, token):
828    """Tries to consume a given piece of text.
829
830    Args:
831      token: Text to consume.
832
833    Returns:
834      True iff the text was consumed.
835    """
836    if self.token == token:
837      self.NextToken()
838      return True
839    return False
840
841  def Consume(self, token):
842    """Consumes a piece of text.
843
844    Args:
845      token: Text to consume.
846
847    Raises:
848      ParseError: If the text couldn't be consumed.
849    """
850    if not self.TryConsume(token):
851      raise self._ParseError('Expected "%s".' % token)
852
853  def TryConsumeIdentifier(self):
854    try:
855      self.ConsumeIdentifier()
856      return True
857    except ParseError:
858      return False
859
860  def ConsumeIdentifier(self):
861    """Consumes protocol message field identifier.
862
863    Returns:
864      Identifier string.
865
866    Raises:
867      ParseError: If an identifier couldn't be consumed.
868    """
869    result = self.token
870    if not self._IDENTIFIER.match(result):
871      raise self._ParseError('Expected identifier.')
872    self.NextToken()
873    return result
874
875  def ConsumeInt32(self):
876    """Consumes a signed 32bit integer number.
877
878    Returns:
879      The integer parsed.
880
881    Raises:
882      ParseError: If a signed 32bit integer couldn't be consumed.
883    """
884    try:
885      result = ParseInteger(self.token, is_signed=True, is_long=False)
886    except ValueError as e:
887      raise self._ParseError(str(e))
888    self.NextToken()
889    return result
890
891  def ConsumeUint32(self):
892    """Consumes an unsigned 32bit integer number.
893
894    Returns:
895      The integer parsed.
896
897    Raises:
898      ParseError: If an unsigned 32bit integer couldn't be consumed.
899    """
900    try:
901      result = ParseInteger(self.token, is_signed=False, is_long=False)
902    except ValueError as e:
903      raise self._ParseError(str(e))
904    self.NextToken()
905    return result
906
907  def TryConsumeInt64(self):
908    try:
909      self.ConsumeInt64()
910      return True
911    except ParseError:
912      return False
913
914  def ConsumeInt64(self):
915    """Consumes a signed 64bit integer number.
916
917    Returns:
918      The integer parsed.
919
920    Raises:
921      ParseError: If a signed 64bit integer couldn't be consumed.
922    """
923    try:
924      result = ParseInteger(self.token, is_signed=True, is_long=True)
925    except ValueError as e:
926      raise self._ParseError(str(e))
927    self.NextToken()
928    return result
929
930  def TryConsumeUint64(self):
931    try:
932      self.ConsumeUint64()
933      return True
934    except ParseError:
935      return False
936
937  def ConsumeUint64(self):
938    """Consumes an unsigned 64bit integer number.
939
940    Returns:
941      The integer parsed.
942
943    Raises:
944      ParseError: If an unsigned 64bit integer couldn't be consumed.
945    """
946    try:
947      result = ParseInteger(self.token, is_signed=False, is_long=True)
948    except ValueError as e:
949      raise self._ParseError(str(e))
950    self.NextToken()
951    return result
952
953  def TryConsumeFloat(self):
954    try:
955      self.ConsumeFloat()
956      return True
957    except ParseError:
958      return False
959
960  def ConsumeFloat(self):
961    """Consumes an floating point number.
962
963    Returns:
964      The number parsed.
965
966    Raises:
967      ParseError: If a floating point number couldn't be consumed.
968    """
969    try:
970      result = ParseFloat(self.token)
971    except ValueError as e:
972      raise self._ParseError(str(e))
973    self.NextToken()
974    return result
975
976  def ConsumeBool(self):
977    """Consumes a boolean value.
978
979    Returns:
980      The bool parsed.
981
982    Raises:
983      ParseError: If a boolean value couldn't be consumed.
984    """
985    try:
986      result = ParseBool(self.token)
987    except ValueError as e:
988      raise self._ParseError(str(e))
989    self.NextToken()
990    return result
991
992  def TryConsumeByteString(self):
993    try:
994      self.ConsumeByteString()
995      return True
996    except ParseError:
997      return False
998
999  def ConsumeString(self):
1000    """Consumes a string value.
1001
1002    Returns:
1003      The string parsed.
1004
1005    Raises:
1006      ParseError: If a string value couldn't be consumed.
1007    """
1008    the_bytes = self.ConsumeByteString()
1009    try:
1010      return six.text_type(the_bytes, 'utf-8')
1011    except UnicodeDecodeError as e:
1012      raise self._StringParseError(e)
1013
1014  def ConsumeByteString(self):
1015    """Consumes a byte array value.
1016
1017    Returns:
1018      The array parsed (as a string).
1019
1020    Raises:
1021      ParseError: If a byte array value couldn't be consumed.
1022    """
1023    the_list = [self._ConsumeSingleByteString()]
1024    while self.token and self.token[0] in _QUOTES:
1025      the_list.append(self._ConsumeSingleByteString())
1026    return b''.join(the_list)
1027
1028  def _ConsumeSingleByteString(self):
1029    """Consume one token of a string literal.
1030
1031    String literals (whether bytes or text) can come in multiple adjacent
1032    tokens which are automatically concatenated, like in C or Python.  This
1033    method only consumes one token.
1034
1035    Returns:
1036      The token parsed.
1037    Raises:
1038      ParseError: When the wrong format data is found.
1039    """
1040    text = self.token
1041    if len(text) < 1 or text[0] not in _QUOTES:
1042      raise self._ParseError('Expected string but found: %r' % (text,))
1043
1044    if len(text) < 2 or text[-1] != text[0]:
1045      raise self._ParseError('String missing ending quote: %r' % (text,))
1046
1047    try:
1048      result = text_encoding.CUnescape(text[1:-1])
1049    except ValueError as e:
1050      raise self._ParseError(str(e))
1051    self.NextToken()
1052    return result
1053
1054  def ConsumeEnum(self, field):
1055    try:
1056      result = ParseEnum(field, self.token)
1057    except ValueError as e:
1058      raise self._ParseError(str(e))
1059    self.NextToken()
1060    return result
1061
1062  def ParseErrorPreviousToken(self, message):
1063    """Creates and *returns* a ParseError for the previously read token.
1064
1065    Args:
1066      message: A message to set for the exception.
1067
1068    Returns:
1069      A ParseError instance.
1070    """
1071    return ParseError('%d:%d : %s' % (
1072        self._previous_line + 1, self._previous_column + 1, message))
1073
1074  def _ParseError(self, message):
1075    """Creates and *returns* a ParseError for the current token."""
1076    return ParseError('%d:%d : %s' % (
1077        self._line + 1, self._column + 1, message))
1078
1079  def _StringParseError(self, e):
1080    return self._ParseError('Couldn\'t parse string: ' + str(e))
1081
1082  def NextToken(self):
1083    """Reads the next meaningful token."""
1084    self._previous_line = self._line
1085    self._previous_column = self._column
1086
1087    self._column += len(self.token)
1088    self._SkipWhitespace()
1089
1090    if not self._more_lines:
1091      self.token = ''
1092      return
1093
1094    match = self._TOKEN.match(self._current_line, self._column)
1095    if match:
1096      token = match.group(0)
1097      self.token = token
1098    else:
1099      self.token = self._current_line[self._column]
1100
1101
1102def ParseInteger(text, is_signed=False, is_long=False):
1103  """Parses an integer.
1104
1105  Args:
1106    text: The text to parse.
1107    is_signed: True if a signed integer must be parsed.
1108    is_long: True if a long integer must be parsed.
1109
1110  Returns:
1111    The integer value.
1112
1113  Raises:
1114    ValueError: Thrown Iff the text is not a valid integer.
1115  """
1116  # Do the actual parsing. Exception handling is propagated to caller.
1117  try:
1118    # We force 32-bit values to int and 64-bit values to long to make
1119    # alternate implementations where the distinction is more significant
1120    # (e.g. the C++ implementation) simpler.
1121    if is_long:
1122      result = long(text, 0)
1123    else:
1124      result = int(text, 0)
1125  except ValueError:
1126    raise ValueError('Couldn\'t parse integer: %s' % text)
1127
1128  # Check if the integer is sane. Exceptions handled by callers.
1129  checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
1130  checker.CheckValue(result)
1131  return result
1132
1133
1134def ParseFloat(text):
1135  """Parse a floating point number.
1136
1137  Args:
1138    text: Text to parse.
1139
1140  Returns:
1141    The number parsed.
1142
1143  Raises:
1144    ValueError: If a floating point number couldn't be parsed.
1145  """
1146  try:
1147    # Assume Python compatible syntax.
1148    return float(text)
1149  except ValueError:
1150    # Check alternative spellings.
1151    if _FLOAT_INFINITY.match(text):
1152      if text[0] == '-':
1153        return float('-inf')
1154      else:
1155        return float('inf')
1156    elif _FLOAT_NAN.match(text):
1157      return float('nan')
1158    else:
1159      # assume '1.0f' format
1160      try:
1161        return float(text.rstrip('f'))
1162      except ValueError:
1163        raise ValueError('Couldn\'t parse float: %s' % text)
1164
1165
1166def ParseBool(text):
1167  """Parse a boolean value.
1168
1169  Args:
1170    text: Text to parse.
1171
1172  Returns:
1173    Boolean values parsed
1174
1175  Raises:
1176    ValueError: If text is not a valid boolean.
1177  """
1178  if text in ('true', 't', '1'):
1179    return True
1180  elif text in ('false', 'f', '0'):
1181    return False
1182  else:
1183    raise ValueError('Expected "true" or "false".')
1184
1185
1186def ParseEnum(field, value):
1187  """Parse an enum value.
1188
1189  The value can be specified by a number (the enum value), or by
1190  a string literal (the enum name).
1191
1192  Args:
1193    field: Enum field descriptor.
1194    value: String value.
1195
1196  Returns:
1197    Enum value number.
1198
1199  Raises:
1200    ValueError: If the enum value could not be parsed.
1201  """
1202  enum_descriptor = field.enum_type
1203  try:
1204    number = int(value, 0)
1205  except ValueError:
1206    # Identifier.
1207    enum_value = enum_descriptor.values_by_name.get(value, None)
1208    if enum_value is None:
1209      raise ValueError(
1210          'Enum type "%s" has no value named %s.' % (
1211              enum_descriptor.full_name, value))
1212  else:
1213    # Numeric value.
1214    enum_value = enum_descriptor.values_by_number.get(number, None)
1215    if enum_value is None:
1216      raise ValueError(
1217          'Enum type "%s" has no value with number %d.' % (
1218              enum_descriptor.full_name, number))
1219  return enum_value.number
1220