1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for encoding protocol message primitives.
32
33Contains the logic for encoding every logical protocol field type
34into one of the 5 physical wire types.
35
36This code is designed to push the Python interpreter's performance to the
37limits.
38
39The basic idea is that at startup time, for every field (i.e. every
40FieldDescriptor) we construct two functions:  a "sizer" and an "encoder".  The
41sizer takes a value of this field's type and computes its byte size.  The
42encoder takes a writer function and a value.  It encodes the value into byte
43strings and invokes the writer function to write those strings.  Typically the
44writer function is the write() method of a BytesIO.
45
46We try to do as much work as possible when constructing the writer and the
47sizer rather than when calling them.  In particular:
48* We copy any needed global functions to local variables, so that we do not need
49  to do costly global table lookups at runtime.
50* Similarly, we try to do any attribute lookups at startup time if possible.
51* Every field's tag is encoded to bytes at startup, since it can't change at
52  runtime.
53* Whatever component of the field size we can compute at startup, we do.
54* We *avoid* sharing code if doing so would make the code slower and not sharing
55  does not burden us too much.  For example, encoders for repeated fields do
56  not just call the encoders for singular fields in a loop because this would
57  add an extra function call overhead for every loop iteration; instead, we
58  manually inline the single-value encoder into the loop.
59* If a Python function lacks a return statement, Python actually generates
60  instructions to pop the result of the last statement off the stack, push
61  None onto the stack, and then return that.  If we really don't care what
62  value is returned, then we can save two instructions by returning the
63  result of the last statement.  It looks funny but it helps.
64* We assume that type and bounds checking has happened at a higher level.
65"""
66
67__author__ = 'kenton@google.com (Kenton Varda)'
68
69import struct
70
71import six
72
73from google.protobuf.internal import wire_format
74
75
76# This will overflow and thus become IEEE-754 "infinity".  We would use
77# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
78_POS_INF = 1e10000
79_NEG_INF = -_POS_INF
80
81
82def _VarintSize(value):
83  """Compute the size of a varint value."""
84  if value <= 0x7f: return 1
85  if value <= 0x3fff: return 2
86  if value <= 0x1fffff: return 3
87  if value <= 0xfffffff: return 4
88  if value <= 0x7ffffffff: return 5
89  if value <= 0x3ffffffffff: return 6
90  if value <= 0x1ffffffffffff: return 7
91  if value <= 0xffffffffffffff: return 8
92  if value <= 0x7fffffffffffffff: return 9
93  return 10
94
95
96def _SignedVarintSize(value):
97  """Compute the size of a signed varint value."""
98  if value < 0: return 10
99  if value <= 0x7f: return 1
100  if value <= 0x3fff: return 2
101  if value <= 0x1fffff: return 3
102  if value <= 0xfffffff: return 4
103  if value <= 0x7ffffffff: return 5
104  if value <= 0x3ffffffffff: return 6
105  if value <= 0x1ffffffffffff: return 7
106  if value <= 0xffffffffffffff: return 8
107  if value <= 0x7fffffffffffffff: return 9
108  return 10
109
110
111def _TagSize(field_number):
112  """Returns the number of bytes required to serialize a tag with this field
113  number."""
114  # Just pass in type 0, since the type won't affect the tag+type size.
115  return _VarintSize(wire_format.PackTag(field_number, 0))
116
117
118# --------------------------------------------------------------------
119# In this section we define some generic sizers.  Each of these functions
120# takes parameters specific to a particular field type, e.g. int32 or fixed64.
121# It returns another function which in turn takes parameters specific to a
122# particular field, e.g. the field number and whether it is repeated or packed.
123# Look at the next section to see how these are used.
124
125
126def _SimpleSizer(compute_value_size):
127  """A sizer which uses the function compute_value_size to compute the size of
128  each value.  Typically compute_value_size is _VarintSize."""
129
130  def SpecificSizer(field_number, is_repeated, is_packed):
131    tag_size = _TagSize(field_number)
132    if is_packed:
133      local_VarintSize = _VarintSize
134      def PackedFieldSize(value):
135        result = 0
136        for element in value:
137          result += compute_value_size(element)
138        return result + local_VarintSize(result) + tag_size
139      return PackedFieldSize
140    elif is_repeated:
141      def RepeatedFieldSize(value):
142        result = tag_size * len(value)
143        for element in value:
144          result += compute_value_size(element)
145        return result
146      return RepeatedFieldSize
147    else:
148      def FieldSize(value):
149        return tag_size + compute_value_size(value)
150      return FieldSize
151
152  return SpecificSizer
153
154
155def _ModifiedSizer(compute_value_size, modify_value):
156  """Like SimpleSizer, but modify_value is invoked on each value before it is
157  passed to compute_value_size.  modify_value is typically ZigZagEncode."""
158
159  def SpecificSizer(field_number, is_repeated, is_packed):
160    tag_size = _TagSize(field_number)
161    if is_packed:
162      local_VarintSize = _VarintSize
163      def PackedFieldSize(value):
164        result = 0
165        for element in value:
166          result += compute_value_size(modify_value(element))
167        return result + local_VarintSize(result) + tag_size
168      return PackedFieldSize
169    elif is_repeated:
170      def RepeatedFieldSize(value):
171        result = tag_size * len(value)
172        for element in value:
173          result += compute_value_size(modify_value(element))
174        return result
175      return RepeatedFieldSize
176    else:
177      def FieldSize(value):
178        return tag_size + compute_value_size(modify_value(value))
179      return FieldSize
180
181  return SpecificSizer
182
183
184def _FixedSizer(value_size):
185  """Like _SimpleSizer except for a fixed-size field.  The input is the size
186  of one value."""
187
188  def SpecificSizer(field_number, is_repeated, is_packed):
189    tag_size = _TagSize(field_number)
190    if is_packed:
191      local_VarintSize = _VarintSize
192      def PackedFieldSize(value):
193        result = len(value) * value_size
194        return result + local_VarintSize(result) + tag_size
195      return PackedFieldSize
196    elif is_repeated:
197      element_size = value_size + tag_size
198      def RepeatedFieldSize(value):
199        return len(value) * element_size
200      return RepeatedFieldSize
201    else:
202      field_size = value_size + tag_size
203      def FieldSize(value):
204        return field_size
205      return FieldSize
206
207  return SpecificSizer
208
209
210# ====================================================================
211# Here we declare a sizer constructor for each field type.  Each "sizer
212# constructor" is a function that takes (field_number, is_repeated, is_packed)
213# as parameters and returns a sizer, which in turn takes a field value as
214# a parameter and returns its encoded size.
215
216
217Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
218
219UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
220
221SInt32Sizer = SInt64Sizer = _ModifiedSizer(
222    _SignedVarintSize, wire_format.ZigZagEncode)
223
224Fixed32Sizer = SFixed32Sizer = FloatSizer  = _FixedSizer(4)
225Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
226
227BoolSizer = _FixedSizer(1)
228
229
230def StringSizer(field_number, is_repeated, is_packed):
231  """Returns a sizer for a string field."""
232
233  tag_size = _TagSize(field_number)
234  local_VarintSize = _VarintSize
235  local_len = len
236  assert not is_packed
237  if is_repeated:
238    def RepeatedFieldSize(value):
239      result = tag_size * len(value)
240      for element in value:
241        l = local_len(element.encode('utf-8'))
242        result += local_VarintSize(l) + l
243      return result
244    return RepeatedFieldSize
245  else:
246    def FieldSize(value):
247      l = local_len(value.encode('utf-8'))
248      return tag_size + local_VarintSize(l) + l
249    return FieldSize
250
251
252def BytesSizer(field_number, is_repeated, is_packed):
253  """Returns a sizer for a bytes field."""
254
255  tag_size = _TagSize(field_number)
256  local_VarintSize = _VarintSize
257  local_len = len
258  assert not is_packed
259  if is_repeated:
260    def RepeatedFieldSize(value):
261      result = tag_size * len(value)
262      for element in value:
263        l = local_len(element)
264        result += local_VarintSize(l) + l
265      return result
266    return RepeatedFieldSize
267  else:
268    def FieldSize(value):
269      l = local_len(value)
270      return tag_size + local_VarintSize(l) + l
271    return FieldSize
272
273
274def GroupSizer(field_number, is_repeated, is_packed):
275  """Returns a sizer for a group field."""
276
277  tag_size = _TagSize(field_number) * 2
278  assert not is_packed
279  if is_repeated:
280    def RepeatedFieldSize(value):
281      result = tag_size * len(value)
282      for element in value:
283        result += element.ByteSize()
284      return result
285    return RepeatedFieldSize
286  else:
287    def FieldSize(value):
288      return tag_size + value.ByteSize()
289    return FieldSize
290
291
292def MessageSizer(field_number, is_repeated, is_packed):
293  """Returns a sizer for a message field."""
294
295  tag_size = _TagSize(field_number)
296  local_VarintSize = _VarintSize
297  assert not is_packed
298  if is_repeated:
299    def RepeatedFieldSize(value):
300      result = tag_size * len(value)
301      for element in value:
302        l = element.ByteSize()
303        result += local_VarintSize(l) + l
304      return result
305    return RepeatedFieldSize
306  else:
307    def FieldSize(value):
308      l = value.ByteSize()
309      return tag_size + local_VarintSize(l) + l
310    return FieldSize
311
312
313# --------------------------------------------------------------------
314# MessageSet is special: it needs custom logic to compute its size properly.
315
316
317def MessageSetItemSizer(field_number):
318  """Returns a sizer for extensions of MessageSet.
319
320  The message set message looks like this:
321    message MessageSet {
322      repeated group Item = 1 {
323        required int32 type_id = 2;
324        required string message = 3;
325      }
326    }
327  """
328  static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
329                 _TagSize(3))
330  local_VarintSize = _VarintSize
331
332  def FieldSize(value):
333    l = value.ByteSize()
334    return static_size + local_VarintSize(l) + l
335
336  return FieldSize
337
338
339# --------------------------------------------------------------------
340# Map is special: it needs custom logic to compute its size properly.
341
342
343def MapSizer(field_descriptor, is_message_map):
344  """Returns a sizer for a map field."""
345
346  # Can't look at field_descriptor.message_type._concrete_class because it may
347  # not have been initialized yet.
348  message_type = field_descriptor.message_type
349  message_sizer = MessageSizer(field_descriptor.number, False, False)
350
351  def FieldSize(map_value):
352    total = 0
353    for key in map_value:
354      value = map_value[key]
355      # It's wasteful to create the messages and throw them away one second
356      # later since we'll do the same for the actual encode.  But there's not an
357      # obvious way to avoid this within the current design without tons of code
358      # duplication. For message map, value.ByteSize() should be called to
359      # update the status.
360      entry_msg = message_type._concrete_class(key=key, value=value)
361      total += message_sizer(entry_msg)
362      if is_message_map:
363        value.ByteSize()
364    return total
365
366  return FieldSize
367
368# ====================================================================
369# Encoders!
370
371
372def _VarintEncoder():
373  """Return an encoder for a basic varint value (does not include tag)."""
374
375  def EncodeVarint(write, value, unused_deterministic=None):
376    bits = value & 0x7f
377    value >>= 7
378    while value:
379      write(six.int2byte(0x80|bits))
380      bits = value & 0x7f
381      value >>= 7
382    return write(six.int2byte(bits))
383
384  return EncodeVarint
385
386
387def _SignedVarintEncoder():
388  """Return an encoder for a basic signed varint value (does not include
389  tag)."""
390
391  def EncodeSignedVarint(write, value, unused_deterministic=None):
392    if value < 0:
393      value += (1 << 64)
394    bits = value & 0x7f
395    value >>= 7
396    while value:
397      write(six.int2byte(0x80|bits))
398      bits = value & 0x7f
399      value >>= 7
400    return write(six.int2byte(bits))
401
402  return EncodeSignedVarint
403
404
405_EncodeVarint = _VarintEncoder()
406_EncodeSignedVarint = _SignedVarintEncoder()
407
408
409def _VarintBytes(value):
410  """Encode the given integer as a varint and return the bytes.  This is only
411  called at startup time so it doesn't need to be fast."""
412
413  pieces = []
414  _EncodeVarint(pieces.append, value, True)
415  return b"".join(pieces)
416
417
418def TagBytes(field_number, wire_type):
419  """Encode the given tag and return the bytes.  Only called at startup."""
420
421  return six.binary_type(
422      _VarintBytes(wire_format.PackTag(field_number, wire_type)))
423
424# --------------------------------------------------------------------
425# As with sizers (see above), we have a number of common encoder
426# implementations.
427
428
429def _SimpleEncoder(wire_type, encode_value, compute_value_size):
430  """Return a constructor for an encoder for fields of a particular type.
431
432  Args:
433      wire_type:  The field's wire type, for encoding tags.
434      encode_value:  A function which encodes an individual value, e.g.
435        _EncodeVarint().
436      compute_value_size:  A function which computes the size of an individual
437        value, e.g. _VarintSize().
438  """
439
440  def SpecificEncoder(field_number, is_repeated, is_packed):
441    if is_packed:
442      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
443      local_EncodeVarint = _EncodeVarint
444      def EncodePackedField(write, value, deterministic):
445        write(tag_bytes)
446        size = 0
447        for element in value:
448          size += compute_value_size(element)
449        local_EncodeVarint(write, size, deterministic)
450        for element in value:
451          encode_value(write, element, deterministic)
452      return EncodePackedField
453    elif is_repeated:
454      tag_bytes = TagBytes(field_number, wire_type)
455      def EncodeRepeatedField(write, value, deterministic):
456        for element in value:
457          write(tag_bytes)
458          encode_value(write, element, deterministic)
459      return EncodeRepeatedField
460    else:
461      tag_bytes = TagBytes(field_number, wire_type)
462      def EncodeField(write, value, deterministic):
463        write(tag_bytes)
464        return encode_value(write, value, deterministic)
465      return EncodeField
466
467  return SpecificEncoder
468
469
470def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
471  """Like SimpleEncoder but additionally invokes modify_value on every value
472  before passing it to encode_value.  Usually modify_value is ZigZagEncode."""
473
474  def SpecificEncoder(field_number, is_repeated, is_packed):
475    if is_packed:
476      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
477      local_EncodeVarint = _EncodeVarint
478      def EncodePackedField(write, value, deterministic):
479        write(tag_bytes)
480        size = 0
481        for element in value:
482          size += compute_value_size(modify_value(element))
483        local_EncodeVarint(write, size, deterministic)
484        for element in value:
485          encode_value(write, modify_value(element), deterministic)
486      return EncodePackedField
487    elif is_repeated:
488      tag_bytes = TagBytes(field_number, wire_type)
489      def EncodeRepeatedField(write, value, deterministic):
490        for element in value:
491          write(tag_bytes)
492          encode_value(write, modify_value(element), deterministic)
493      return EncodeRepeatedField
494    else:
495      tag_bytes = TagBytes(field_number, wire_type)
496      def EncodeField(write, value, deterministic):
497        write(tag_bytes)
498        return encode_value(write, modify_value(value), deterministic)
499      return EncodeField
500
501  return SpecificEncoder
502
503
504def _StructPackEncoder(wire_type, format):
505  """Return a constructor for an encoder for a fixed-width field.
506
507  Args:
508      wire_type:  The field's wire type, for encoding tags.
509      format:  The format string to pass to struct.pack().
510  """
511
512  value_size = struct.calcsize(format)
513
514  def SpecificEncoder(field_number, is_repeated, is_packed):
515    local_struct_pack = struct.pack
516    if is_packed:
517      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
518      local_EncodeVarint = _EncodeVarint
519      def EncodePackedField(write, value, deterministic):
520        write(tag_bytes)
521        local_EncodeVarint(write, len(value) * value_size, deterministic)
522        for element in value:
523          write(local_struct_pack(format, element))
524      return EncodePackedField
525    elif is_repeated:
526      tag_bytes = TagBytes(field_number, wire_type)
527      def EncodeRepeatedField(write, value, unused_deterministic=None):
528        for element in value:
529          write(tag_bytes)
530          write(local_struct_pack(format, element))
531      return EncodeRepeatedField
532    else:
533      tag_bytes = TagBytes(field_number, wire_type)
534      def EncodeField(write, value, unused_deterministic=None):
535        write(tag_bytes)
536        return write(local_struct_pack(format, value))
537      return EncodeField
538
539  return SpecificEncoder
540
541
542def _FloatingPointEncoder(wire_type, format):
543  """Return a constructor for an encoder for float fields.
544
545  This is like StructPackEncoder, but catches errors that may be due to
546  passing non-finite floating-point values to struct.pack, and makes a
547  second attempt to encode those values.
548
549  Args:
550      wire_type:  The field's wire type, for encoding tags.
551      format:  The format string to pass to struct.pack().
552  """
553
554  value_size = struct.calcsize(format)
555  if value_size == 4:
556    def EncodeNonFiniteOrRaise(write, value):
557      # Remember that the serialized form uses little-endian byte order.
558      if value == _POS_INF:
559        write(b'\x00\x00\x80\x7F')
560      elif value == _NEG_INF:
561        write(b'\x00\x00\x80\xFF')
562      elif value != value:           # NaN
563        write(b'\x00\x00\xC0\x7F')
564      else:
565        raise
566  elif value_size == 8:
567    def EncodeNonFiniteOrRaise(write, value):
568      if value == _POS_INF:
569        write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
570      elif value == _NEG_INF:
571        write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
572      elif value != value:                         # NaN
573        write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
574      else:
575        raise
576  else:
577    raise ValueError('Can\'t encode floating-point values that are '
578                     '%d bytes long (only 4 or 8)' % value_size)
579
580  def SpecificEncoder(field_number, is_repeated, is_packed):
581    local_struct_pack = struct.pack
582    if is_packed:
583      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
584      local_EncodeVarint = _EncodeVarint
585      def EncodePackedField(write, value, deterministic):
586        write(tag_bytes)
587        local_EncodeVarint(write, len(value) * value_size, deterministic)
588        for element in value:
589          # This try/except block is going to be faster than any code that
590          # we could write to check whether element is finite.
591          try:
592            write(local_struct_pack(format, element))
593          except SystemError:
594            EncodeNonFiniteOrRaise(write, element)
595      return EncodePackedField
596    elif is_repeated:
597      tag_bytes = TagBytes(field_number, wire_type)
598      def EncodeRepeatedField(write, value, unused_deterministic=None):
599        for element in value:
600          write(tag_bytes)
601          try:
602            write(local_struct_pack(format, element))
603          except SystemError:
604            EncodeNonFiniteOrRaise(write, element)
605      return EncodeRepeatedField
606    else:
607      tag_bytes = TagBytes(field_number, wire_type)
608      def EncodeField(write, value, unused_deterministic=None):
609        write(tag_bytes)
610        try:
611          write(local_struct_pack(format, value))
612        except SystemError:
613          EncodeNonFiniteOrRaise(write, value)
614      return EncodeField
615
616  return SpecificEncoder
617
618
619# ====================================================================
620# Here we declare an encoder constructor for each field type.  These work
621# very similarly to sizer constructors, described earlier.
622
623
624Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
625    wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
626
627UInt32Encoder = UInt64Encoder = _SimpleEncoder(
628    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
629
630SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
631    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
632    wire_format.ZigZagEncode)
633
634# Note that Python conveniently guarantees that when using the '<' prefix on
635# formats, they will also have the same size across all platforms (as opposed
636# to without the prefix, where their sizes depend on the C compiler's basic
637# type sizes).
638Fixed32Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
639Fixed64Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
640SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
641SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
642FloatEncoder    = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
643DoubleEncoder   = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
644
645
646def BoolEncoder(field_number, is_repeated, is_packed):
647  """Returns an encoder for a boolean field."""
648
649  false_byte = b'\x00'
650  true_byte = b'\x01'
651  if is_packed:
652    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
653    local_EncodeVarint = _EncodeVarint
654    def EncodePackedField(write, value, deterministic):
655      write(tag_bytes)
656      local_EncodeVarint(write, len(value), deterministic)
657      for element in value:
658        if element:
659          write(true_byte)
660        else:
661          write(false_byte)
662    return EncodePackedField
663  elif is_repeated:
664    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
665    def EncodeRepeatedField(write, value, unused_deterministic=None):
666      for element in value:
667        write(tag_bytes)
668        if element:
669          write(true_byte)
670        else:
671          write(false_byte)
672    return EncodeRepeatedField
673  else:
674    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
675    def EncodeField(write, value, unused_deterministic=None):
676      write(tag_bytes)
677      if value:
678        return write(true_byte)
679      return write(false_byte)
680    return EncodeField
681
682
683def StringEncoder(field_number, is_repeated, is_packed):
684  """Returns an encoder for a string field."""
685
686  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
687  local_EncodeVarint = _EncodeVarint
688  local_len = len
689  assert not is_packed
690  if is_repeated:
691    def EncodeRepeatedField(write, value, deterministic):
692      for element in value:
693        encoded = element.encode('utf-8')
694        write(tag)
695        local_EncodeVarint(write, local_len(encoded), deterministic)
696        write(encoded)
697    return EncodeRepeatedField
698  else:
699    def EncodeField(write, value, deterministic):
700      encoded = value.encode('utf-8')
701      write(tag)
702      local_EncodeVarint(write, local_len(encoded), deterministic)
703      return write(encoded)
704    return EncodeField
705
706
707def BytesEncoder(field_number, is_repeated, is_packed):
708  """Returns an encoder for a bytes field."""
709
710  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
711  local_EncodeVarint = _EncodeVarint
712  local_len = len
713  assert not is_packed
714  if is_repeated:
715    def EncodeRepeatedField(write, value, deterministic):
716      for element in value:
717        write(tag)
718        local_EncodeVarint(write, local_len(element), deterministic)
719        write(element)
720    return EncodeRepeatedField
721  else:
722    def EncodeField(write, value, deterministic):
723      write(tag)
724      local_EncodeVarint(write, local_len(value), deterministic)
725      return write(value)
726    return EncodeField
727
728
729def GroupEncoder(field_number, is_repeated, is_packed):
730  """Returns an encoder for a group field."""
731
732  start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
733  end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
734  assert not is_packed
735  if is_repeated:
736    def EncodeRepeatedField(write, value, deterministic):
737      for element in value:
738        write(start_tag)
739        element._InternalSerialize(write, deterministic)
740        write(end_tag)
741    return EncodeRepeatedField
742  else:
743    def EncodeField(write, value, deterministic):
744      write(start_tag)
745      value._InternalSerialize(write, deterministic)
746      return write(end_tag)
747    return EncodeField
748
749
750def MessageEncoder(field_number, is_repeated, is_packed):
751  """Returns an encoder for a message field."""
752
753  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
754  local_EncodeVarint = _EncodeVarint
755  assert not is_packed
756  if is_repeated:
757    def EncodeRepeatedField(write, value, deterministic):
758      for element in value:
759        write(tag)
760        local_EncodeVarint(write, element.ByteSize(), deterministic)
761        element._InternalSerialize(write, deterministic)
762    return EncodeRepeatedField
763  else:
764    def EncodeField(write, value, deterministic):
765      write(tag)
766      local_EncodeVarint(write, value.ByteSize(), deterministic)
767      return value._InternalSerialize(write, deterministic)
768    return EncodeField
769
770
771# --------------------------------------------------------------------
772# As before, MessageSet is special.
773
774
775def MessageSetItemEncoder(field_number):
776  """Encoder for extensions of MessageSet.
777
778  The message set message looks like this:
779    message MessageSet {
780      repeated group Item = 1 {
781        required int32 type_id = 2;
782        required string message = 3;
783      }
784    }
785  """
786  start_bytes = b"".join([
787      TagBytes(1, wire_format.WIRETYPE_START_GROUP),
788      TagBytes(2, wire_format.WIRETYPE_VARINT),
789      _VarintBytes(field_number),
790      TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
791  end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
792  local_EncodeVarint = _EncodeVarint
793
794  def EncodeField(write, value, deterministic):
795    write(start_bytes)
796    local_EncodeVarint(write, value.ByteSize(), deterministic)
797    value._InternalSerialize(write, deterministic)
798    return write(end_bytes)
799
800  return EncodeField
801
802
803# --------------------------------------------------------------------
804# As before, Map is special.
805
806
807def MapEncoder(field_descriptor):
808  """Encoder for extensions of MessageSet.
809
810  Maps always have a wire format like this:
811    message MapEntry {
812      key_type key = 1;
813      value_type value = 2;
814    }
815    repeated MapEntry map = N;
816  """
817  # Can't look at field_descriptor.message_type._concrete_class because it may
818  # not have been initialized yet.
819  message_type = field_descriptor.message_type
820  encode_message = MessageEncoder(field_descriptor.number, False, False)
821
822  def EncodeField(write, value, deterministic):
823    value_keys = sorted(value.keys()) if deterministic else value
824    for key in value_keys:
825      entry_msg = message_type._concrete_class(key=key, value=value[key])
826      encode_message(write, entry_msg, deterministic)
827
828  return EncodeField
829