1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""JSON support for message types.
19
20Public classes:
21  MessageJSONEncoder: JSON encoder for message objects.
22
23Public functions:
24  encode_message: Encodes a message in to a JSON string.
25  decode_message: Merge from a JSON string in to a message.
26"""
27import base64
28import binascii
29import logging
30
31import six
32
33from apitools.base.protorpclite import message_types
34from apitools.base.protorpclite import messages
35from apitools.base.protorpclite import util
36
37__all__ = [
38    'ALTERNATIVE_CONTENT_TYPES',
39    'CONTENT_TYPE',
40    'MessageJSONEncoder',
41    'encode_message',
42    'decode_message',
43    'ProtoJson',
44]
45
46
47def _load_json_module():
48    """Try to load a valid json module.
49
50    There are more than one json modules that might be installed.  They are
51    mostly compatible with one another but some versions may be different.
52    This function attempts to load various json modules in a preferred order.
53    It does a basic check to guess if a loaded version of json is compatible.
54
55    Returns:
56      Compatible json module.
57
58    Raises:
59      ImportError if there are no json modules or the loaded json module is
60        not compatible with ProtoRPC.
61    """
62    first_import_error = None
63    for module_name in ['json',
64                        'simplejson']:
65        try:
66            module = __import__(module_name, {}, {}, 'json')
67            if not hasattr(module, 'JSONEncoder'):
68                message = (
69                    'json library "%s" is not compatible with ProtoRPC' %
70                    module_name)
71                logging.warning(message)
72                raise ImportError(message)
73            else:
74                return module
75        except ImportError as err:
76            if not first_import_error:
77                first_import_error = err
78
79    logging.error('Must use valid json library (json or simplejson)')
80    raise first_import_error  # pylint:disable=raising-bad-type
81json = _load_json_module()
82
83
84# TODO: Rename this to MessageJsonEncoder.
85class MessageJSONEncoder(json.JSONEncoder):
86    """Message JSON encoder class.
87
88    Extension of JSONEncoder that can build JSON from a message object.
89    """
90
91    def __init__(self, protojson_protocol=None, **kwargs):
92        """Constructor.
93
94        Args:
95          protojson_protocol: ProtoJson instance.
96        """
97        super(MessageJSONEncoder, self).__init__(**kwargs)
98        self.__protojson_protocol = (
99            protojson_protocol or ProtoJson.get_default())
100
101    def default(self, value):
102        """Return dictionary instance from a message object.
103
104        Args:
105        value: Value to get dictionary for.  If not encodable, will
106          call superclasses default method.
107        """
108        if isinstance(value, messages.Enum):
109            return str(value)
110
111        if six.PY3 and isinstance(value, bytes):
112            return value.decode('utf8')
113
114        if isinstance(value, messages.Message):
115            result = {}
116            for field in value.all_fields():
117                item = value.get_assigned_value(field.name)
118                if item not in (None, [], ()):
119                    result[field.name] = (
120                        self.__protojson_protocol.encode_field(field, item))
121            # Handle unrecognized fields, so they're included when a message is
122            # decoded then encoded.
123            for unknown_key in value.all_unrecognized_fields():
124                unrecognized_field, _ = value.get_unrecognized_field_info(
125                    unknown_key)
126                result[unknown_key] = unrecognized_field
127            return result
128
129        return super(MessageJSONEncoder, self).default(value)
130
131
132class ProtoJson(object):
133    """ProtoRPC JSON implementation class.
134
135    Implementation of JSON based protocol used for serializing and
136    deserializing message objects. Instances of remote.ProtocolConfig
137    constructor or used with remote.Protocols.add_protocol. See the
138    remote.py module for more details.
139
140    """
141
142    CONTENT_TYPE = 'application/json'
143    ALTERNATIVE_CONTENT_TYPES = [
144        'application/x-javascript',
145        'text/javascript',
146        'text/x-javascript',
147        'text/x-json',
148        'text/json',
149    ]
150
151    def encode_field(self, field, value):
152        """Encode a python field value to a JSON value.
153
154        Args:
155          field: A ProtoRPC field instance.
156          value: A python value supported by field.
157
158        Returns:
159          A JSON serializable value appropriate for field.
160        """
161        if isinstance(field, messages.BytesField):
162            if field.repeated:
163                value = [base64.b64encode(byte) for byte in value]
164            else:
165                value = base64.b64encode(value)
166        elif isinstance(field, message_types.DateTimeField):
167            # DateTimeField stores its data as a RFC 3339 compliant string.
168            if field.repeated:
169                value = [i.isoformat() for i in value]
170            else:
171                value = value.isoformat()
172        return value
173
174    def encode_message(self, message):
175        """Encode Message instance to JSON string.
176
177        Args:
178          Message instance to encode in to JSON string.
179
180        Returns:
181          String encoding of Message instance in protocol JSON format.
182
183        Raises:
184          messages.ValidationError if message is not initialized.
185        """
186        message.check_initialized()
187
188        return json.dumps(message, cls=MessageJSONEncoder,
189                          protojson_protocol=self)
190
191    def decode_message(self, message_type, encoded_message):
192        """Merge JSON structure to Message instance.
193
194        Args:
195          message_type: Message to decode data to.
196          encoded_message: JSON encoded version of message.
197
198        Returns:
199          Decoded instance of message_type.
200
201        Raises:
202          ValueError: If encoded_message is not valid JSON.
203          messages.ValidationError if merged message is not initialized.
204        """
205        if not encoded_message.strip():
206            return message_type()
207
208        dictionary = json.loads(encoded_message)
209        message = self.__decode_dictionary(message_type, dictionary)
210        message.check_initialized()
211        return message
212
213    def __find_variant(self, value):
214        """Find the messages.Variant type that describes this value.
215
216        Args:
217          value: The value whose variant type is being determined.
218
219        Returns:
220          The messages.Variant value that best describes value's type,
221          or None if it's a type we don't know how to handle.
222
223        """
224        if isinstance(value, bool):
225            return messages.Variant.BOOL
226        elif isinstance(value, six.integer_types):
227            return messages.Variant.INT64
228        elif isinstance(value, float):
229            return messages.Variant.DOUBLE
230        elif isinstance(value, six.string_types):
231            return messages.Variant.STRING
232        elif isinstance(value, (list, tuple)):
233            # Find the most specific variant that covers all elements.
234            variant_priority = [None,
235                                messages.Variant.INT64,
236                                messages.Variant.DOUBLE,
237                                messages.Variant.STRING]
238            chosen_priority = 0
239            for v in value:
240                variant = self.__find_variant(v)
241                try:
242                    priority = variant_priority.index(variant)
243                except IndexError:
244                    priority = -1
245                if priority > chosen_priority:
246                    chosen_priority = priority
247            return variant_priority[chosen_priority]
248        # Unrecognized type.
249        return None
250
251    def __decode_dictionary(self, message_type, dictionary):
252        """Merge dictionary in to message.
253
254        Args:
255          message: Message to merge dictionary in to.
256          dictionary: Dictionary to extract information from.  Dictionary
257            is as parsed from JSON.  Nested objects will also be dictionaries.
258        """
259        message = message_type()
260        for key, value in six.iteritems(dictionary):
261            if value is None:
262                try:
263                    message.reset(key)
264                except AttributeError:
265                    pass  # This is an unrecognized field, skip it.
266                continue
267
268            try:
269                field = message.field_by_name(key)
270            except KeyError:
271                # Save unknown values.
272                variant = self.__find_variant(value)
273                if variant:
274                    message.set_unrecognized_field(key, value, variant)
275                continue
276
277            if field.repeated:
278                # This should be unnecessary? Or in fact become an error.
279                if not isinstance(value, list):
280                    value = [value]
281                valid_value = [self.decode_field(field, item)
282                               for item in value]
283                setattr(message, field.name, valid_value)
284            else:
285                # This is just for consistency with the old behavior.
286                if value == []:
287                    continue
288                setattr(message, field.name, self.decode_field(field, value))
289
290        return message
291
292    def decode_field(self, field, value):
293        """Decode a JSON value to a python value.
294
295        Args:
296          field: A ProtoRPC field instance.
297          value: A serialized JSON value.
298
299        Return:
300          A Python value compatible with field.
301        """
302        if isinstance(field, messages.EnumField):
303            try:
304                return field.type(value)
305            except TypeError:
306                raise messages.DecodeError(
307                    'Invalid enum value "%s"' % (value or ''))
308
309        elif isinstance(field, messages.BytesField):
310            try:
311                return base64.b64decode(value)
312            except (binascii.Error, TypeError) as err:
313                raise messages.DecodeError('Base64 decoding error: %s' % err)
314
315        elif isinstance(field, message_types.DateTimeField):
316            try:
317                return util.decode_datetime(value)
318            except ValueError as err:
319                raise messages.DecodeError(err)
320
321        elif (isinstance(field, messages.MessageField) and
322              issubclass(field.type, messages.Message)):
323            return self.__decode_dictionary(field.type, value)
324
325        elif (isinstance(field, messages.FloatField) and
326              isinstance(value, (six.integer_types, six.string_types))):
327            try:
328                return float(value)
329            except:  # pylint:disable=bare-except
330                pass
331
332        elif (isinstance(field, messages.IntegerField) and
333              isinstance(value, six.string_types)):
334            try:
335                return int(value)
336            except:  # pylint:disable=bare-except
337                pass
338
339        return value
340
341    @staticmethod
342    def get_default():
343        """Get default instanceof ProtoJson."""
344        try:
345            return ProtoJson.__default
346        except AttributeError:
347            ProtoJson.__default = ProtoJson()
348            return ProtoJson.__default
349
350    @staticmethod
351    def set_default(protocol):
352        """Set the default instance of ProtoJson.
353
354        Args:
355          protocol: A ProtoJson instance.
356        """
357        if not isinstance(protocol, ProtoJson):
358            raise TypeError('Expected protocol of type ProtoJson')
359        ProtoJson.__default = protocol
360
361CONTENT_TYPE = ProtoJson.CONTENT_TYPE
362
363ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES
364
365encode_message = ProtoJson.get_default().encode_message
366
367decode_message = ProtoJson.get_default().decode_message
368