1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.core.framework import types_pb2
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.core.protobuf import struct_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.saved_model import constants
34from tensorflow.python.saved_model import nested_structure_coder
35from tensorflow.python.util import compat
36from tensorflow.python.util import deprecation
37from tensorflow.python.util import nest
38from tensorflow.python.util.tf_export import tf_export
39
40
41# TensorInfo helpers.
42
43
44@tf_export(v1=["saved_model.build_tensor_info",
45               "saved_model.utils.build_tensor_info"])
46@deprecation.deprecated(
47    None,
48    "This function will only be available through the v1 compatibility "
49    "library as tf.compat.v1.saved_model.utils.build_tensor_info or "
50    "tf.compat.v1.saved_model.build_tensor_info.")
51def build_tensor_info(tensor):
52  """Utility function to build TensorInfo proto from a Tensor.
53
54  Args:
55    tensor: Tensor or SparseTensor whose name, dtype and shape are used to
56        build the TensorInfo. For SparseTensors, the names of the three
57        constituent Tensors are used.
58
59  Returns:
60    A TensorInfo protocol buffer constructed based on the supplied argument.
61
62  Raises:
63    RuntimeError: If eager execution is enabled.
64  """
65  if context.executing_eagerly():
66    raise RuntimeError("build_tensor_info is not supported in Eager mode.")
67  return build_tensor_info_internal(tensor)
68
69
70def build_tensor_info_internal(tensor):
71  """Utility function to build TensorInfo proto from a Tensor."""
72  if (isinstance(tensor, composite_tensor.CompositeTensor) and
73      not isinstance(tensor, sparse_tensor.SparseTensor)):
74    return _build_composite_tensor_info_internal(tensor)
75
76  tensor_info = meta_graph_pb2.TensorInfo(
77      dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
78      tensor_shape=tensor.get_shape().as_proto())
79  if isinstance(tensor, sparse_tensor.SparseTensor):
80    tensor_info.coo_sparse.values_tensor_name = tensor.values.name
81    tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name
82    tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name
83  else:
84    tensor_info.name = tensor.name
85  return tensor_info
86
87
88def _build_composite_tensor_info_internal(tensor):
89  """Utility function to build TensorInfo proto from a CompositeTensor."""
90  spec = tensor._type_spec  # pylint: disable=protected-access
91  tensor_info = meta_graph_pb2.TensorInfo()
92  struct_coder = nested_structure_coder.StructureCoder()
93  spec_proto = struct_coder.encode_structure(spec)
94  tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
95  for component in nest.flatten(tensor, expand_composites=True):
96    tensor_info.composite_tensor.components.add().CopyFrom(
97        build_tensor_info_internal(component))
98  return tensor_info
99
100
101def build_tensor_info_from_op(op):
102  """Utility function to build TensorInfo proto from an Op.
103
104  Note that this function should be used with caution. It is strictly restricted
105  to TensorFlow internal use-cases only. Please make sure you do need it before
106  using it.
107
108  This utility function overloads the TensorInfo proto by setting the name to
109  the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
110  is for the Op of the call site for the defunned function:
111  ```python
112    @function.defun
113    def some_variable_initialization_fn(value_a, value_b):
114      a = value_a
115      b = value_b
116
117    value_a = constant_op.constant(1, name="a")
118    value_b = constant_op.constant(2, name="b")
119    op_info = utils.build_op_info(
120        some_variable_initialization_fn(value_a, value_b))
121  ```
122
123  Args:
124    op: An Op whose name is used to build the TensorInfo. The name that points
125        to the Op could be fetched at run time in the Loader session.
126
127  Returns:
128    A TensorInfo protocol buffer constructed based on the supplied argument.
129
130  Raises:
131    RuntimeError: If eager execution is enabled.
132  """
133  if context.executing_eagerly():
134    raise RuntimeError(
135        "build_tensor_info_from_op is not supported in Eager mode.")
136  return meta_graph_pb2.TensorInfo(
137      dtype=types_pb2.DT_INVALID,
138      tensor_shape=tensor_shape.unknown_shape().as_proto(),
139      name=op.name)
140
141
142@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
143               "saved_model.utils.get_tensor_from_tensor_info"])
144@deprecation.deprecated(
145    None,
146    "This function will only be available through the v1 compatibility "
147    "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
148    "tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
149def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
150  """Returns the Tensor or CompositeTensor described by a TensorInfo proto.
151
152  Args:
153    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
154      CompositeTensor.
155    graph: The tf.Graph in which tensors are looked up. If None, the
156        current default graph is used.
157    import_scope: If not None, names in `tensor_info` are prefixed with this
158        string before lookup.
159
160  Returns:
161    The Tensor or SparseTensor or CompositeTensor in `graph` described by
162    `tensor_info`.
163
164  Raises:
165    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
166    ValueError: If `tensor_info` is malformed.
167  """
168  graph = graph or ops.get_default_graph()
169  def _get_tensor(name):
170    return graph.get_tensor_by_name(
171        ops.prepend_name_scope(name, import_scope=import_scope))
172  encoding = tensor_info.WhichOneof("encoding")
173  if encoding == "name":
174    return _get_tensor(tensor_info.name)
175  elif encoding == "coo_sparse":
176    return sparse_tensor.SparseTensor(
177        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
178        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
179        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
180  elif encoding == "composite_tensor":
181    struct_coder = nested_structure_coder.StructureCoder()
182    spec_proto = struct_pb2.StructuredValue(
183        type_spec_value=tensor_info.composite_tensor.type_spec)
184    spec = struct_coder.decode_proto(spec_proto)
185    components = [_get_tensor(component.name) for component in
186                  tensor_info.composite_tensor.components]
187    return nest.pack_sequence_as(spec, components, expand_composites=True)
188  else:
189    raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
190
191
192def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
193  """Returns the element in the graph described by a TensorInfo proto.
194
195  Args:
196    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
197    graph: The tf.Graph in which tensors are looked up. If None, the current
198      default graph is used.
199    import_scope: If not None, names in `tensor_info` are prefixed with this
200      string before lookup.
201
202  Returns:
203    Op or tensor in `graph` described by `tensor_info`.
204
205  Raises:
206    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
207  """
208  graph = graph or ops.get_default_graph()
209  return graph.as_graph_element(
210      ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
211
212
213# Path helpers.
214
215
216def get_or_create_variables_dir(export_dir):
217  """Return variables sub-directory, or create one if it doesn't exist."""
218  variables_dir = get_variables_dir(export_dir)
219  if not file_io.file_exists(variables_dir):
220    file_io.recursive_create_dir(variables_dir)
221  return variables_dir
222
223
224def get_variables_dir(export_dir):
225  """Return variables sub-directory in the SavedModel."""
226  return os.path.join(
227      compat.as_text(export_dir),
228      compat.as_text(constants.VARIABLES_DIRECTORY))
229
230
231def get_variables_path(export_dir):
232  """Return the variables path, used as the prefix for checkpoint files."""
233  return os.path.join(
234      compat.as_text(get_variables_dir(export_dir)),
235      compat.as_text(constants.VARIABLES_FILENAME))
236
237
238def get_or_create_assets_dir(export_dir):
239  """Return assets sub-directory, or create one if it doesn't exist."""
240  assets_destination_dir = get_assets_dir(export_dir)
241
242  if not file_io.file_exists(assets_destination_dir):
243    file_io.recursive_create_dir(assets_destination_dir)
244
245  return assets_destination_dir
246
247
248def get_assets_dir(export_dir):
249  """Return path to asset directory in the SavedModel."""
250  return os.path.join(
251      compat.as_text(export_dir),
252      compat.as_text(constants.ASSETS_DIRECTORY))
253
254
255def get_or_create_debug_dir(export_dir):
256  """Returns path to the debug sub-directory, creating if it does not exist."""
257  debug_dir = get_debug_dir(export_dir)
258
259  if not file_io.file_exists(debug_dir):
260    file_io.recursive_create_dir(debug_dir)
261
262  return debug_dir
263
264
265def get_saved_model_pbtxt_path(export_dir):
266  return os.path.join(
267      compat.as_bytes(compat.path_to_str(export_dir)),
268      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
269
270
271def get_saved_model_pb_path(export_dir):
272  return os.path.join(
273      compat.as_bytes(compat.path_to_str(export_dir)),
274      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
275
276
277def get_debug_dir(export_dir):
278  """Returns path to the debug sub-directory in the SavedModel."""
279  return os.path.join(
280      compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
281