1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for decoding protocol buffer primitives.
32
33This code is very similar to encoder.py -- read the docs for that module first.
34
35A "decoder" is a function with the signature:
36  Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38  buffer:     The string containing the encoded message.
39  pos:        The current position in the string.
40  end:        The position in the string where the current message ends.  May be
41              less than len(buffer) if we're reading a sub-message.
42  message:    The message object into which we're parsing.
43  field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position.  A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
47
48Note that decoders may throw any of the following:
49  IndexError:  Indicates a truncated message.
50  struct.error:  Unpacking of a fixed-width field failed.
51  message.DecodeError:  Other errors.
52
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking:  it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
57
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
60
61Decoders are constructed using decoder constructors with the signature:
62  MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64  field_number:  The field number of the field we want to decode.
65  is_repeated:   Is the field a repeated field? (bool)
66  is_packed:     Is the field a packed field? (bool)
67  key:           The key to use when looking up the field within field_dict.
68                 (This is actually the FieldDescriptor but nothing in this
69                 file should depend on that.)
70  new_default:   A function which takes a message object as a parameter and
71                 returns a new instance of the default value for this field.
72                 (This is called for repeated fields and sub-messages, when an
73                 instance does not already exist.)
74
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
80
81__author__ = 'kenton@google.com (Kenton Varda)'
82
83import struct
84
85import six
86
87if six.PY3:
88  long = int
89
90from google.protobuf.internal import encoder
91from google.protobuf.internal import wire_format
92from google.protobuf import message
93
94
95# This will overflow and thus become IEEE-754 "infinity".  We would use
96# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
97_POS_INF = 1e10000
98_NEG_INF = -_POS_INF
99_NAN = _POS_INF * 0
100
101
102# This is not for optimization, but rather to avoid conflicts with local
103# variables named "message".
104_DecodeError = message.DecodeError
105
106
107def _VarintDecoder(mask, result_type):
108  """Return an encoder for a basic varint value (does not include tag).
109
110  Decoded values will be bitwise-anded with the given mask before being
111  returned, e.g. to limit them to 32 bits.  The returned decoder does not
112  take the usual "end" parameter -- the caller is expected to do bounds checking
113  after the fact (often the caller can defer such checking until later).  The
114  decoder returns a (value, new_pos) pair.
115  """
116
117  def DecodeVarint(buffer, pos):
118    result = 0
119    shift = 0
120    while 1:
121      b = six.indexbytes(buffer, pos)
122      result |= ((b & 0x7f) << shift)
123      pos += 1
124      if not (b & 0x80):
125        result &= mask
126        result = result_type(result)
127        return (result, pos)
128      shift += 7
129      if shift >= 64:
130        raise _DecodeError('Too many bytes when decoding varint.')
131  return DecodeVarint
132
133
134def _SignedVarintDecoder(mask, result_type):
135  """Like _VarintDecoder() but decodes signed values."""
136
137  def DecodeVarint(buffer, pos):
138    result = 0
139    shift = 0
140    while 1:
141      b = six.indexbytes(buffer, pos)
142      result |= ((b & 0x7f) << shift)
143      pos += 1
144      if not (b & 0x80):
145        if result > 0x7fffffffffffffff:
146          result -= (1 << 64)
147          result |= ~mask
148        else:
149          result &= mask
150        result = result_type(result)
151        return (result, pos)
152      shift += 7
153      if shift >= 64:
154        raise _DecodeError('Too many bytes when decoding varint.')
155  return DecodeVarint
156
157# We force 32-bit values to int and 64-bit values to long to make
158# alternate implementations where the distinction is more significant
159# (e.g. the C++ implementation) simpler.
160
161_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
162_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
163
164# Use these versions for values which must be limited to 32 bits.
165_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
166_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
167
168
169def ReadTag(buffer, pos):
170  """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
171
172  We return the raw bytes of the tag rather than decoding them.  The raw
173  bytes can then be used to look up the proper decoder.  This effectively allows
174  us to trade some work that would be done in pure-python (decoding a varint)
175  for work that is done in C (searching for a byte string in a hash table).
176  In a low-level language it would be much cheaper to decode the varint and
177  use that, but not in Python.
178  """
179
180  start = pos
181  while six.indexbytes(buffer, pos) & 0x80:
182    pos += 1
183  pos += 1
184  return (buffer[start:pos], pos)
185
186
187# --------------------------------------------------------------------
188
189
190def _SimpleDecoder(wire_type, decode_value):
191  """Return a constructor for a decoder for fields of a particular type.
192
193  Args:
194      wire_type:  The field's wire type.
195      decode_value:  A function which decodes an individual value, e.g.
196        _DecodeVarint()
197  """
198
199  def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
200    if is_packed:
201      local_DecodeVarint = _DecodeVarint
202      def DecodePackedField(buffer, pos, end, message, field_dict):
203        value = field_dict.get(key)
204        if value is None:
205          value = field_dict.setdefault(key, new_default(message))
206        (endpoint, pos) = local_DecodeVarint(buffer, pos)
207        endpoint += pos
208        if endpoint > end:
209          raise _DecodeError('Truncated message.')
210        while pos < endpoint:
211          (element, pos) = decode_value(buffer, pos)
212          value.append(element)
213        if pos > endpoint:
214          del value[-1]   # Discard corrupt value.
215          raise _DecodeError('Packed element was truncated.')
216        return pos
217      return DecodePackedField
218    elif is_repeated:
219      tag_bytes = encoder.TagBytes(field_number, wire_type)
220      tag_len = len(tag_bytes)
221      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
222        value = field_dict.get(key)
223        if value is None:
224          value = field_dict.setdefault(key, new_default(message))
225        while 1:
226          (element, new_pos) = decode_value(buffer, pos)
227          value.append(element)
228          # Predict that the next tag is another copy of the same repeated
229          # field.
230          pos = new_pos + tag_len
231          if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
232            # Prediction failed.  Return.
233            if new_pos > end:
234              raise _DecodeError('Truncated message.')
235            return new_pos
236      return DecodeRepeatedField
237    else:
238      def DecodeField(buffer, pos, end, message, field_dict):
239        (field_dict[key], pos) = decode_value(buffer, pos)
240        if pos > end:
241          del field_dict[key]  # Discard corrupt value.
242          raise _DecodeError('Truncated message.')
243        return pos
244      return DecodeField
245
246  return SpecificDecoder
247
248
249def _ModifiedDecoder(wire_type, decode_value, modify_value):
250  """Like SimpleDecoder but additionally invokes modify_value on every value
251  before storing it.  Usually modify_value is ZigZagDecode.
252  """
253
254  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
255  # not enough to make a significant difference.
256
257  def InnerDecode(buffer, pos):
258    (result, new_pos) = decode_value(buffer, pos)
259    return (modify_value(result), new_pos)
260  return _SimpleDecoder(wire_type, InnerDecode)
261
262
263def _StructPackDecoder(wire_type, format):
264  """Return a constructor for a decoder for a fixed-width field.
265
266  Args:
267      wire_type:  The field's wire type.
268      format:  The format string to pass to struct.unpack().
269  """
270
271  value_size = struct.calcsize(format)
272  local_unpack = struct.unpack
273
274  # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
275  # not enough to make a significant difference.
276
277  # Note that we expect someone up-stack to catch struct.error and convert
278  # it to _DecodeError -- this way we don't have to set up exception-
279  # handling blocks every time we parse one value.
280
281  def InnerDecode(buffer, pos):
282    new_pos = pos + value_size
283    result = local_unpack(format, buffer[pos:new_pos])[0]
284    return (result, new_pos)
285  return _SimpleDecoder(wire_type, InnerDecode)
286
287
288def _FloatDecoder():
289  """Returns a decoder for a float field.
290
291  This code works around a bug in struct.unpack for non-finite 32-bit
292  floating-point values.
293  """
294
295  local_unpack = struct.unpack
296
297  def InnerDecode(buffer, pos):
298    # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
299    # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
300    new_pos = pos + 4
301    float_bytes = buffer[pos:new_pos]
302
303    # If this value has all its exponent bits set, then it's non-finite.
304    # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
305    # To avoid that, we parse it specially.
306    if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
307      # If at least one significand bit is set...
308      if float_bytes[0:3] != b'\x00\x00\x80':
309        return (_NAN, new_pos)
310      # If sign bit is set...
311      if float_bytes[3:4] == b'\xFF':
312        return (_NEG_INF, new_pos)
313      return (_POS_INF, new_pos)
314
315    # Note that we expect someone up-stack to catch struct.error and convert
316    # it to _DecodeError -- this way we don't have to set up exception-
317    # handling blocks every time we parse one value.
318    result = local_unpack('<f', float_bytes)[0]
319    return (result, new_pos)
320  return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
321
322
323def _DoubleDecoder():
324  """Returns a decoder for a double field.
325
326  This code works around a bug in struct.unpack for not-a-number.
327  """
328
329  local_unpack = struct.unpack
330
331  def InnerDecode(buffer, pos):
332    # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
333    # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
334    new_pos = pos + 8
335    double_bytes = buffer[pos:new_pos]
336
337    # If this value has all its exponent bits set and at least one significand
338    # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
339    # as inf or -inf.  To avoid that, we treat it specially.
340    if ((double_bytes[7:8] in b'\x7F\xFF')
341        and (double_bytes[6:7] >= b'\xF0')
342        and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
343      return (_NAN, new_pos)
344
345    # Note that we expect someone up-stack to catch struct.error and convert
346    # it to _DecodeError -- this way we don't have to set up exception-
347    # handling blocks every time we parse one value.
348    result = local_unpack('<d', double_bytes)[0]
349    return (result, new_pos)
350  return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
351
352
353def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
354  enum_type = key.enum_type
355  if is_packed:
356    local_DecodeVarint = _DecodeVarint
357    def DecodePackedField(buffer, pos, end, message, field_dict):
358      value = field_dict.get(key)
359      if value is None:
360        value = field_dict.setdefault(key, new_default(message))
361      (endpoint, pos) = local_DecodeVarint(buffer, pos)
362      endpoint += pos
363      if endpoint > end:
364        raise _DecodeError('Truncated message.')
365      while pos < endpoint:
366        value_start_pos = pos
367        (element, pos) = _DecodeSignedVarint32(buffer, pos)
368        if element in enum_type.values_by_number:
369          value.append(element)
370        else:
371          if not message._unknown_fields:
372            message._unknown_fields = []
373          tag_bytes = encoder.TagBytes(field_number,
374                                       wire_format.WIRETYPE_VARINT)
375          message._unknown_fields.append(
376              (tag_bytes, buffer[value_start_pos:pos]))
377      if pos > endpoint:
378        if element in enum_type.values_by_number:
379          del value[-1]   # Discard corrupt value.
380        else:
381          del message._unknown_fields[-1]
382        raise _DecodeError('Packed element was truncated.')
383      return pos
384    return DecodePackedField
385  elif is_repeated:
386    tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
387    tag_len = len(tag_bytes)
388    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
389      value = field_dict.get(key)
390      if value is None:
391        value = field_dict.setdefault(key, new_default(message))
392      while 1:
393        (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
394        if element in enum_type.values_by_number:
395          value.append(element)
396        else:
397          if not message._unknown_fields:
398            message._unknown_fields = []
399          message._unknown_fields.append(
400              (tag_bytes, buffer[pos:new_pos]))
401        # Predict that the next tag is another copy of the same repeated
402        # field.
403        pos = new_pos + tag_len
404        if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
405          # Prediction failed.  Return.
406          if new_pos > end:
407            raise _DecodeError('Truncated message.')
408          return new_pos
409    return DecodeRepeatedField
410  else:
411    def DecodeField(buffer, pos, end, message, field_dict):
412      value_start_pos = pos
413      (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
414      if pos > end:
415        raise _DecodeError('Truncated message.')
416      if enum_value in enum_type.values_by_number:
417        field_dict[key] = enum_value
418      else:
419        if not message._unknown_fields:
420          message._unknown_fields = []
421        tag_bytes = encoder.TagBytes(field_number,
422                                     wire_format.WIRETYPE_VARINT)
423        message._unknown_fields.append(
424          (tag_bytes, buffer[value_start_pos:pos]))
425      return pos
426    return DecodeField
427
428
429# --------------------------------------------------------------------
430
431
432Int32Decoder = _SimpleDecoder(
433    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
434
435Int64Decoder = _SimpleDecoder(
436    wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
437
438UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
439UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
440
441SInt32Decoder = _ModifiedDecoder(
442    wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
443SInt64Decoder = _ModifiedDecoder(
444    wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
445
446# Note that Python conveniently guarantees that when using the '<' prefix on
447# formats, they will also have the same size across all platforms (as opposed
448# to without the prefix, where their sizes depend on the C compiler's basic
449# type sizes).
450Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
451Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
452SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
453SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
454FloatDecoder = _FloatDecoder()
455DoubleDecoder = _DoubleDecoder()
456
457BoolDecoder = _ModifiedDecoder(
458    wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
459
460
461def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
462  """Returns a decoder for a string field."""
463
464  local_DecodeVarint = _DecodeVarint
465  local_unicode = six.text_type
466
467  def _ConvertToUnicode(byte_str):
468    try:
469      return local_unicode(byte_str, 'utf-8')
470    except UnicodeDecodeError as e:
471      # add more information to the error message and re-raise it.
472      e.reason = '%s in field: %s' % (e, key.full_name)
473      raise
474
475  assert not is_packed
476  if is_repeated:
477    tag_bytes = encoder.TagBytes(field_number,
478                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
479    tag_len = len(tag_bytes)
480    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
481      value = field_dict.get(key)
482      if value is None:
483        value = field_dict.setdefault(key, new_default(message))
484      while 1:
485        (size, pos) = local_DecodeVarint(buffer, pos)
486        new_pos = pos + size
487        if new_pos > end:
488          raise _DecodeError('Truncated string.')
489        value.append(_ConvertToUnicode(buffer[pos:new_pos]))
490        # Predict that the next tag is another copy of the same repeated field.
491        pos = new_pos + tag_len
492        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
493          # Prediction failed.  Return.
494          return new_pos
495    return DecodeRepeatedField
496  else:
497    def DecodeField(buffer, pos, end, message, field_dict):
498      (size, pos) = local_DecodeVarint(buffer, pos)
499      new_pos = pos + size
500      if new_pos > end:
501        raise _DecodeError('Truncated string.')
502      field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
503      return new_pos
504    return DecodeField
505
506
507def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
508  """Returns a decoder for a bytes field."""
509
510  local_DecodeVarint = _DecodeVarint
511
512  assert not is_packed
513  if is_repeated:
514    tag_bytes = encoder.TagBytes(field_number,
515                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
516    tag_len = len(tag_bytes)
517    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
518      value = field_dict.get(key)
519      if value is None:
520        value = field_dict.setdefault(key, new_default(message))
521      while 1:
522        (size, pos) = local_DecodeVarint(buffer, pos)
523        new_pos = pos + size
524        if new_pos > end:
525          raise _DecodeError('Truncated string.')
526        value.append(buffer[pos:new_pos])
527        # Predict that the next tag is another copy of the same repeated field.
528        pos = new_pos + tag_len
529        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
530          # Prediction failed.  Return.
531          return new_pos
532    return DecodeRepeatedField
533  else:
534    def DecodeField(buffer, pos, end, message, field_dict):
535      (size, pos) = local_DecodeVarint(buffer, pos)
536      new_pos = pos + size
537      if new_pos > end:
538        raise _DecodeError('Truncated string.')
539      field_dict[key] = buffer[pos:new_pos]
540      return new_pos
541    return DecodeField
542
543
544def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
545  """Returns a decoder for a group field."""
546
547  end_tag_bytes = encoder.TagBytes(field_number,
548                                   wire_format.WIRETYPE_END_GROUP)
549  end_tag_len = len(end_tag_bytes)
550
551  assert not is_packed
552  if is_repeated:
553    tag_bytes = encoder.TagBytes(field_number,
554                                 wire_format.WIRETYPE_START_GROUP)
555    tag_len = len(tag_bytes)
556    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557      value = field_dict.get(key)
558      if value is None:
559        value = field_dict.setdefault(key, new_default(message))
560      while 1:
561        value = field_dict.get(key)
562        if value is None:
563          value = field_dict.setdefault(key, new_default(message))
564        # Read sub-message.
565        pos = value.add()._InternalParse(buffer, pos, end)
566        # Read end tag.
567        new_pos = pos+end_tag_len
568        if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
569          raise _DecodeError('Missing group end tag.')
570        # Predict that the next tag is another copy of the same repeated field.
571        pos = new_pos + tag_len
572        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
573          # Prediction failed.  Return.
574          return new_pos
575    return DecodeRepeatedField
576  else:
577    def DecodeField(buffer, pos, end, message, field_dict):
578      value = field_dict.get(key)
579      if value is None:
580        value = field_dict.setdefault(key, new_default(message))
581      # Read sub-message.
582      pos = value._InternalParse(buffer, pos, end)
583      # Read end tag.
584      new_pos = pos+end_tag_len
585      if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
586        raise _DecodeError('Missing group end tag.')
587      return new_pos
588    return DecodeField
589
590
591def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
592  """Returns a decoder for a message field."""
593
594  local_DecodeVarint = _DecodeVarint
595
596  assert not is_packed
597  if is_repeated:
598    tag_bytes = encoder.TagBytes(field_number,
599                                 wire_format.WIRETYPE_LENGTH_DELIMITED)
600    tag_len = len(tag_bytes)
601    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
602      value = field_dict.get(key)
603      if value is None:
604        value = field_dict.setdefault(key, new_default(message))
605      while 1:
606        # Read length.
607        (size, pos) = local_DecodeVarint(buffer, pos)
608        new_pos = pos + size
609        if new_pos > end:
610          raise _DecodeError('Truncated message.')
611        # Read sub-message.
612        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
613          # The only reason _InternalParse would return early is if it
614          # encountered an end-group tag.
615          raise _DecodeError('Unexpected end-group tag.')
616        # Predict that the next tag is another copy of the same repeated field.
617        pos = new_pos + tag_len
618        if buffer[new_pos:pos] != tag_bytes or new_pos == end:
619          # Prediction failed.  Return.
620          return new_pos
621    return DecodeRepeatedField
622  else:
623    def DecodeField(buffer, pos, end, message, field_dict):
624      value = field_dict.get(key)
625      if value is None:
626        value = field_dict.setdefault(key, new_default(message))
627      # Read length.
628      (size, pos) = local_DecodeVarint(buffer, pos)
629      new_pos = pos + size
630      if new_pos > end:
631        raise _DecodeError('Truncated message.')
632      # Read sub-message.
633      if value._InternalParse(buffer, pos, new_pos) != new_pos:
634        # The only reason _InternalParse would return early is if it encountered
635        # an end-group tag.
636        raise _DecodeError('Unexpected end-group tag.')
637      return new_pos
638    return DecodeField
639
640
641# --------------------------------------------------------------------
642
643MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
644
645def MessageSetItemDecoder(extensions_by_number):
646  """Returns a decoder for a MessageSet item.
647
648  The parameter is the _extensions_by_number map for the message class.
649
650  The message set message looks like this:
651    message MessageSet {
652      repeated group Item = 1 {
653        required int32 type_id = 2;
654        required string message = 3;
655      }
656    }
657  """
658
659  type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
660  message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
661  item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
662
663  local_ReadTag = ReadTag
664  local_DecodeVarint = _DecodeVarint
665  local_SkipField = SkipField
666
667  def DecodeItem(buffer, pos, end, message, field_dict):
668    message_set_item_start = pos
669    type_id = -1
670    message_start = -1
671    message_end = -1
672
673    # Technically, type_id and message can appear in any order, so we need
674    # a little loop here.
675    while 1:
676      (tag_bytes, pos) = local_ReadTag(buffer, pos)
677      if tag_bytes == type_id_tag_bytes:
678        (type_id, pos) = local_DecodeVarint(buffer, pos)
679      elif tag_bytes == message_tag_bytes:
680        (size, message_start) = local_DecodeVarint(buffer, pos)
681        pos = message_end = message_start + size
682      elif tag_bytes == item_end_tag_bytes:
683        break
684      else:
685        pos = SkipField(buffer, pos, end, tag_bytes)
686        if pos == -1:
687          raise _DecodeError('Missing group end tag.')
688
689    if pos > end:
690      raise _DecodeError('Truncated message.')
691
692    if type_id == -1:
693      raise _DecodeError('MessageSet item missing type_id.')
694    if message_start == -1:
695      raise _DecodeError('MessageSet item missing message.')
696
697    extension = extensions_by_number.get(type_id)
698    if extension is not None:
699      value = field_dict.get(extension)
700      if value is None:
701        value = field_dict.setdefault(
702            extension, extension.message_type._concrete_class())
703      if value._InternalParse(buffer, message_start,message_end) != message_end:
704        # The only reason _InternalParse would return early is if it encountered
705        # an end-group tag.
706        raise _DecodeError('Unexpected end-group tag.')
707    else:
708      if not message._unknown_fields:
709        message._unknown_fields = []
710      message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
711                                      buffer[message_set_item_start:pos]))
712
713    return pos
714
715  return DecodeItem
716
717# --------------------------------------------------------------------
718
719def MapDecoder(field_descriptor, new_default, is_message_map):
720  """Returns a decoder for a map field."""
721
722  key = field_descriptor
723  tag_bytes = encoder.TagBytes(field_descriptor.number,
724                               wire_format.WIRETYPE_LENGTH_DELIMITED)
725  tag_len = len(tag_bytes)
726  local_DecodeVarint = _DecodeVarint
727  # Can't read _concrete_class yet; might not be initialized.
728  message_type = field_descriptor.message_type
729
730  def DecodeMap(buffer, pos, end, message, field_dict):
731    submsg = message_type._concrete_class()
732    value = field_dict.get(key)
733    if value is None:
734      value = field_dict.setdefault(key, new_default(message))
735    while 1:
736      # Read length.
737      (size, pos) = local_DecodeVarint(buffer, pos)
738      new_pos = pos + size
739      if new_pos > end:
740        raise _DecodeError('Truncated message.')
741      # Read sub-message.
742      submsg.Clear()
743      if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
744        # The only reason _InternalParse would return early is if it
745        # encountered an end-group tag.
746        raise _DecodeError('Unexpected end-group tag.')
747
748      if is_message_map:
749        value[submsg.key].MergeFrom(submsg.value)
750      else:
751        value[submsg.key] = submsg.value
752
753      # Predict that the next tag is another copy of the same repeated field.
754      pos = new_pos + tag_len
755      if buffer[new_pos:pos] != tag_bytes or new_pos == end:
756        # Prediction failed.  Return.
757        return new_pos
758
759  return DecodeMap
760
761# --------------------------------------------------------------------
762# Optimization is not as heavy here because calls to SkipField() are rare,
763# except for handling end-group tags.
764
765def _SkipVarint(buffer, pos, end):
766  """Skip a varint value.  Returns the new position."""
767  # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
768  # With this code, ord(b'') raises TypeError.  Both are handled in
769  # python_message.py to generate a 'Truncated message' error.
770  while ord(buffer[pos:pos+1]) & 0x80:
771    pos += 1
772  pos += 1
773  if pos > end:
774    raise _DecodeError('Truncated message.')
775  return pos
776
777def _SkipFixed64(buffer, pos, end):
778  """Skip a fixed64 value.  Returns the new position."""
779
780  pos += 8
781  if pos > end:
782    raise _DecodeError('Truncated message.')
783  return pos
784
785def _SkipLengthDelimited(buffer, pos, end):
786  """Skip a length-delimited value.  Returns the new position."""
787
788  (size, pos) = _DecodeVarint(buffer, pos)
789  pos += size
790  if pos > end:
791    raise _DecodeError('Truncated message.')
792  return pos
793
794def _SkipGroup(buffer, pos, end):
795  """Skip sub-group.  Returns the new position."""
796
797  while 1:
798    (tag_bytes, pos) = ReadTag(buffer, pos)
799    new_pos = SkipField(buffer, pos, end, tag_bytes)
800    if new_pos == -1:
801      return pos
802    pos = new_pos
803
804def _EndGroup(buffer, pos, end):
805  """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
806
807  return -1
808
809def _SkipFixed32(buffer, pos, end):
810  """Skip a fixed32 value.  Returns the new position."""
811
812  pos += 4
813  if pos > end:
814    raise _DecodeError('Truncated message.')
815  return pos
816
817def _RaiseInvalidWireType(buffer, pos, end):
818  """Skip function for unknown wire types.  Raises an exception."""
819
820  raise _DecodeError('Tag had invalid wire type.')
821
822def _FieldSkipper():
823  """Constructs the SkipField function."""
824
825  WIRETYPE_TO_SKIPPER = [
826      _SkipVarint,
827      _SkipFixed64,
828      _SkipLengthDelimited,
829      _SkipGroup,
830      _EndGroup,
831      _SkipFixed32,
832      _RaiseInvalidWireType,
833      _RaiseInvalidWireType,
834      ]
835
836  wiretype_mask = wire_format.TAG_TYPE_MASK
837
838  def SkipField(buffer, pos, end, tag_bytes):
839    """Skips a field with the specified tag.
840
841    |pos| should point to the byte immediately after the tag.
842
843    Returns:
844        The new position (after the tag value), or -1 if the tag is an end-group
845        tag (in which case the calling loop should break).
846    """
847
848    # The wire type is always in the first byte since varints are little-endian.
849    wire_type = ord(tag_bytes[0:1]) & wiretype_mask
850    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
851
852  return SkipField
853
854SkipField = _FieldSkipper()
855