1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Extra types understood by apitools."""
18
19import collections
20import datetime
21import json
22import numbers
23
24import six
25
26from apitools.base.protorpclite import message_types
27from apitools.base.protorpclite import messages
28from apitools.base.protorpclite import protojson
29from apitools.base.py import encoding
30from apitools.base.py import exceptions
31from apitools.base.py import util
32
33__all__ = [
34    'DateField',
35    'DateTimeMessage',
36    'JsonArray',
37    'JsonObject',
38    'JsonValue',
39    'JsonProtoEncoder',
40    'JsonProtoDecoder',
41]
42
43# pylint:disable=invalid-name
44DateTimeMessage = message_types.DateTimeMessage
45# pylint:enable=invalid-name
46
47
48# We insert our own metaclass here to avoid letting ProtoRPC
49# register this as the default field type for strings.
50#  * since ProtoRPC does this via metaclasses, we don't have any
51#    choice but to use one ourselves
52#  * since a subclass's metaclass must inherit from its superclass's
53#    metaclass, we're forced to have this hard-to-read inheritance.
54#
55# pylint: disable=protected-access
56class _FieldMeta(messages._FieldMeta):
57
58    def __init__(cls, name, bases, dct):  # pylint: disable=no-self-argument
59        # pylint: disable=super-init-not-called,non-parent-init-called
60        type.__init__(cls, name, bases, dct)
61# pylint: enable=protected-access
62
63
64class DateField(six.with_metaclass(_FieldMeta, messages.Field)):
65
66    """Field definition for Date values."""
67
68    VARIANTS = frozenset([messages.Variant.STRING])
69    DEFAULT_VARIANT = messages.Variant.STRING
70    type = datetime.date
71
72
73def _ValidateJsonValue(json_value):
74    entries = [(f, json_value.get_assigned_value(f.name))
75               for f in json_value.all_fields()]
76    assigned_entries = [(f, value)
77                        for f, value in entries if value is not None]
78    if len(assigned_entries) != 1:
79        raise exceptions.InvalidDataError(
80            'Malformed JsonValue: %s' % json_value)
81
82
83def _JsonValueToPythonValue(json_value):
84    """Convert the given JsonValue to a json string."""
85    util.Typecheck(json_value, JsonValue)
86    _ValidateJsonValue(json_value)
87    if json_value.is_null:
88        return None
89    entries = [(f, json_value.get_assigned_value(f.name))
90               for f in json_value.all_fields()]
91    assigned_entries = [(f, value)
92                        for f, value in entries if value is not None]
93    field, value = assigned_entries[0]
94    if not isinstance(field, messages.MessageField):
95        return value
96    elif field.message_type is JsonObject:
97        return _JsonObjectToPythonValue(value)
98    elif field.message_type is JsonArray:
99        return _JsonArrayToPythonValue(value)
100
101
102def _JsonObjectToPythonValue(json_value):
103    util.Typecheck(json_value, JsonObject)
104    return dict([(prop.key, _JsonValueToPythonValue(prop.value)) for prop
105                 in json_value.properties])
106
107
108def _JsonArrayToPythonValue(json_value):
109    util.Typecheck(json_value, JsonArray)
110    return [_JsonValueToPythonValue(e) for e in json_value.entries]
111
112
113_MAXINT64 = 2 << 63 - 1
114_MININT64 = -(2 << 63)
115
116
117def _PythonValueToJsonValue(py_value):
118    """Convert the given python value to a JsonValue."""
119    if py_value is None:
120        return JsonValue(is_null=True)
121    if isinstance(py_value, bool):
122        return JsonValue(boolean_value=py_value)
123    if isinstance(py_value, six.string_types):
124        return JsonValue(string_value=py_value)
125    if isinstance(py_value, numbers.Number):
126        if isinstance(py_value, six.integer_types):
127            if _MININT64 < py_value < _MAXINT64:
128                return JsonValue(integer_value=py_value)
129        return JsonValue(double_value=float(py_value))
130    if isinstance(py_value, dict):
131        return JsonValue(object_value=_PythonValueToJsonObject(py_value))
132    if isinstance(py_value, collections.Iterable):
133        return JsonValue(array_value=_PythonValueToJsonArray(py_value))
134    raise exceptions.InvalidDataError(
135        'Cannot convert "%s" to JsonValue' % py_value)
136
137
138def _PythonValueToJsonObject(py_value):
139    util.Typecheck(py_value, dict)
140    return JsonObject(
141        properties=[
142            JsonObject.Property(key=key, value=_PythonValueToJsonValue(value))
143            for key, value in py_value.items()])
144
145
146def _PythonValueToJsonArray(py_value):
147    return JsonArray(entries=list(map(_PythonValueToJsonValue, py_value)))
148
149
150class JsonValue(messages.Message):
151
152    """Any valid JSON value."""
153    # Is this JSON object `null`?
154    is_null = messages.BooleanField(1, default=False)
155
156    # Exactly one of the following is provided if is_null is False; none
157    # should be provided if is_null is True.
158    boolean_value = messages.BooleanField(2)
159    string_value = messages.StringField(3)
160    # We keep two numeric fields to keep int64 round-trips exact.
161    double_value = messages.FloatField(4, variant=messages.Variant.DOUBLE)
162    integer_value = messages.IntegerField(5, variant=messages.Variant.INT64)
163    # Compound types
164    object_value = messages.MessageField('JsonObject', 6)
165    array_value = messages.MessageField('JsonArray', 7)
166
167
168class JsonObject(messages.Message):
169
170    """A JSON object value.
171
172    Messages:
173      Property: A property of a JsonObject.
174
175    Fields:
176      properties: A list of properties of a JsonObject.
177    """
178
179    class Property(messages.Message):
180
181        """A property of a JSON object.
182
183        Fields:
184          key: Name of the property.
185          value: A JsonValue attribute.
186        """
187        key = messages.StringField(1)
188        value = messages.MessageField(JsonValue, 2)
189
190    properties = messages.MessageField(Property, 1, repeated=True)
191
192
193class JsonArray(messages.Message):
194
195    """A JSON array value."""
196    entries = messages.MessageField(JsonValue, 1, repeated=True)
197
198
199_JSON_PROTO_TO_PYTHON_MAP = {
200    JsonArray: _JsonArrayToPythonValue,
201    JsonObject: _JsonObjectToPythonValue,
202    JsonValue: _JsonValueToPythonValue,
203}
204_JSON_PROTO_TYPES = tuple(_JSON_PROTO_TO_PYTHON_MAP.keys())
205
206
207def _JsonProtoToPythonValue(json_proto):
208    util.Typecheck(json_proto, _JSON_PROTO_TYPES)
209    return _JSON_PROTO_TO_PYTHON_MAP[type(json_proto)](json_proto)
210
211
212def _PythonValueToJsonProto(py_value):
213    if isinstance(py_value, dict):
214        return _PythonValueToJsonObject(py_value)
215    if (isinstance(py_value, collections.Iterable) and
216            not isinstance(py_value, six.string_types)):
217        return _PythonValueToJsonArray(py_value)
218    return _PythonValueToJsonValue(py_value)
219
220
221def _JsonProtoToJson(json_proto, unused_encoder=None):
222    return json.dumps(_JsonProtoToPythonValue(json_proto))
223
224
225def _JsonToJsonProto(json_data, unused_decoder=None):
226    return _PythonValueToJsonProto(json.loads(json_data))
227
228
229def _JsonToJsonValue(json_data, unused_decoder=None):
230    result = _PythonValueToJsonProto(json.loads(json_data))
231    if isinstance(result, JsonValue):
232        return result
233    elif isinstance(result, JsonObject):
234        return JsonValue(object_value=result)
235    elif isinstance(result, JsonArray):
236        return JsonValue(array_value=result)
237    else:
238        raise exceptions.InvalidDataError(
239            'Malformed JsonValue: %s' % json_data)
240
241
242# pylint:disable=invalid-name
243JsonProtoEncoder = _JsonProtoToJson
244JsonProtoDecoder = _JsonToJsonProto
245# pylint:enable=invalid-name
246encoding.RegisterCustomMessageCodec(
247    encoder=JsonProtoEncoder, decoder=_JsonToJsonValue)(JsonValue)
248encoding.RegisterCustomMessageCodec(
249    encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonObject)
250encoding.RegisterCustomMessageCodec(
251    encoder=JsonProtoEncoder, decoder=JsonProtoDecoder)(JsonArray)
252
253
254def _EncodeDateTimeField(field, value):
255    result = protojson.ProtoJson().encode_field(field, value)
256    return encoding.CodecResult(value=result, complete=True)
257
258
259def _DecodeDateTimeField(unused_field, value):
260    result = protojson.ProtoJson().decode_field(
261        message_types.DateTimeField(1), value)
262    return encoding.CodecResult(value=result, complete=True)
263
264
265encoding.RegisterFieldTypeCodec(_EncodeDateTimeField, _DecodeDateTimeField)(
266    message_types.DateTimeField)
267
268
269def _EncodeInt64Field(field, value):
270    """Handle the special case of int64 as a string."""
271    capabilities = [
272        messages.Variant.INT64,
273        messages.Variant.UINT64,
274    ]
275    if field.variant not in capabilities:
276        return encoding.CodecResult(value=value, complete=False)
277
278    if field.repeated:
279        result = [str(x) for x in value]
280    else:
281        result = str(value)
282    return encoding.CodecResult(value=result, complete=True)
283
284
285def _DecodeInt64Field(unused_field, value):
286    # Don't need to do anything special, they're decoded just fine
287    return encoding.CodecResult(value=value, complete=False)
288
289encoding.RegisterFieldTypeCodec(_EncodeInt64Field, _DecodeInt64Field)(
290    messages.IntegerField)
291
292
293def _EncodeDateField(field, value):
294    """Encoder for datetime.date objects."""
295    if field.repeated:
296        result = [d.isoformat() for d in value]
297    else:
298        result = value.isoformat()
299    return encoding.CodecResult(value=result, complete=True)
300
301
302def _DecodeDateField(unused_field, value):
303    date = datetime.datetime.strptime(value, '%Y-%m-%d').date()
304    return encoding.CodecResult(value=date, complete=True)
305
306encoding.RegisterFieldTypeCodec(_EncodeDateField, _DecodeDateField)(DateField)
307