1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Protocol buffer support for message types.
19
20For more details about protocol buffer encoding and decoding please see:
21
22  http://code.google.com/apis/protocolbuffers/docs/encoding.html
23
24Public Exceptions:
25  DecodeError: Raised when a decode error occurs from incorrect protobuf format.
26
27Public Functions:
28  encode_message: Encodes a message in to a protocol buffer string.
29  decode_message: Decode from a protocol buffer string to a message.
30"""
31import six
32
33__author__ = 'rafek@google.com (Rafe Kaplan)'
34
35
36import array
37
38from . import message_types
39from . import messages
40from . import util
41from .google_imports import ProtocolBuffer
42
43
44__all__ = ['ALTERNATIVE_CONTENT_TYPES',
45           'CONTENT_TYPE',
46           'encode_message',
47           'decode_message',
48          ]
49
50CONTENT_TYPE = 'application/octet-stream'
51
52ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf']
53
54
55class _Encoder(ProtocolBuffer.Encoder):
56  """Extension of protocol buffer encoder.
57
58  Original protocol buffer encoder does not have complete set of methods
59  for handling required encoding.  This class adds them.
60  """
61
62  # TODO(rafek): Implement the missing encoding types.
63  def no_encoding(self, value):
64    """No encoding available for type.
65
66    Args:
67      value: Value to encode.
68
69    Raises:
70      NotImplementedError at all times.
71    """
72    raise NotImplementedError()
73
74  def encode_enum(self, value):
75    """Encode an enum value.
76
77    Args:
78      value: Enum to encode.
79    """
80    self.putVarInt32(value.number)
81
82  def encode_message(self, value):
83    """Encode a Message in to an embedded message.
84
85    Args:
86      value: Message instance to encode.
87    """
88    self.putPrefixedString(encode_message(value))
89
90
91  def encode_unicode_string(self, value):
92    """Helper to properly pb encode unicode strings to UTF-8.
93
94    Args:
95      value: String value to encode.
96    """
97    if isinstance(value, six.text_type):
98      value = value.encode('utf-8')
99    self.putPrefixedString(value)
100
101
102class _Decoder(ProtocolBuffer.Decoder):
103  """Extension of protocol buffer decoder.
104
105  Original protocol buffer decoder does not have complete set of methods
106  for handling required decoding.  This class adds them.
107  """
108
109  # TODO(rafek): Implement the missing encoding types.
110  def no_decoding(self):
111    """No decoding available for type.
112
113    Raises:
114      NotImplementedError at all times.
115    """
116    raise NotImplementedError()
117
118  def decode_string(self):
119    """Decode a unicode string.
120
121    Returns:
122      Next value in stream as a unicode string.
123    """
124    return self.getPrefixedString().decode('UTF-8')
125
126  def decode_boolean(self):
127    """Decode a boolean value.
128
129    Returns:
130      Next value in stream as a boolean.
131    """
132    return bool(self.getBoolean())
133
134
135# Number of bits used to describe a protocol buffer bits used for the variant.
136_WIRE_TYPE_BITS = 3
137_WIRE_TYPE_MASK = 7
138
139
140# Maps variant to underlying wire type.  Many variants map to same type.
141_VARIANT_TO_WIRE_TYPE = {
142    messages.Variant.DOUBLE: _Encoder.DOUBLE,
143    messages.Variant.FLOAT: _Encoder.FLOAT,
144    messages.Variant.INT64: _Encoder.NUMERIC,
145    messages.Variant.UINT64: _Encoder.NUMERIC,
146    messages.Variant.INT32:  _Encoder.NUMERIC,
147    messages.Variant.BOOL: _Encoder.NUMERIC,
148    messages.Variant.STRING: _Encoder.STRING,
149    messages.Variant.MESSAGE: _Encoder.STRING,
150    messages.Variant.BYTES: _Encoder.STRING,
151    messages.Variant.UINT32: _Encoder.NUMERIC,
152    messages.Variant.ENUM:  _Encoder.NUMERIC,
153    messages.Variant.SINT32: _Encoder.NUMERIC,
154    messages.Variant.SINT64: _Encoder.NUMERIC,
155}
156
157
158# Maps variant to encoder method.
159_VARIANT_TO_ENCODER_MAP = {
160    messages.Variant.DOUBLE: _Encoder.putDouble,
161    messages.Variant.FLOAT: _Encoder.putFloat,
162    messages.Variant.INT64: _Encoder.putVarInt64,
163    messages.Variant.UINT64: _Encoder.putVarUint64,
164    messages.Variant.INT32: _Encoder.putVarInt32,
165    messages.Variant.BOOL: _Encoder.putBoolean,
166    messages.Variant.STRING: _Encoder.encode_unicode_string,
167    messages.Variant.MESSAGE: _Encoder.encode_message,
168    messages.Variant.BYTES: _Encoder.encode_unicode_string,
169    messages.Variant.UINT32: _Encoder.no_encoding,
170    messages.Variant.ENUM: _Encoder.encode_enum,
171    messages.Variant.SINT32: _Encoder.no_encoding,
172    messages.Variant.SINT64: _Encoder.no_encoding,
173}
174
175
176# Basic wire format decoders.  Used for reading unknown values.
177_WIRE_TYPE_TO_DECODER_MAP = {
178  _Encoder.NUMERIC: _Decoder.getVarInt64,
179  _Encoder.DOUBLE: _Decoder.getDouble,
180  _Encoder.STRING: _Decoder.getPrefixedString,
181  _Encoder.FLOAT: _Decoder.getFloat,
182}
183
184
185# Map wire type to variant.  Used to find a variant for unknown values.
186_WIRE_TYPE_TO_VARIANT_MAP = {
187  _Encoder.NUMERIC: messages.Variant.INT64,
188  _Encoder.DOUBLE: messages.Variant.DOUBLE,
189  _Encoder.STRING: messages.Variant.STRING,
190  _Encoder.FLOAT: messages.Variant.FLOAT,
191}
192
193
194# Wire type to name mapping for error messages.
195_WIRE_TYPE_NAME = {
196  _Encoder.NUMERIC: 'NUMERIC',
197  _Encoder.DOUBLE: 'DOUBLE',
198  _Encoder.STRING: 'STRING',
199  _Encoder.FLOAT: 'FLOAT',
200}
201
202
203# Maps variant to decoder method.
204_VARIANT_TO_DECODER_MAP = {
205    messages.Variant.DOUBLE: _Decoder.getDouble,
206    messages.Variant.FLOAT: _Decoder.getFloat,
207    messages.Variant.INT64: _Decoder.getVarInt64,
208    messages.Variant.UINT64: _Decoder.getVarUint64,
209    messages.Variant.INT32:  _Decoder.getVarInt32,
210    messages.Variant.BOOL: _Decoder.decode_boolean,
211    messages.Variant.STRING: _Decoder.decode_string,
212    messages.Variant.MESSAGE: _Decoder.getPrefixedString,
213    messages.Variant.BYTES: _Decoder.getPrefixedString,
214    messages.Variant.UINT32: _Decoder.no_decoding,
215    messages.Variant.ENUM:  _Decoder.getVarInt32,
216    messages.Variant.SINT32: _Decoder.no_decoding,
217    messages.Variant.SINT64: _Decoder.no_decoding,
218}
219
220
221def encode_message(message):
222  """Encode Message instance to protocol buffer.
223
224  Args:
225    Message instance to encode in to protocol buffer.
226
227  Returns:
228    String encoding of Message instance in protocol buffer format.
229
230  Raises:
231    messages.ValidationError if message is not initialized.
232  """
233  message.check_initialized()
234  encoder = _Encoder()
235
236  # Get all fields, from the known fields we parsed and the unknown fields
237  # we saved.  Note which ones were known, so we can process them differently.
238  all_fields = [(field.number, field) for field in message.all_fields()]
239  all_fields.extend((key, None)
240                    for key in message.all_unrecognized_fields()
241                    if isinstance(key, six.integer_types))
242  all_fields.sort()
243  for field_num, field in all_fields:
244    if field:
245      # Known field.
246      value = message.get_assigned_value(field.name)
247      if value is None:
248        continue
249      variant = field.variant
250      repeated = field.repeated
251    else:
252      # Unrecognized field.
253      value, variant = message.get_unrecognized_field_info(field_num)
254      if not isinstance(variant, messages.Variant):
255        continue
256      repeated = isinstance(value, (list, tuple))
257
258    tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant])
259
260    # Write value to wire.
261    if repeated:
262      values = value
263    else:
264      values = [value]
265    for next in values:
266      encoder.putVarInt32(tag)
267      if isinstance(field, messages.MessageField):
268        next = field.value_to_message(next)
269      field_encoder = _VARIANT_TO_ENCODER_MAP[variant]
270      field_encoder(encoder, next)
271
272  return encoder.buffer().tostring()
273
274
275def decode_message(message_type, encoded_message):
276  """Decode protocol buffer to Message instance.
277
278  Args:
279    message_type: Message type to decode data to.
280    encoded_message: Encoded version of message as string.
281
282  Returns:
283    Decoded instance of message_type.
284
285  Raises:
286    DecodeError if an error occurs during decoding, such as incompatible
287      wire format for a field.
288    messages.ValidationError if merged message is not initialized.
289  """
290  message = message_type()
291  message_array = array.array('B')
292  message_array.fromstring(encoded_message)
293  try:
294    decoder = _Decoder(message_array, 0, len(message_array))
295
296    while decoder.avail() > 0:
297      # Decode tag and variant information.
298      encoded_tag = decoder.getVarInt32()
299      tag = encoded_tag >> _WIRE_TYPE_BITS
300      wire_type = encoded_tag & _WIRE_TYPE_MASK
301      try:
302        found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type]
303      except:
304        raise messages.DecodeError('No such wire type %d' % wire_type)
305
306      if tag < 1:
307        raise messages.DecodeError('Invalid tag value %d' % tag)
308
309      try:
310        field = message.field_by_number(tag)
311      except KeyError:
312        # Unexpected tags are ok.
313        field = None
314        wire_type_decoder = found_wire_type_decoder
315      else:
316        expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant]
317        if expected_wire_type != wire_type:
318          raise messages.DecodeError('Expected wire type %s but found %s' % (
319              _WIRE_TYPE_NAME[expected_wire_type],
320              _WIRE_TYPE_NAME[wire_type]))
321
322        wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant]
323
324      value = wire_type_decoder(decoder)
325
326      # Save unknown fields and skip additional processing.
327      if not field:
328        # When saving this, save it under the tag number (which should
329        # be unique), and set the variant and value so we know how to
330        # interpret the value later.
331        variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type)
332        if variant:
333          message.set_unrecognized_field(tag, value, variant)
334        continue
335
336      # Special case Enum and Message types.
337      if isinstance(field, messages.EnumField):
338        try:
339          value = field.type(value)
340        except TypeError:
341          raise messages.DecodeError('Invalid enum value %s' % value)
342      elif isinstance(field, messages.MessageField):
343        value = decode_message(field.message_type, value)
344        value = field.value_from_message(value)
345
346      # Merge value in to message.
347      if field.repeated:
348        values = getattr(message, field.name)
349        if values is None:
350          setattr(message, field.name, [value])
351        else:
352          values.append(value)
353      else:
354        setattr(message, field.name, value)
355  except ProtocolBuffer.ProtocolBufferDecodeError as err:
356    raise messages.DecodeError('Decoding error: %s' % str(err))
357
358  message.check_initialized()
359  return message
360