1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""TensorSignature class and utilities (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import collections
28
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import parsing_ops
35
36
37class TensorSignature(collections.namedtuple(
38    "TensorSignature", ["dtype", "shape", "is_sparse"])):
39  """Signature of the `Tensor` object.
40
41  THIS CLASS IS DEPRECATED. See
42  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
43  for general migration instructions.
44
45  Useful to check compatibility of tensors.
46
47  Example:
48
49  ```python
50  examples = tf.placeholder(...)
51  inputs = {'a': var_a, 'b': var_b}
52  signatures = tensor_signature.create_signatures(inputs)
53  result = tensor_signature.create_example_parser_from_signatures(
54      signatures, examples)
55  self.assertTrue(tensor_signature.tensors_compatible(result, signatures))
56  ```
57
58  Attributes:
59    dtype: `DType` object.
60    shape: `TensorShape` object.
61  """
62
63  def __new__(cls, tensor):
64    if isinstance(tensor, sparse_tensor.SparseTensor):
65      return super(TensorSignature, cls).__new__(
66          cls, dtype=tensor.values.dtype, shape=None, is_sparse=True)
67    return super(TensorSignature, cls).__new__(
68        cls, dtype=tensor.dtype, shape=tensor.get_shape(), is_sparse=False)
69
70  def is_compatible_with(self, other):
71    """Returns True if signatures are compatible."""
72
73    def _shape_is_compatible_0dim(this, other):
74      """Checks that shapes are compatible skipping dim 0."""
75      other = tensor_shape.as_shape(other)
76      # If shapes are None (unknown) they may be compatible.
77      if this.dims is None or other.dims is None:
78        return True
79      if this.ndims != other.ndims:
80        return False
81      for dim, (x_dim, y_dim) in enumerate(zip(this.dims, other.dims)):
82        if dim == 0:
83          continue
84        if not x_dim.is_compatible_with(y_dim):
85          return False
86      return True
87
88    if other.is_sparse:
89      return self.is_sparse and self.dtype.is_compatible_with(other.dtype)
90    return (self.dtype.is_compatible_with(other.dtype) and
91            _shape_is_compatible_0dim(self.shape, other.shape) and
92            not self.is_sparse)
93
94  def get_placeholder(self):
95    if self.is_sparse:
96      return array_ops.sparse_placeholder(dtype=self.dtype)
97    return array_ops.placeholder(dtype=self.dtype,
98                                 shape=[None] + list(self.shape[1:]))
99
100  def get_feature_spec(self):
101    dtype = self.dtype
102    # Convert, because example parser only supports float32, int64 and string.
103    if dtype == dtypes.int32:
104      dtype = dtypes.int64
105    if dtype == dtypes.float64:
106      dtype = dtypes.float32
107    if self.is_sparse:
108      return parsing_ops.VarLenFeature(dtype=dtype)
109    return parsing_ops.FixedLenFeature(shape=self.shape[1:], dtype=dtype)
110
111
112def tensors_compatible(tensors, signatures):
113  """Check that tensors are compatible with signatures.
114
115  Args:
116    tensors: Dict of `Tensor` objects or single `Tensor` object.
117    signatures: Dict of `TensorSignature` objects or
118                single `TensorSignature` object.
119
120  Returns:
121    True if all tensors are compatible, False otherwise.
122  """
123  # Dict of Tensors as input.
124  if tensors is None:
125    return signatures is None
126
127  if isinstance(tensors, dict):
128    if not isinstance(signatures, dict):
129      return False
130    for key in signatures:
131      if key not in tensors:
132        return False
133      if not TensorSignature(tensors[key]).is_compatible_with(signatures[key]):
134        return False
135    return True
136
137  # Single tensor as input.
138  if signatures is None or isinstance(signatures, dict):
139    return False
140  return TensorSignature(tensors).is_compatible_with(signatures)
141
142
143def create_signatures(tensors):
144  """Creates TensorSignature objects for given tensors.
145
146  Args:
147    tensors: Dict of `Tensor` objects or single `Tensor`.
148
149  Returns:
150    Dict of `TensorSignature` objects or single `TensorSignature`.
151  """
152  if isinstance(tensors, dict):
153    return {
154        key: TensorSignature(tensors[key]) for key in tensors}
155  if tensors is None:
156    return None
157  return TensorSignature(tensors)
158
159
160def create_placeholders_from_signatures(signatures):
161  """Creates placeholders from given signatures.
162
163  Args:
164    signatures: Dict of `TensorSignature` objects or single `TensorSignature`,
165      or `None`.
166
167  Returns:
168    Dict of `tf.placeholder` objects or single `tf.placeholder`, or `None`.
169  """
170  if signatures is None:
171    return None
172  if not isinstance(signatures, dict):
173    return signatures.get_placeholder()
174  return {
175      key: signatures[key].get_placeholder()
176      for key in signatures}
177
178
179def create_example_parser_from_signatures(signatures, examples_batch,
180                                          single_feature_name="feature"):
181  """Creates example parser from given signatures.
182
183  Args:
184    signatures: Dict of `TensorSignature` objects or single `TensorSignature`.
185    examples_batch: string `Tensor` of serialized `Example` proto.
186    single_feature_name: string, single feature name.
187
188  Returns:
189    features: `Tensor` or `dict` of `Tensor` objects.
190  """
191  feature_spec = {}
192  if not isinstance(signatures, dict):
193    feature_spec[single_feature_name] = signatures.get_feature_spec()
194  else:
195    feature_spec = {key: signatures[key].get_feature_spec()
196                    for key in signatures}
197  features = parsing_ops.parse_example(examples_batch, feature_spec)
198  if not isinstance(signatures, dict):
199    # Returns single feature, casts if needed.
200    features = features[single_feature_name]
201    if not signatures.dtype.is_compatible_with(features.dtype):
202      features = math_ops.cast(features, signatures.dtype)
203    return features
204  # Returns dict of features, casts if needed.
205  for name in features:
206    if not signatures[name].dtype.is_compatible_with(features[name].dtype):
207      features[name] = math_ops.cast(features[name], signatures[name].dtype)
208  return features
209