1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for decoding protocol buffer primitives.
32
33This code is very similar to encoder.py -- read the docs for that module first.
34
35A "decoder" is a function with the signature:
36  Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38  buffer:     The string containing the encoded message.
39  pos:        The current position in the string.
40  end:        The position in the string where the current message ends.  May be
41              less than len(buffer) if we're reading a sub-message.
42  message:    The message object into which we're parsing.
43  field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position.  A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
47
48Note that decoders may throw any of the following:
49  IndexError:  Indicates a truncated message.
50  struct.error:  Unpacking of a fixed-width field failed.
51  message.DecodeError:  Other errors.
52
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking:  it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
57
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
60
61Decoders are constructed using decoder constructors with the signature:
62  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64  field_number:  The field number of the field we want to decode.
65  is_repeated:   Is the field a repeated field? (bool)
66  is_packed:     Is the field a packed field? (bool)
67  key:           The key to use when looking up the field within field_dict.
68                 (This is actually the FieldDescriptor but nothing in this
69                 file should depend on that.)
70  new_default:   A function which takes a message object as a parameter and
71                 returns a new instance of the default value for this field.
72                 (This is called for repeated fields and sub-messages, when an
73                 instance does not already exist.)
74
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
80
81__author__ = 'kenton@google.com (Kenton Varda)'
82
83import struct
84import sys
85import six
86
87_UCS2_MAXUNICODE = 65535
88if six.PY3:
89  long = int
90else:
91  import re    # pylint: disable=g-import-not-at-top
92  _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
93
94from google.protobuf.internal import containers
95from google.protobuf.internal import encoder
96from google.protobuf.internal import wire_format
97from google.protobuf import message
98
99
100# This will overflow and thus become IEEE-754 "infinity".  We would use
101# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
102_POS_INF = 1e10000
103_NEG_INF = -_POS_INF
104_NAN = _POS_INF * 0
105
106
107# This is not for optimization, but rather to avoid conflicts with local
108# variables named "message".
109_DecodeError = message.DecodeError
110
111
112def _VarintDecoder(mask, result_type):
113  """Return an encoder for a basic varint value (does not include tag).
114
115  Decoded values will be bitwise-anded with the given mask before being
116  returned, e.g. to limit them to 32 bits.  The returned decoder does not
117  take the usual "end" parameter -- the caller is expected to do bounds checking
118  after the fact (often the caller can defer such checking until later).  The
119  decoder returns a (value, new_pos) pair.
120  """
121
122  def DecodeVarint(buffer, pos):
123    result = 0
124    shift = 0
125    while 1:
126      b = six.indexbytes(buffer, pos)
127      result |= ((b & 0x7f) << shift)
128      pos += 1
129      if not (b & 0x80):
130        result &= mask
131        result = result_type(result)
132        return (result, pos)
133      shift += 7
134      if shift >= 64:
135        raise _DecodeError('Too many bytes when decoding varint.')
136  return DecodeVarint
137
138
139def _SignedVarintDecoder(bits, result_type):
140  """Like _VarintDecoder() but decodes signed values."""
141
142  signbit = 1 << (bits - 1)
143  mask = (1 << bits) - 1
144
145  def DecodeVarint(buffer, pos):
146    result = 0
147    shift = 0
148    while 1:
149      b = six.indexbytes(buffer, pos)
150      result |= ((b & 0x7f) << shift)
151      pos += 1
152      if not (b & 0x80):
153        result &= mask
154        result = (result ^ signbit) - signbit
155        result = result_type(result)
156        return (result, pos)
157      shift += 7
158      if shift >= 64:
159        raise _DecodeError('Too many bytes when decoding varint.')
160  return DecodeVarint
161
162# We force 32-bit values to int and 64-bit values to long to make
163# alternate implementations where the distinction is more significant
164# (e.g. the C++ implementation) simpler.
165
166_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
167_DecodeSignedVarint = _SignedVarintDecoder(64, long)
168
169# Use these versions for values which must be limited to 32 bits.
170_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
171_DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
172
173
174def ReadTag(buffer, pos):
175  """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
176
177  We return the raw bytes of the tag rather than decoding them.  The raw
178  bytes can then be used to look up the proper decoder.  This effectively allows
179  us to trade some work that would be done in pure-python (decoding a varint)
180  for work that is done in C (searching for a byte string in a hash table).
181  In a low-level language it would be much cheaper to decode the varint and
182  use that, but not in Python.
183
184  Args:
185    buffer: memoryview object of the encoded bytes
186    pos: int of the current position to start from
187
188  Returns:
189    Tuple[bytes, int] of the tag data and new position.
190  """
191  start = pos
192  while six.indexbytes(buffer, pos) & 0x80:
193    pos += 1
194  pos += 1
195
196  tag_bytes = buffer[start:pos].tobytes()
197  return tag_bytes, pos
198
199
200# --------------------------------------------------------------------
201
202
203def _SimpleDecoder(wire_type, decode_value):
204  """Return a constructor for a decoder for fields of a particular type.
205
206  Args:
207      wire_type:  The field's wire type.
208      decode_value:  A function which decodes an individual value, e.g.
209        _DecodeVarint()
210  """
211
212  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
213    if is_packed:
214      local_DecodeVarint = _DecodeVarint
215      def DecodePackedField(buffer, pos, end, message, field_dict):
216        value = field_dict.get(key)
217        if value is None:
218          value = field_dict.setdefault(key, new_default(message))
219        (endpoint, pos) = local_DecodeVarint(buffer, pos)
220        endpoint += pos
221        if endpoint > end:
222          raise _DecodeError('Truncated message.')
223        while pos < endpoint:
224          (element, pos) = decode_value(buffer, pos)
225          value.append(element)
226        if pos > endpoint:
227          del value[-1]   # Discard corrupt value.
228          raise _DecodeError('Packed element was truncated.')
229        return pos
230      return DecodePackedField
231    elif is_repeated:
232      tag_bytes = encoder.TagBytes(field_number, wire_type)
233      tag_len = len(tag_bytes)
234      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
235        value = field_dict.get(key)
236        if value is None:
237          value = field_dict.setdefault(key, new_default(message))
238        while 1:
239          (element, new_pos) = decode_value(buffer, pos)
240          value.append(element)
241          # Predict that the next tag is another copy of the same repeated
242          # field.
243          pos = new_pos + tag_len
244          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
245            # Prediction failed.  Return.
246            if new_pos > end:
247              raise _DecodeError('Truncated message.')
248            return new_pos
249      return DecodeRepeatedField
250    else:
251      def DecodeField(buffer, pos, end, message, field_dict):
252        (field_dict[key], pos) = decode_value(buffer, pos)
253        if pos > end:
254          del field_dict[key]  # Discard corrupt value.
255          raise _DecodeError('Truncated message.')
256        return pos
257      return DecodeField
258
259  return SpecificDecoder
260
261
262def _ModifiedDecoder(wire_type, decode_value, modify_value):
263  """Like SimpleDecoder but additionally invokes modify_value on every value
264  before storing it.  Usually modify_value is ZigZagDecode.
265  """
266
267  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
268  # not enough to make a significant difference.
269
270  def InnerDecode(buffer, pos):
271    (result, new_pos) = decode_value(buffer, pos)
272    return (modify_value(result), new_pos)
273  return _SimpleDecoder(wire_type, InnerDecode)
274
275
276def _StructPackDecoder(wire_type, format):
277  """Return a constructor for a decoder for a fixed-width field.
278
279  Args:
280      wire_type:  The field's wire type.
281      format:  The format string to pass to struct.unpack().
282  """
283
284  value_size = struct.calcsize(format)
285  local_unpack = struct.unpack
286
287  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
288  # not enough to make a significant difference.
289
290  # Note that we expect someone up-stack to catch struct.error and convert
291  # it to _DecodeError -- this way we don't have to set up exception-
292  # handling blocks every time we parse one value.
293
294  def InnerDecode(buffer, pos):
295    new_pos = pos + value_size
296    result = local_unpack(format, buffer[pos:new_pos])[0]
297    return (result, new_pos)
298  return _SimpleDecoder(wire_type, InnerDecode)
299
300
301def _FloatDecoder():
302  """Returns a decoder for a float field.
303
304  This code works around a bug in struct.unpack for non-finite 32-bit
305  floating-point values.
306  """
307
308  local_unpack = struct.unpack
309
310  def InnerDecode(buffer, pos):
311    """Decode serialized float to a float and new position.
312
313    Args:
314      buffer: memoryview of the serialized bytes
315      pos: int, position in the memory view to start at.
316
317    Returns:
318      Tuple[float, int] of the deserialized float value and new position
319      in the serialized data.
320    """
321    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
322    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
323    new_pos = pos + 4
324    float_bytes = buffer[pos:new_pos].tobytes()
325
326    # If this value has all its exponent bits set, then it's non-finite.
327    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
328    # To avoid that, we parse it specially.
329    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
330      # If at least one significand bit is set...
331      if float_bytes[0:3] != b'\x00\x00\x80':
332        return (_NAN, new_pos)
333      # If sign bit is set...
334      if float_bytes[3:4] == b'\xFF':
335        return (_NEG_INF, new_pos)
336      return (_POS_INF, new_pos)
337
338    # Note that we expect someone up-stack to catch struct.error and convert
339    # it to _DecodeError -- this way we don't have to set up exception-
340    # handling blocks every time we parse one value.
341    result = local_unpack('<f', float_bytes)[0]
342    return (result, new_pos)
343  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
344
345
346def _DoubleDecoder():
347  """Returns a decoder for a double field.
348
349  This code works around a bug in struct.unpack for not-a-number.
350  """
351
352  local_unpack = struct.unpack
353
354  def InnerDecode(buffer, pos):
355    """Decode serialized double to a double and new position.
356
357    Args:
358      buffer: memoryview of the serialized bytes.
359      pos: int, position in the memory view to start at.
360
361    Returns:
362      Tuple[float, int] of the decoded double value and new position
363      in the serialized data.
364    """
365    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
366    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
367    new_pos = pos + 8
368    double_bytes = buffer[pos:new_pos].tobytes()
369
370    # If this value has all its exponent bits set and at least one significand
371    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
372    # as inf or -inf.  To avoid that, we treat it specially.
373    if ((double_bytes[7:8] in b'\x7F\xFF')
374        and (double_bytes[6:7] >= b'\xF0')
375        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
376      return (_NAN, new_pos)
377
378    # Note that we expect someone up-stack to catch struct.error and convert
379    # it to _DecodeError -- this way we don't have to set up exception-
380    # handling blocks every time we parse one value.
381    result = local_unpack('<d', double_bytes)[0]
382    return (result, new_pos)
383  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
384
385
386def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
387  enum_type = key.enum_type
388  if is_packed:
389    local_DecodeVarint = _DecodeVarint
390    def DecodePackedField(buffer, pos, end, message, field_dict):
391      """Decode serialized packed enum to its value and a new position.
392
393      Args:
394        buffer: memoryview of the serialized bytes.
395        pos: int, position in the memory view to start at.
396        end: int, end position of serialized data
397        message: Message object to store unknown fields in
398        field_dict: Map[Descriptor, Any] to store decoded values in.
399
400      Returns:
401        int, new position in serialized data.
402      """
403      value = field_dict.get(key)
404      if value is None:
405        value = field_dict.setdefault(key, new_default(message))
406      (endpoint, pos) = local_DecodeVarint(buffer, pos)
407      endpoint += pos
408      if endpoint > end:
409        raise _DecodeError('Truncated message.')
410      while pos < endpoint:
411        value_start_pos = pos
412        (element, pos) = _DecodeSignedVarint32(buffer, pos)
413        # pylint: disable=protected-access
414        if element in enum_type.values_by_number:
415          value.append(element)
416        else:
417          if not message._unknown_fields:
418            message._unknown_fields = []
419          tag_bytes = encoder.TagBytes(field_number,
420                                       wire_format.WIRETYPE_VARINT)
421
422          message._unknown_fields.append(
423              (tag_bytes, buffer[value_start_pos:pos].tobytes()))
424          # pylint: enable=protected-access
425      if pos > endpoint:
426        if element in enum_type.values_by_number:
427          del value[-1]   # Discard corrupt value.
428        else:
429          del message._unknown_fields[-1]
430        raise _DecodeError('Packed element was truncated.')
431      return pos
432    return DecodePackedField
433  elif is_repeated:
434    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
435    tag_len = len(tag_bytes)
436    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
437      """Decode serialized repeated enum to its value and a new position.
438
439      Args:
440        buffer: memoryview of the serialized bytes.
441        pos: int, position in the memory view to start at.
442        end: int, end position of serialized data
443        message: Message object to store unknown fields in
444        field_dict: Map[Descriptor, Any] to store decoded values in.
445
446      Returns:
447        int, new position in serialized data.
448      """
449      value = field_dict.get(key)
450      if value is None:
451        value = field_dict.setdefault(key, new_default(message))
452      while 1:
453        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
454        # pylint: disable=protected-access
455        if element in enum_type.values_by_number:
456          value.append(element)
457        else:
458          if not message._unknown_fields:
459            message._unknown_fields = []
460          message._unknown_fields.append(
461              (tag_bytes, buffer[pos:new_pos].tobytes()))
462          # pylint: enable=protected-access
463        # Predict that the next tag is another copy of the same repeated
464        # field.
465        pos = new_pos + tag_len
466        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
467          # Prediction failed.  Return.
468          if new_pos > end:
469            raise _DecodeError('Truncated message.')
470          return new_pos
471    return DecodeRepeatedField
472  else:
473    def DecodeField(buffer, pos, end, message, field_dict):
474      """Decode serialized repeated enum to its value and a new position.
475
476      Args:
477        buffer: memoryview of the serialized bytes.
478        pos: int, position in the memory view to start at.
479        end: int, end position of serialized data
480        message: Message object to store unknown fields in
481        field_dict: Map[Descriptor, Any] to store decoded values in.
482
483      Returns:
484        int, new position in serialized data.
485      """
486      value_start_pos = pos
487      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
488      if pos > end:
489        raise _DecodeError('Truncated message.')
490      # pylint: disable=protected-access
491      if enum_value in enum_type.values_by_number:
492        field_dict[key] = enum_value
493      else:
494        if not message._unknown_fields:
495          message._unknown_fields = []
496        tag_bytes = encoder.TagBytes(field_number,
497                                     wire_format.WIRETYPE_VARINT)
498        message._unknown_fields.append(
499            (tag_bytes, buffer[value_start_pos:pos].tobytes()))
500        # pylint: enable=protected-access
501      return pos
502    return DecodeField
503
504
505# --------------------------------------------------------------------
506
507
508Int32Decoder = _SimpleDecoder(
509    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
510
511Int64Decoder = _SimpleDecoder(
512    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
513
514UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
515UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
516
517SInt32Decoder = _ModifiedDecoder(
518    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
519SInt64Decoder = _ModifiedDecoder(
520    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
521
522# Note that Python conveniently guarantees that when using the '<' prefix on
523# formats, they will also have the same size across all platforms (as opposed
524# to without the prefix, where their sizes depend on the C compiler's basic
525# type sizes).
526Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
527Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
528SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
529SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
530FloatDecoder = _FloatDecoder()
531DoubleDecoder = _DoubleDecoder()
532
533BoolDecoder = _ModifiedDecoder(
534    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
535
536
537def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
538                  is_strict_utf8=False):
539  """Returns a decoder for a string field."""
540
541  local_DecodeVarint = _DecodeVarint
542  local_unicode = six.text_type
543
544  def _ConvertToUnicode(memview):
545    """Convert byte to unicode."""
546    byte_str = memview.tobytes()
547    try:
548      value = local_unicode(byte_str, 'utf-8')
549    except UnicodeDecodeError as e:
550      # add more information to the error message and re-raise it.
551      e.reason = '%s in field: %s' % (e, key.full_name)
552      raise
553
554    if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
555      # Only do the check for python2 ucs4 when is_strict_utf8 enabled
556      if _SURROGATE_PATTERN.search(value):
557        reason = ('String field %s contains invalid UTF-8 data when parsing'
558                  'a protocol buffer: surrogates not allowed. Use'
559                  'the bytes type if you intend to send raw bytes.') % (
560                      key.full_name)
561        raise message.DecodeError(reason)
562
563    return value
564
565  assert not is_packed
566  if is_repeated:
567    tag_bytes = encoder.TagBytes(field_number,
568                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
569    tag_len = len(tag_bytes)
570    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
571      value = field_dict.get(key)
572      if value is None:
573        value = field_dict.setdefault(key, new_default(message))
574      while 1:
575        (size, pos) = local_DecodeVarint(buffer, pos)
576        new_pos = pos + size
577        if new_pos > end:
578          raise _DecodeError('Truncated string.')
579        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
580        # Predict that the next tag is another copy of the same repeated field.
581        pos = new_pos + tag_len
582        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
583          # Prediction failed.  Return.
584          return new_pos
585    return DecodeRepeatedField
586  else:
587    def DecodeField(buffer, pos, end, message, field_dict):
588      (size, pos) = local_DecodeVarint(buffer, pos)
589      new_pos = pos + size
590      if new_pos > end:
591        raise _DecodeError('Truncated string.')
592      field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
593      return new_pos
594    return DecodeField
595
596
597def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
598  """Returns a decoder for a bytes field."""
599
600  local_DecodeVarint = _DecodeVarint
601
602  assert not is_packed
603  if is_repeated:
604    tag_bytes = encoder.TagBytes(field_number,
605                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
606    tag_len = len(tag_bytes)
607    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
608      value = field_dict.get(key)
609      if value is None:
610        value = field_dict.setdefault(key, new_default(message))
611      while 1:
612        (size, pos) = local_DecodeVarint(buffer, pos)
613        new_pos = pos + size
614        if new_pos > end:
615          raise _DecodeError('Truncated string.')
616        value.append(buffer[pos:new_pos].tobytes())
617        # Predict that the next tag is another copy of the same repeated field.
618        pos = new_pos + tag_len
619        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
620          # Prediction failed.  Return.
621          return new_pos
622    return DecodeRepeatedField
623  else:
624    def DecodeField(buffer, pos, end, message, field_dict):
625      (size, pos) = local_DecodeVarint(buffer, pos)
626      new_pos = pos + size
627      if new_pos > end:
628        raise _DecodeError('Truncated string.')
629      field_dict[key] = buffer[pos:new_pos].tobytes()
630      return new_pos
631    return DecodeField
632
633
634def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
635  """Returns a decoder for a group field."""
636
637  end_tag_bytes = encoder.TagBytes(field_number,
638                                   wire_format.WIRETYPE_END_GROUP)
639  end_tag_len = len(end_tag_bytes)
640
641  assert not is_packed
642  if is_repeated:
643    tag_bytes = encoder.TagBytes(field_number,
644                                 wire_format.WIRETYPE_START_GROUP)
645    tag_len = len(tag_bytes)
646    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
647      value = field_dict.get(key)
648      if value is None:
649        value = field_dict.setdefault(key, new_default(message))
650      while 1:
651        value = field_dict.get(key)
652        if value is None:
653          value = field_dict.setdefault(key, new_default(message))
654        # Read sub-message.
655        pos = value.add()._InternalParse(buffer, pos, end)
656        # Read end tag.
657        new_pos = pos+end_tag_len
658        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
659          raise _DecodeError('Missing group end tag.')
660        # Predict that the next tag is another copy of the same repeated field.
661        pos = new_pos + tag_len
662        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
663          # Prediction failed.  Return.
664          return new_pos
665    return DecodeRepeatedField
666  else:
667    def DecodeField(buffer, pos, end, message, field_dict):
668      value = field_dict.get(key)
669      if value is None:
670        value = field_dict.setdefault(key, new_default(message))
671      # Read sub-message.
672      pos = value._InternalParse(buffer, pos, end)
673      # Read end tag.
674      new_pos = pos+end_tag_len
675      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
676        raise _DecodeError('Missing group end tag.')
677      return new_pos
678    return DecodeField
679
680
681def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
682  """Returns a decoder for a message field."""
683
684  local_DecodeVarint = _DecodeVarint
685
686  assert not is_packed
687  if is_repeated:
688    tag_bytes = encoder.TagBytes(field_number,
689                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
690    tag_len = len(tag_bytes)
691    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
692      value = field_dict.get(key)
693      if value is None:
694        value = field_dict.setdefault(key, new_default(message))
695      while 1:
696        # Read length.
697        (size, pos) = local_DecodeVarint(buffer, pos)
698        new_pos = pos + size
699        if new_pos > end:
700          raise _DecodeError('Truncated message.')
701        # Read sub-message.
702        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
703          # The only reason _InternalParse would return early is if it
704          # encountered an end-group tag.
705          raise _DecodeError('Unexpected end-group tag.')
706        # Predict that the next tag is another copy of the same repeated field.
707        pos = new_pos + tag_len
708        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
709          # Prediction failed.  Return.
710          return new_pos
711    return DecodeRepeatedField
712  else:
713    def DecodeField(buffer, pos, end, message, field_dict):
714      value = field_dict.get(key)
715      if value is None:
716        value = field_dict.setdefault(key, new_default(message))
717      # Read length.
718      (size, pos) = local_DecodeVarint(buffer, pos)
719      new_pos = pos + size
720      if new_pos > end:
721        raise _DecodeError('Truncated message.')
722      # Read sub-message.
723      if value._InternalParse(buffer, pos, new_pos) != new_pos:
724        # The only reason _InternalParse would return early is if it encountered
725        # an end-group tag.
726        raise _DecodeError('Unexpected end-group tag.')
727      return new_pos
728    return DecodeField
729
730
731# --------------------------------------------------------------------
732
733MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
734
735def MessageSetItemDecoder(descriptor):
736  """Returns a decoder for a MessageSet item.
737
738  The parameter is the message Descriptor.
739
740  The message set message looks like this:
741    message MessageSet {
742      repeated group Item = 1 {
743        required int32 type_id = 2;
744        required string message = 3;
745      }
746    }
747  """
748
749  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
750  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
751  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
752
753  local_ReadTag = ReadTag
754  local_DecodeVarint = _DecodeVarint
755  local_SkipField = SkipField
756
757  def DecodeItem(buffer, pos, end, message, field_dict):
758    """Decode serialized message set to its value and new position.
759
760    Args:
761      buffer: memoryview of the serialized bytes.
762      pos: int, position in the memory view to start at.
763      end: int, end position of serialized data
764      message: Message object to store unknown fields in
765      field_dict: Map[Descriptor, Any] to store decoded values in.
766
767    Returns:
768      int, new position in serialized data.
769    """
770    message_set_item_start = pos
771    type_id = -1
772    message_start = -1
773    message_end = -1
774
775    # Technically, type_id and message can appear in any order, so we need
776    # a little loop here.
777    while 1:
778      (tag_bytes, pos) = local_ReadTag(buffer, pos)
779      if tag_bytes == type_id_tag_bytes:
780        (type_id, pos) = local_DecodeVarint(buffer, pos)
781      elif tag_bytes == message_tag_bytes:
782        (size, message_start) = local_DecodeVarint(buffer, pos)
783        pos = message_end = message_start + size
784      elif tag_bytes == item_end_tag_bytes:
785        break
786      else:
787        pos = SkipField(buffer, pos, end, tag_bytes)
788        if pos == -1:
789          raise _DecodeError('Missing group end tag.')
790
791    if pos > end:
792      raise _DecodeError('Truncated message.')
793
794    if type_id == -1:
795      raise _DecodeError('MessageSet item missing type_id.')
796    if message_start == -1:
797      raise _DecodeError('MessageSet item missing message.')
798
799    extension = message.Extensions._FindExtensionByNumber(type_id)
800    # pylint: disable=protected-access
801    if extension is not None:
802      value = field_dict.get(extension)
803      if value is None:
804        value = field_dict.setdefault(
805            extension, extension.message_type._concrete_class())
806      if value._InternalParse(buffer, message_start,message_end) != message_end:
807        # The only reason _InternalParse would return early is if it encountered
808        # an end-group tag.
809        raise _DecodeError('Unexpected end-group tag.')
810    else:
811      if not message._unknown_fields:
812        message._unknown_fields = []
813      message._unknown_fields.append(
814          (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
815      # pylint: enable=protected-access
816
817    return pos
818
819  return DecodeItem
820
821# --------------------------------------------------------------------
822
823def MapDecoder(field_descriptor, new_default, is_message_map):
824  """Returns a decoder for a map field."""
825
826  key = field_descriptor
827  tag_bytes = encoder.TagBytes(field_descriptor.number,
828                               wire_format.WIRETYPE_LENGTH_DELIMITED)
829  tag_len = len(tag_bytes)
830  local_DecodeVarint = _DecodeVarint
831  # Can't read _concrete_class yet; might not be initialized.
832  message_type = field_descriptor.message_type
833
834  def DecodeMap(buffer, pos, end, message, field_dict):
835    submsg = message_type._concrete_class()
836    value = field_dict.get(key)
837    if value is None:
838      value = field_dict.setdefault(key, new_default(message))
839    while 1:
840      # Read length.
841      (size, pos) = local_DecodeVarint(buffer, pos)
842      new_pos = pos + size
843      if new_pos > end:
844        raise _DecodeError('Truncated message.')
845      # Read sub-message.
846      submsg.Clear()
847      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
848        # The only reason _InternalParse would return early is if it
849        # encountered an end-group tag.
850        raise _DecodeError('Unexpected end-group tag.')
851
852      if is_message_map:
853        value[submsg.key].MergeFrom(submsg.value)
854      else:
855        value[submsg.key] = submsg.value
856
857      # Predict that the next tag is another copy of the same repeated field.
858      pos = new_pos + tag_len
859      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
860        # Prediction failed.  Return.
861        return new_pos
862
863  return DecodeMap
864
865# --------------------------------------------------------------------
866# Optimization is not as heavy here because calls to SkipField() are rare,
867# except for handling end-group tags.
868
869def _SkipVarint(buffer, pos, end):
870  """Skip a varint value.  Returns the new position."""
871  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
872  # With this code, ord(b'') raises TypeError.  Both are handled in
873  # python_message.py to generate a 'Truncated message' error.
874  while ord(buffer[pos:pos+1].tobytes()) & 0x80:
875    pos += 1
876  pos += 1
877  if pos > end:
878    raise _DecodeError('Truncated message.')
879  return pos
880
881def _SkipFixed64(buffer, pos, end):
882  """Skip a fixed64 value.  Returns the new position."""
883
884  pos += 8
885  if pos > end:
886    raise _DecodeError('Truncated message.')
887  return pos
888
889
890def _DecodeFixed64(buffer, pos):
891  """Decode a fixed64."""
892  new_pos = pos + 8
893  return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
894
895
896def _SkipLengthDelimited(buffer, pos, end):
897  """Skip a length-delimited value.  Returns the new position."""
898
899  (size, pos) = _DecodeVarint(buffer, pos)
900  pos += size
901  if pos > end:
902    raise _DecodeError('Truncated message.')
903  return pos
904
905
906def _SkipGroup(buffer, pos, end):
907  """Skip sub-group.  Returns the new position."""
908
909  while 1:
910    (tag_bytes, pos) = ReadTag(buffer, pos)
911    new_pos = SkipField(buffer, pos, end, tag_bytes)
912    if new_pos == -1:
913      return pos
914    pos = new_pos
915
916
917def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
918  """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
919
920  unknown_field_set = containers.UnknownFieldSet()
921  while end_pos is None or pos < end_pos:
922    (tag_bytes, pos) = ReadTag(buffer, pos)
923    (tag, _) = _DecodeVarint(tag_bytes, 0)
924    field_number, wire_type = wire_format.UnpackTag(tag)
925    if wire_type == wire_format.WIRETYPE_END_GROUP:
926      break
927    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
928    # pylint: disable=protected-access
929    unknown_field_set._add(field_number, wire_type, data)
930
931  return (unknown_field_set, pos)
932
933
934def _DecodeUnknownField(buffer, pos, wire_type):
935  """Decode a unknown field.  Returns the UnknownField and new position."""
936
937  if wire_type == wire_format.WIRETYPE_VARINT:
938    (data, pos) = _DecodeVarint(buffer, pos)
939  elif wire_type == wire_format.WIRETYPE_FIXED64:
940    (data, pos) = _DecodeFixed64(buffer, pos)
941  elif wire_type == wire_format.WIRETYPE_FIXED32:
942    (data, pos) = _DecodeFixed32(buffer, pos)
943  elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
944    (size, pos) = _DecodeVarint(buffer, pos)
945    data = buffer[pos:pos+size]
946    pos += size
947  elif wire_type == wire_format.WIRETYPE_START_GROUP:
948    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
949  elif wire_type == wire_format.WIRETYPE_END_GROUP:
950    return (0, -1)
951  else:
952    raise _DecodeError('Wrong wire type in tag.')
953
954  return (data, pos)
955
956
957def _EndGroup(buffer, pos, end):
958  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
959
960  return -1
961
962
963def _SkipFixed32(buffer, pos, end):
964  """Skip a fixed32 value.  Returns the new position."""
965
966  pos += 4
967  if pos > end:
968    raise _DecodeError('Truncated message.')
969  return pos
970
971
972def _DecodeFixed32(buffer, pos):
973  """Decode a fixed32."""
974
975  new_pos = pos + 4
976  return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
977
978
979def _RaiseInvalidWireType(buffer, pos, end):
980  """Skip function for unknown wire types.  Raises an exception."""
981
982  raise _DecodeError('Tag had invalid wire type.')
983
984def _FieldSkipper():
985  """Constructs the SkipField function."""
986
987  WIRETYPE_TO_SKIPPER = [
988      _SkipVarint,
989      _SkipFixed64,
990      _SkipLengthDelimited,
991      _SkipGroup,
992      _EndGroup,
993      _SkipFixed32,
994      _RaiseInvalidWireType,
995      _RaiseInvalidWireType,
996      ]
997
998  wiretype_mask = wire_format.TAG_TYPE_MASK
999
1000  def SkipField(buffer, pos, end, tag_bytes):
1001    """Skips a field with the specified tag.
1002
1003    |pos| should point to the byte immediately after the tag.
1004
1005    Returns:
1006        The new position (after the tag value), or -1 if the tag is an end-group
1007        tag (in which case the calling loop should break).
1008    """
1009
1010    # The wire type is always in the first byte since varints are little-endian.
1011    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
1012    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
1013
1014  return SkipField
1015
1016SkipField = _FieldSkipper()
1017