1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.core.framework import types_pb2
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.saved_model import constants
32from tensorflow.python.util import compat
33from tensorflow.python.util import deprecation
34from tensorflow.python.util.tf_export import tf_export
35
36
37# TensorInfo helpers.
38
39
40@tf_export(v1=["saved_model.build_tensor_info",
41               "saved_model.utils.build_tensor_info"])
42@deprecation.deprecated(
43    None,
44    "This function will only be available through the v1 compatibility "
45    "library as tf.compat.v1.saved_model.utils.build_tensor_info or "
46    "tf.compat.v1.saved_model.build_tensor_info.")
47def build_tensor_info(tensor):
48  """Utility function to build TensorInfo proto from a Tensor.
49
50  Args:
51    tensor: Tensor or SparseTensor whose name, dtype and shape are used to
52        build the TensorInfo. For SparseTensors, the names of the three
53        constituent Tensors are used.
54
55  Returns:
56    A TensorInfo protocol buffer constructed based on the supplied argument.
57
58  Raises:
59    RuntimeError: If eager execution is enabled.
60  """
61  if context.executing_eagerly():
62    raise RuntimeError("build_tensor_info is not supported in Eager mode.")
63  return build_tensor_info_internal(tensor)
64
65
66def build_tensor_info_internal(tensor):
67  """Utility function to build TensorInfo proto from a Tensor."""
68  tensor_info = meta_graph_pb2.TensorInfo(
69      dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
70      tensor_shape=tensor.get_shape().as_proto())
71  if isinstance(tensor, sparse_tensor.SparseTensor):
72    tensor_info.coo_sparse.values_tensor_name = tensor.values.name
73    tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name
74    tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name
75  else:
76    tensor_info.name = tensor.name
77  return tensor_info
78
79
80def build_tensor_info_from_op(op):
81  """Utility function to build TensorInfo proto from an Op.
82
83  Note that this function should be used with caution. It is strictly restricted
84  to TensorFlow internal use-cases only. Please make sure you do need it before
85  using it.
86
87  This utility function overloads the TensorInfo proto by setting the name to
88  the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
89  is for the Op of the call site for the defunned function:
90  ```python
91    @function.defun
92    def some_vairable_initialiation_fn(value_a, value_b):
93      a = value_a
94      b = value_b
95
96    value_a = constant_op.constant(1, name="a")
97    value_b = constant_op.constant(2, name="b")
98    op_info = utils.build_op_info(
99        some_vairable_initialiation_fn(value_a, value_b))
100  ```
101
102  Args:
103    op: An Op whose name is used to build the TensorInfo. The name that points
104        to the Op could be fetched at run time in the Loader session.
105
106  Returns:
107    A TensorInfo protocol buffer constructed based on the supplied argument.
108  """
109  return meta_graph_pb2.TensorInfo(
110      dtype=types_pb2.DT_INVALID,
111      tensor_shape=tensor_shape.unknown_shape().as_proto(),
112      name=op.name)
113
114
115@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
116               "saved_model.utils.get_tensor_from_tensor_info"])
117@deprecation.deprecated(
118    None,
119    "This function will only be available through the v1 compatibility "
120    "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
121    "tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
122def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
123  """Returns the Tensor or SparseTensor described by a TensorInfo proto.
124
125  Args:
126    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor.
127    graph: The tf.Graph in which tensors are looked up. If None, the
128        current default graph is used.
129    import_scope: If not None, names in `tensor_info` are prefixed with this
130        string before lookup.
131
132  Returns:
133    The Tensor or SparseTensor in `graph` described by `tensor_info`.
134
135  Raises:
136    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
137    ValueError: If `tensor_info` is malformed.
138  """
139  graph = graph or ops.get_default_graph()
140  def _get_tensor(name):
141    return graph.get_tensor_by_name(
142        ops.prepend_name_scope(name, import_scope=import_scope))
143  encoding = tensor_info.WhichOneof("encoding")
144  if encoding == "name":
145    return _get_tensor(tensor_info.name)
146  elif encoding == "coo_sparse":
147    return sparse_tensor.SparseTensor(
148        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
149        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
150        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
151  else:
152    raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
153
154
155def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
156  """Returns the element in the graph described by a TensorInfo proto.
157
158  Args:
159    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
160    graph: The tf.Graph in which tensors are looked up. If None, the current
161      default graph is used.
162    import_scope: If not None, names in `tensor_info` are prefixed with this
163      string before lookup.
164
165  Returns:
166    Op or tensor in `graph` described by `tensor_info`.
167
168  Raises:
169    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
170  """
171  graph = graph or ops.get_default_graph()
172  return graph.as_graph_element(
173      ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
174
175
176# Path helpers.
177
178
179def get_or_create_variables_dir(export_dir):
180  """Return variables sub-directory, or create one if it doesn't exist."""
181  variables_dir = get_variables_dir(export_dir)
182  if not file_io.file_exists(variables_dir):
183    file_io.recursive_create_dir(variables_dir)
184  return variables_dir
185
186
187def get_variables_dir(export_dir):
188  """Return variables sub-directory in the SavedModel."""
189  return os.path.join(
190      compat.as_text(export_dir),
191      compat.as_text(constants.VARIABLES_DIRECTORY))
192
193
194def get_variables_path(export_dir):
195  """Return the variables path, used as the prefix for checkpoint files."""
196  return os.path.join(
197      compat.as_text(get_variables_dir(export_dir)),
198      compat.as_text(constants.VARIABLES_FILENAME))
199
200
201def get_or_create_assets_dir(export_dir):
202  """Return assets sub-directory, or create one if it doesn't exist."""
203  assets_destination_dir = get_assets_dir(export_dir)
204
205  if not file_io.file_exists(assets_destination_dir):
206    file_io.recursive_create_dir(assets_destination_dir)
207
208  return assets_destination_dir
209
210
211def get_assets_dir(export_dir):
212  """Return path to asset directory in the SavedModel."""
213  return os.path.join(
214      compat.as_text(export_dir),
215      compat.as_text(constants.ASSETS_DIRECTORY))
216