1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module that encodes (decodes) nested structures into (from) protos. 16 17The intended use is to serialize everything needed to restore a `Function` that 18was saved into a SavedModel. This may include concrete function inputs and 19outputs, signatures, function specs, etc. 20 21Example use: 22coder = nested_structure_coder.StructureCoder() 23# Encode into proto. 24signature_proto = coder.encode_structure(function.input_signature) 25# Decode into a Python object. 26restored_signature = coder.decode_proto(signature_proto) 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import collections 34import functools 35import six 36 37from tensorflow.core.protobuf import struct_pb2 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.util import compat 42 43 44class NotEncodableError(Exception): 45 """Error raised when a coder cannot encode an object.""" 46 47 48class StructureCoder(object): 49 """Encoder and decoder for nested structures into protos.""" 50 51 _codecs = [] 52 53 @classmethod 54 def register_codec(cls, x): 55 cls._codecs.append(x) 56 57 @classmethod 58 def _get_encoders(cls): 59 return [(c.can_encode, c.do_encode) for c in cls._codecs] 60 61 @classmethod 62 def _get_decoders(cls): 63 return [(c.can_decode, c.do_decode) for c in cls._codecs] 64 65 def _map_structure(self, pyobj, coders): 66 for can, do in coders: 67 if can(pyobj): 68 recursion_fn = functools.partial(self._map_structure, coders=coders) 69 return do(pyobj, recursion_fn) 70 raise NotEncodableError( 71 "No encoder for object [%s] of type [%s]." % (str(pyobj), type(pyobj))) 72 73 def encode_structure(self, nested_structure): 74 """Encodes nested structures composed of encodable types into a proto. 75 76 Args: 77 nested_structure: Structure to encode. 78 79 Returns: 80 Encoded proto. 81 82 Raises: 83 NotEncodableError: For values for which there are no encoders. 84 """ 85 return self._map_structure(nested_structure, self._get_encoders()) 86 87 def can_encode(self, nested_structure): 88 """Determines whether a nested structure can be encoded into a proto. 89 90 Args: 91 nested_structure: Structure to encode. 92 93 Returns: 94 True if the nested structured can be encoded. 95 """ 96 try: 97 self.encode_structure(nested_structure) 98 except NotEncodableError: 99 return False 100 return True 101 102 def decode_proto(self, proto): 103 """Decodes proto representing a nested structure. 104 105 Args: 106 proto: Proto to decode. 107 108 Returns: 109 Decoded structure. 110 111 Raises: 112 NotEncodableError: For values for which there are no encoders. 113 """ 114 return self._map_structure(proto, self._get_decoders()) 115 116 117class _ListCodec(object): 118 """Codec for lists.""" 119 120 def can_encode(self, pyobj): 121 return isinstance(pyobj, list) 122 123 def do_encode(self, list_value, encode_fn): 124 encoded_list = struct_pb2.StructuredValue() 125 encoded_list.list_value.CopyFrom(struct_pb2.ListValue()) 126 for element in list_value: 127 encoded_list.list_value.values.add().CopyFrom(encode_fn(element)) 128 return encoded_list 129 130 def can_decode(self, value): 131 return value.HasField("list_value") 132 133 def do_decode(self, value, decode_fn): 134 return [decode_fn(element) for element in value.list_value.values] 135 136 137StructureCoder.register_codec(_ListCodec()) 138 139 140def _is_tuple(obj): 141 return not _is_named_tuple(obj) and isinstance(obj, tuple) 142 143 144def _is_named_tuple(instance): 145 """Returns True iff `instance` is a `namedtuple`. 146 147 Args: 148 instance: An instance of a Python object. 149 150 Returns: 151 True if `instance` is a `namedtuple`. 152 """ 153 if not isinstance(instance, tuple): 154 return False 155 return (hasattr(instance, "_fields") and 156 isinstance(instance._fields, collections.Sequence) and 157 all(isinstance(f, six.string_types) for f in instance._fields)) 158 159 160class _TupleCodec(object): 161 """Codec for tuples.""" 162 163 def can_encode(self, pyobj): 164 return _is_tuple(pyobj) 165 166 def do_encode(self, tuple_value, encode_fn): 167 encoded_tuple = struct_pb2.StructuredValue() 168 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue()) 169 for element in tuple_value: 170 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element)) 171 return encoded_tuple 172 173 def can_decode(self, value): 174 return value.HasField("tuple_value") 175 176 def do_decode(self, value, decode_fn): 177 return tuple(decode_fn(element) for element in value.tuple_value.values) 178 179 180StructureCoder.register_codec(_TupleCodec()) 181 182 183class _DictCodec(object): 184 """Codec for dicts.""" 185 186 def can_encode(self, pyobj): 187 return isinstance(pyobj, dict) 188 189 def do_encode(self, dict_value, encode_fn): 190 encoded_dict = struct_pb2.StructuredValue() 191 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue()) 192 for key, value in dict_value.items(): 193 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value)) 194 return encoded_dict 195 196 def can_decode(self, value): 197 return value.HasField("dict_value") 198 199 def do_decode(self, value, decode_fn): 200 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()} 201 202 203StructureCoder.register_codec(_DictCodec()) 204 205 206class _NamedTupleCodec(object): 207 """Codec for namedtuples. 208 209 Encoding and decoding a namedtuple reconstructs a namedtuple with a different 210 actual Python type, but with same `typename` and `fields`. 211 """ 212 213 def can_encode(self, pyobj): 214 return _is_named_tuple(pyobj) 215 216 def do_encode(self, named_tuple_value, encode_fn): 217 encoded_named_tuple = struct_pb2.StructuredValue() 218 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) 219 encoded_named_tuple.named_tuple_value.name = \ 220 named_tuple_value.__class__.__name__ 221 for key in named_tuple_value._fields: 222 pair = encoded_named_tuple.named_tuple_value.values.add() 223 pair.key = key 224 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) 225 return encoded_named_tuple 226 227 def can_decode(self, value): 228 return value.HasField("named_tuple_value") 229 230 def do_decode(self, value, decode_fn): 231 key_value_pairs = value.named_tuple_value.values 232 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs] 233 named_tuple_type = collections.namedtuple(value.named_tuple_value.name, 234 [item[0] for item in items]) 235 return named_tuple_type(**dict(items)) 236 237 238StructureCoder.register_codec(_NamedTupleCodec()) 239 240 241class _Float64Codec(object): 242 """Codec for floats.""" 243 244 def can_encode(self, pyobj): 245 return isinstance(pyobj, float) 246 247 def do_encode(self, float64_value, encode_fn): 248 del encode_fn 249 value = struct_pb2.StructuredValue() 250 value.float64_value = float64_value 251 return value 252 253 def can_decode(self, value): 254 return value.HasField("float64_value") 255 256 def do_decode(self, value, decode_fn): 257 del decode_fn 258 return value.float64_value 259 260 261StructureCoder.register_codec(_Float64Codec()) 262 263 264class _Int64Codec(object): 265 """Codec for Python integers (limited to 64 bit values).""" 266 267 def can_encode(self, pyobj): 268 return not isinstance(pyobj, bool) and isinstance(pyobj, int) 269 270 def do_encode(self, int_value, encode_fn): 271 del encode_fn 272 value = struct_pb2.StructuredValue() 273 value.int64_value = int_value 274 return value 275 276 def can_decode(self, value): 277 return value.HasField("int64_value") 278 279 def do_decode(self, value, decode_fn): 280 del decode_fn 281 return int(value.int64_value) 282 283 284StructureCoder.register_codec(_Int64Codec()) 285 286 287class _StringCodec(object): 288 """Codec for strings. 289 290 See StructuredValue.string_value in proto/struct.proto for more detailed 291 explanation. 292 """ 293 294 def can_encode(self, pyobj): 295 return isinstance(pyobj, str) 296 297 def do_encode(self, string_value, encode_fn): 298 del encode_fn 299 value = struct_pb2.StructuredValue() 300 value.string_value = string_value 301 return value 302 303 def can_decode(self, value): 304 return value.HasField("string_value") 305 306 def do_decode(self, value, decode_fn): 307 del decode_fn 308 return compat.as_str(value.string_value) 309 310 311StructureCoder.register_codec(_StringCodec()) 312 313 314class _NoneCodec(object): 315 """Codec for None.""" 316 317 def can_encode(self, pyobj): 318 return pyobj is None 319 320 def do_encode(self, none_value, encode_fn): 321 del encode_fn, none_value 322 value = struct_pb2.StructuredValue() 323 value.none_value.CopyFrom(struct_pb2.NoneValue()) 324 return value 325 326 def can_decode(self, value): 327 return value.HasField("none_value") 328 329 def do_decode(self, value, decode_fn): 330 del decode_fn, value 331 return None 332 333 334StructureCoder.register_codec(_NoneCodec()) 335 336 337class _BoolCodec(object): 338 """Codec for booleans.""" 339 340 def can_encode(self, pyobj): 341 return isinstance(pyobj, bool) 342 343 def do_encode(self, bool_value, encode_fn): 344 del encode_fn 345 value = struct_pb2.StructuredValue() 346 value.bool_value = bool_value 347 return value 348 349 def can_decode(self, value): 350 return value.HasField("bool_value") 351 352 def do_decode(self, value, decode_fn): 353 del decode_fn 354 return value.bool_value 355 356 357StructureCoder.register_codec(_BoolCodec()) 358 359 360class _TensorShapeCodec(object): 361 """Codec for `TensorShape`.""" 362 363 def can_encode(self, pyobj): 364 return isinstance(pyobj, tensor_shape.TensorShape) 365 366 def do_encode(self, tensor_shape_value, encode_fn): 367 del encode_fn 368 encoded_tensor_shape = struct_pb2.StructuredValue() 369 encoded_tensor_shape.tensor_shape_value.CopyFrom( 370 tensor_shape_value.as_proto()) 371 return encoded_tensor_shape 372 373 def can_decode(self, value): 374 return value.HasField("tensor_shape_value") 375 376 def do_decode(self, value, decode_fn): 377 del decode_fn 378 return tensor_shape.TensorShape(value.tensor_shape_value) 379 380 381StructureCoder.register_codec(_TensorShapeCodec()) 382 383 384class _TensorTypeCodec(object): 385 """Codec for `TensorType`.""" 386 387 def can_encode(self, pyobj): 388 return isinstance(pyobj, dtypes.DType) 389 390 def do_encode(self, tensor_dtype_value, encode_fn): 391 del encode_fn 392 encoded_tensor_type = struct_pb2.StructuredValue() 393 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum 394 return encoded_tensor_type 395 396 def can_decode(self, value): 397 return value.HasField("tensor_dtype_value") 398 399 def do_decode(self, value, decode_fn): 400 del decode_fn 401 return dtypes.DType(value.tensor_dtype_value) 402 403 404StructureCoder.register_codec(_TensorTypeCodec()) 405 406 407class _TensorSpecCodec(object): 408 """Codec for `TensorSpec`.""" 409 410 def can_encode(self, pyobj): 411 return isinstance(pyobj, tensor_spec.TensorSpec) 412 413 def do_encode(self, tensor_spec_value, encode_fn): 414 encoded_tensor_spec = struct_pb2.StructuredValue() 415 encoded_tensor_spec.tensor_spec_value.CopyFrom( 416 struct_pb2.TensorSpecProto( 417 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, 418 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, 419 name=tensor_spec_value.name)) 420 return encoded_tensor_spec 421 422 def can_decode(self, value): 423 return value.HasField("tensor_spec_value") 424 425 def do_decode(self, value, decode_fn): 426 return tensor_spec.TensorSpec( 427 shape=decode_fn( 428 struct_pb2.StructuredValue( 429 tensor_shape_value=value.tensor_spec_value.shape)), 430 dtype=decode_fn( 431 struct_pb2.StructuredValue( 432 tensor_dtype_value=value.tensor_spec_value.dtype)), 433 name=value.tensor_spec_value.name) 434 435 436StructureCoder.register_codec(_TensorSpecCodec()) 437