1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module that encodes (decodes) nested structures into (from) protos.
16
17The intended use is to serialize everything needed to restore a `Function` that
18was saved into a SavedModel. This may include concrete function inputs and
19outputs, signatures, function specs, etc.
20
21Example use:
22coder = nested_structure_coder.StructureCoder()
23# Encode into proto.
24signature_proto = coder.encode_structure(function.input_signature)
25# Decode into a Python object.
26restored_signature = coder.decode_proto(signature_proto)
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import collections
34import functools
35import six
36
37from tensorflow.core.protobuf import struct_pb2
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.util import compat
42
43
44class NotEncodableError(Exception):
45  """Error raised when a coder cannot encode an object."""
46
47
48class StructureCoder(object):
49  """Encoder and decoder for nested structures into protos."""
50
51  _codecs = []
52
53  @classmethod
54  def register_codec(cls, x):
55    cls._codecs.append(x)
56
57  @classmethod
58  def _get_encoders(cls):
59    return [(c.can_encode, c.do_encode) for c in cls._codecs]
60
61  @classmethod
62  def _get_decoders(cls):
63    return [(c.can_decode, c.do_decode) for c in cls._codecs]
64
65  def _map_structure(self, pyobj, coders):
66    for can, do in coders:
67      if can(pyobj):
68        recursion_fn = functools.partial(self._map_structure, coders=coders)
69        return do(pyobj, recursion_fn)
70    raise NotEncodableError(
71        "No encoder for object [%s] of type [%s]." % (str(pyobj), type(pyobj)))
72
73  def encode_structure(self, nested_structure):
74    """Encodes nested structures composed of encodable types into a proto.
75
76    Args:
77      nested_structure: Structure to encode.
78
79    Returns:
80      Encoded proto.
81
82    Raises:
83      NotEncodableError: For values for which there are no encoders.
84    """
85    return self._map_structure(nested_structure, self._get_encoders())
86
87  def can_encode(self, nested_structure):
88    """Determines whether a nested structure can be encoded into a proto.
89
90    Args:
91      nested_structure: Structure to encode.
92
93    Returns:
94      True if the nested structured can be encoded.
95    """
96    try:
97      self.encode_structure(nested_structure)
98    except NotEncodableError:
99      return False
100    return True
101
102  def decode_proto(self, proto):
103    """Decodes proto representing a nested structure.
104
105    Args:
106      proto: Proto to decode.
107
108    Returns:
109      Decoded structure.
110
111    Raises:
112      NotEncodableError: For values for which there are no encoders.
113    """
114    return self._map_structure(proto, self._get_decoders())
115
116
117class _ListCodec(object):
118  """Codec for lists."""
119
120  def can_encode(self, pyobj):
121    return isinstance(pyobj, list)
122
123  def do_encode(self, list_value, encode_fn):
124    encoded_list = struct_pb2.StructuredValue()
125    encoded_list.list_value.CopyFrom(struct_pb2.ListValue())
126    for element in list_value:
127      encoded_list.list_value.values.add().CopyFrom(encode_fn(element))
128    return encoded_list
129
130  def can_decode(self, value):
131    return value.HasField("list_value")
132
133  def do_decode(self, value, decode_fn):
134    return [decode_fn(element) for element in value.list_value.values]
135
136
137StructureCoder.register_codec(_ListCodec())
138
139
140def _is_tuple(obj):
141  return not _is_named_tuple(obj) and isinstance(obj, tuple)
142
143
144def _is_named_tuple(instance):
145  """Returns True iff `instance` is a `namedtuple`.
146
147  Args:
148    instance: An instance of a Python object.
149
150  Returns:
151    True if `instance` is a `namedtuple`.
152  """
153  if not isinstance(instance, tuple):
154    return False
155  return (hasattr(instance, "_fields") and
156          isinstance(instance._fields, collections.Sequence) and
157          all(isinstance(f, six.string_types) for f in instance._fields))
158
159
160class _TupleCodec(object):
161  """Codec for tuples."""
162
163  def can_encode(self, pyobj):
164    return _is_tuple(pyobj)
165
166  def do_encode(self, tuple_value, encode_fn):
167    encoded_tuple = struct_pb2.StructuredValue()
168    encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue())
169    for element in tuple_value:
170      encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element))
171    return encoded_tuple
172
173  def can_decode(self, value):
174    return value.HasField("tuple_value")
175
176  def do_decode(self, value, decode_fn):
177    return tuple(decode_fn(element) for element in value.tuple_value.values)
178
179
180StructureCoder.register_codec(_TupleCodec())
181
182
183class _DictCodec(object):
184  """Codec for dicts."""
185
186  def can_encode(self, pyobj):
187    return isinstance(pyobj, dict)
188
189  def do_encode(self, dict_value, encode_fn):
190    encoded_dict = struct_pb2.StructuredValue()
191    encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue())
192    for key, value in dict_value.items():
193      encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value))
194    return encoded_dict
195
196  def can_decode(self, value):
197    return value.HasField("dict_value")
198
199  def do_decode(self, value, decode_fn):
200    return {key: decode_fn(val) for key, val in value.dict_value.fields.items()}
201
202
203StructureCoder.register_codec(_DictCodec())
204
205
206class _NamedTupleCodec(object):
207  """Codec for namedtuples.
208
209  Encoding and decoding a namedtuple reconstructs a namedtuple with a different
210  actual Python type, but with same `typename` and `fields`.
211  """
212
213  def can_encode(self, pyobj):
214    return _is_named_tuple(pyobj)
215
216  def do_encode(self, named_tuple_value, encode_fn):
217    encoded_named_tuple = struct_pb2.StructuredValue()
218    encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue())
219    encoded_named_tuple.named_tuple_value.name = \
220      named_tuple_value.__class__.__name__
221    for key in named_tuple_value._fields:
222      pair = encoded_named_tuple.named_tuple_value.values.add()
223      pair.key = key
224      pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key]))
225    return encoded_named_tuple
226
227  def can_decode(self, value):
228    return value.HasField("named_tuple_value")
229
230  def do_decode(self, value, decode_fn):
231    key_value_pairs = value.named_tuple_value.values
232    items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs]
233    named_tuple_type = collections.namedtuple(value.named_tuple_value.name,
234                                              [item[0] for item in items])
235    return named_tuple_type(**dict(items))
236
237
238StructureCoder.register_codec(_NamedTupleCodec())
239
240
241class _Float64Codec(object):
242  """Codec for floats."""
243
244  def can_encode(self, pyobj):
245    return isinstance(pyobj, float)
246
247  def do_encode(self, float64_value, encode_fn):
248    del encode_fn
249    value = struct_pb2.StructuredValue()
250    value.float64_value = float64_value
251    return value
252
253  def can_decode(self, value):
254    return value.HasField("float64_value")
255
256  def do_decode(self, value, decode_fn):
257    del decode_fn
258    return value.float64_value
259
260
261StructureCoder.register_codec(_Float64Codec())
262
263
264class _Int64Codec(object):
265  """Codec for Python integers (limited to 64 bit values)."""
266
267  def can_encode(self, pyobj):
268    return not isinstance(pyobj, bool) and isinstance(pyobj, int)
269
270  def do_encode(self, int_value, encode_fn):
271    del encode_fn
272    value = struct_pb2.StructuredValue()
273    value.int64_value = int_value
274    return value
275
276  def can_decode(self, value):
277    return value.HasField("int64_value")
278
279  def do_decode(self, value, decode_fn):
280    del decode_fn
281    return int(value.int64_value)
282
283
284StructureCoder.register_codec(_Int64Codec())
285
286
287class _StringCodec(object):
288  """Codec for strings.
289
290  See StructuredValue.string_value in proto/struct.proto for more detailed
291  explanation.
292  """
293
294  def can_encode(self, pyobj):
295    return isinstance(pyobj, str)
296
297  def do_encode(self, string_value, encode_fn):
298    del encode_fn
299    value = struct_pb2.StructuredValue()
300    value.string_value = string_value
301    return value
302
303  def can_decode(self, value):
304    return value.HasField("string_value")
305
306  def do_decode(self, value, decode_fn):
307    del decode_fn
308    return compat.as_str(value.string_value)
309
310
311StructureCoder.register_codec(_StringCodec())
312
313
314class _NoneCodec(object):
315  """Codec for None."""
316
317  def can_encode(self, pyobj):
318    return pyobj is None
319
320  def do_encode(self, none_value, encode_fn):
321    del encode_fn, none_value
322    value = struct_pb2.StructuredValue()
323    value.none_value.CopyFrom(struct_pb2.NoneValue())
324    return value
325
326  def can_decode(self, value):
327    return value.HasField("none_value")
328
329  def do_decode(self, value, decode_fn):
330    del decode_fn, value
331    return None
332
333
334StructureCoder.register_codec(_NoneCodec())
335
336
337class _BoolCodec(object):
338  """Codec for booleans."""
339
340  def can_encode(self, pyobj):
341    return isinstance(pyobj, bool)
342
343  def do_encode(self, bool_value, encode_fn):
344    del encode_fn
345    value = struct_pb2.StructuredValue()
346    value.bool_value = bool_value
347    return value
348
349  def can_decode(self, value):
350    return value.HasField("bool_value")
351
352  def do_decode(self, value, decode_fn):
353    del decode_fn
354    return value.bool_value
355
356
357StructureCoder.register_codec(_BoolCodec())
358
359
360class _TensorShapeCodec(object):
361  """Codec for `TensorShape`."""
362
363  def can_encode(self, pyobj):
364    return isinstance(pyobj, tensor_shape.TensorShape)
365
366  def do_encode(self, tensor_shape_value, encode_fn):
367    del encode_fn
368    encoded_tensor_shape = struct_pb2.StructuredValue()
369    encoded_tensor_shape.tensor_shape_value.CopyFrom(
370        tensor_shape_value.as_proto())
371    return encoded_tensor_shape
372
373  def can_decode(self, value):
374    return value.HasField("tensor_shape_value")
375
376  def do_decode(self, value, decode_fn):
377    del decode_fn
378    return tensor_shape.TensorShape(value.tensor_shape_value)
379
380
381StructureCoder.register_codec(_TensorShapeCodec())
382
383
384class _TensorTypeCodec(object):
385  """Codec for `TensorType`."""
386
387  def can_encode(self, pyobj):
388    return isinstance(pyobj, dtypes.DType)
389
390  def do_encode(self, tensor_dtype_value, encode_fn):
391    del encode_fn
392    encoded_tensor_type = struct_pb2.StructuredValue()
393    encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum
394    return encoded_tensor_type
395
396  def can_decode(self, value):
397    return value.HasField("tensor_dtype_value")
398
399  def do_decode(self, value, decode_fn):
400    del decode_fn
401    return dtypes.DType(value.tensor_dtype_value)
402
403
404StructureCoder.register_codec(_TensorTypeCodec())
405
406
407class _TensorSpecCodec(object):
408  """Codec for `TensorSpec`."""
409
410  def can_encode(self, pyobj):
411    return isinstance(pyobj, tensor_spec.TensorSpec)
412
413  def do_encode(self, tensor_spec_value, encode_fn):
414    encoded_tensor_spec = struct_pb2.StructuredValue()
415    encoded_tensor_spec.tensor_spec_value.CopyFrom(
416        struct_pb2.TensorSpecProto(
417            shape=encode_fn(tensor_spec_value.shape).tensor_shape_value,
418            dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value,
419            name=tensor_spec_value.name))
420    return encoded_tensor_spec
421
422  def can_decode(self, value):
423    return value.HasField("tensor_spec_value")
424
425  def do_decode(self, value, decode_fn):
426    return tensor_spec.TensorSpec(
427        shape=decode_fn(
428            struct_pb2.StructuredValue(
429                tensor_shape_value=value.tensor_spec_value.shape)),
430        dtype=decode_fn(
431            struct_pb2.StructuredValue(
432                tensor_dtype_value=value.tensor_spec_value.dtype)),
433        name=value.tensor_spec_value.name)
434
435
436StructureCoder.register_codec(_TensorSpecCodec())
437