1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loader implementation for SavedModel with hermetic, language-neutral exports.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from google.protobuf import message
25from google.protobuf import text_format
26
27from tensorflow.core.protobuf import meta_graph_pb2
28from tensorflow.core.protobuf import saved_model_pb2
29from tensorflow.python.framework import ops
30from tensorflow.python.lib.io import file_io
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import tf_logging
33from tensorflow.python.saved_model import constants
34from tensorflow.python.saved_model import signature_def_utils
35from tensorflow.python.saved_model import utils_impl as saved_model_utils
36from tensorflow.python.training import saver as tf_saver
37from tensorflow.python.util import compat
38from tensorflow.python.util import deprecation
39from tensorflow.python.util.tf_export import tf_export
40
41
42def parse_saved_model(export_dir):
43  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
44
45  Args:
46    export_dir: Directory containing the SavedModel file.
47
48  Returns:
49    A `SavedModel` protocol buffer.
50
51  Raises:
52    IOError: If the file does not exist, or cannot be successfully parsed.
53  """
54  # Build the path to the SavedModel in pbtxt format.
55  path_to_pbtxt = os.path.join(
56      compat.as_bytes(export_dir),
57      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
58  # Build the path to the SavedModel in pb format.
59  path_to_pb = os.path.join(
60      compat.as_bytes(export_dir),
61      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
62
63  # Parse the SavedModel protocol buffer.
64  saved_model = saved_model_pb2.SavedModel()
65  if file_io.file_exists(path_to_pb):
66    try:
67      file_content = file_io.FileIO(path_to_pb, "rb").read()
68      saved_model.ParseFromString(file_content)
69      return saved_model
70    except message.DecodeError as e:
71      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
72  elif file_io.file_exists(path_to_pbtxt):
73    try:
74      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
75      text_format.Merge(file_content.decode("utf-8"), saved_model)
76      return saved_model
77    except text_format.ParseError as e:
78      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
79  else:
80    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
81                  (export_dir,
82                   constants.SAVED_MODEL_FILENAME_PBTXT,
83                   constants.SAVED_MODEL_FILENAME_PB))
84
85
86# TODO(b/120594573): Make this symbol also available as private, so that
87# tensorflow_transform and tensorflow_estimator do not break.
88_parse_saved_model = parse_saved_model
89
90
91def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
92  """Gets the asset tensors, if defined in the meta graph def to load.
93
94  Args:
95    export_dir: Directory where the SavedModel is located.
96    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
97    import_scope: Optional `string` -- if specified, prepend this followed by
98        '/' to all returned asset tensor names.
99
100  Returns:
101    A dictionary of asset tensors, keyed by the name of the asset tensor. The
102    value in the map corresponds to the absolute path of the asset file.
103  """
104  # Collection-def that may contain the assets key.
105  collection_def = meta_graph_def_to_load.collection_def
106
107  asset_tensor_dict = {}
108  asset_protos = []
109
110  if meta_graph_def_to_load.asset_file_def:
111    asset_protos = meta_graph_def_to_load.asset_file_def
112  elif constants.ASSETS_KEY in collection_def:
113    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
114    for asset_any_proto in assets_any_proto:
115      asset_proto = meta_graph_pb2.AssetFileDef()
116      asset_any_proto.Unpack(asset_proto)
117      asset_protos.append(asset_proto)
118
119  # Location of the assets for SavedModel.
120  assets_directory = os.path.join(
121      compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
122  # Process each asset and add it to the asset tensor dictionary.
123  for asset_proto in asset_protos:
124    tensor_name = asset_proto.tensor_info.name
125    if import_scope:
126      tensor_name = "%s/%s" % (import_scope, tensor_name)
127    asset_tensor_dict[tensor_name] = os.path.join(
128        compat.as_bytes(assets_directory),
129        compat.as_bytes(asset_proto.filename))
130
131  return asset_tensor_dict
132
133
134def _get_main_op_tensor(
135    meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
136  """Gets the main op tensor, if one exists.
137
138  Args:
139    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
140    init_op_key: name of collection to check; should be one of MAIN_OP_KEY
141      or the deprecated LEGACY_INIT_OP_KEY
142
143  Returns:
144    The main op tensor, if it exists and `None` otherwise.
145
146  Raises:
147    RuntimeError: If the collection def corresponding to the main op key has
148        other than exactly one tensor.
149  """
150  # TODO(kathywu): Rename this method to _get_op_from_collection when
151  # dependency from SavedModelEstimator is removed.
152  collection_def = meta_graph_def_to_load.collection_def
153  init_op = None
154  if init_op_key in collection_def:
155    init_op_list = collection_def[init_op_key].node_list.value
156    if len(init_op_list) != 1:
157      raise RuntimeError("Expected exactly one SavedModel init op. "
158                         "Found: {}".format(init_op_list))
159    init_op = ops.get_collection(init_op_key)[0]
160  return init_op
161
162
163def _get_op_from_collection(meta_graph_def, op_key):
164  return _get_main_op_tensor(meta_graph_def, op_key)
165
166
167def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
168  """Retrieve op stored in the imported meta graph's signature def."""
169  if op_signature_key in meta_graph_def.signature_def:
170    return signature_def_utils.load_op_from_signature_def(
171        meta_graph_def.signature_def[op_signature_key], op_signature_key,
172        import_scope)
173  else:
174    return None
175
176
177def get_init_op(meta_graph_def, import_scope=None):
178  return (_get_op_from_signature_def(
179      meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
180          _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
181          _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
182
183
184def get_train_op(meta_graph_def, import_scope=None):
185  train_op = _get_op_from_signature_def(
186      meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
187  if train_op is None:
188    train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
189  return train_op
190
191
192@tf_export(v1=[
193    "saved_model.contains_saved_model",
194    "saved_model.maybe_saved_model_directory",
195    "saved_model.loader.maybe_saved_model_directory"
196])
197@deprecation.deprecated_endpoints(
198    "saved_model.loader.maybe_saved_model_directory")
199def maybe_saved_model_directory(export_dir):
200  """Checks whether the provided export directory could contain a SavedModel.
201
202  Note that the method does not load any data by itself. If the method returns
203  `false`, the export directory definitely does not contain a SavedModel. If the
204  method returns `true`, the export directory may contain a SavedModel but
205  provides no guarantee that it can be loaded.
206
207  Args:
208    export_dir: Absolute string path to possible export location. For example,
209                '/my/foo/model'.
210
211  Returns:
212    True if the export directory contains SavedModel files, False otherwise.
213  """
214  txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
215  pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
216  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
217
218
219@tf_export("saved_model.contains_saved_model", v1=[])
220def contains_saved_model(export_dir):
221  """Checks whether the provided export directory could contain a SavedModel.
222
223  Note that the method does not load any data by itself. If the method returns
224  `false`, the export directory definitely does not contain a SavedModel. If the
225  method returns `true`, the export directory may contain a SavedModel but
226  provides no guarantee that it can be loaded.
227
228  Args:
229    export_dir: Absolute string path to possible export location. For example,
230                '/my/foo/model'.
231
232  Returns:
233    True if the export directory contains SavedModel files, False otherwise.
234  """
235  return maybe_saved_model_directory(export_dir)
236
237
238@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
239@deprecation.deprecated(
240    None,
241    "This function will only be available through the v1 compatibility "
242    "library as tf.compat.v1.saved_model.loader.load or "
243    "tf.compat.v1.saved_model.load. There will be a new function for importing "
244    "SavedModels in Tensorflow 2.0.")
245def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
246  """Loads the model from a SavedModel as specified by tags.
247
248  Args:
249    sess: The TensorFlow session to restore the variables.
250    tags: Set of string tags to identify the required MetaGraphDef. These should
251        correspond to the tags used when saving the variables using the
252        SavedModel `save()` API.
253    export_dir: Directory in which the SavedModel protocol buffer and variables
254        to be loaded are located.
255    import_scope: Optional `string` -- if specified, prepend this string
256        followed by '/' to all loaded tensor names. This scope is applied to
257        tensor instances loaded into the passed session, but it is *not* written
258        through to the static `MetaGraphDef` protocol buffer that is returned.
259    **saver_kwargs: Optional keyword arguments passed through to Saver.
260
261  Returns:
262    The `MetaGraphDef` protocol buffer loaded in the provided session. This
263    can be used to further extract signature-defs, collection-defs, etc.
264
265  Raises:
266    RuntimeError: MetaGraphDef associated with the tags cannot be found.
267  """
268  loader = SavedModelLoader(export_dir)
269  return loader.load(sess, tags, import_scope, **saver_kwargs)
270
271
272class SavedModelLoader(object):
273  """Load graphs and restore variable values from a `SavedModel`."""
274
275  def __init__(self, export_dir):
276    """Creates a `SavedModelLoader`.
277
278    Args:
279      export_dir: Directory in which the SavedModel protocol buffer and
280        variables to be loaded are located.
281    """
282    self._export_dir = export_dir
283    self._variables_path = saved_model_utils.get_variables_path(export_dir)
284    self._saved_model = parse_saved_model(export_dir)
285
286  @property
287  def export_dir(self):
288    """Directory containing the SavedModel."""
289    return self._export_dir
290
291  @property
292  def variables_path(self):
293    """Path to variable checkpoint files."""
294    return self._variables_path
295
296  @property
297  def saved_model(self):
298    """SavedModel object parsed from the export directory."""
299    return self._saved_model
300
301  def get_meta_graph_def_from_tags(self, tags):
302    """Return MetaGraphDef with the exact specified tags.
303
304    Args:
305      tags: A list or set of string tags that identify the MetaGraphDef.
306
307    Returns:
308      MetaGraphDef with the same tags.
309
310    Raises:
311      RuntimeError: if no metagraphs were found with the associated tags.
312    """
313    found_match = False
314    for meta_graph_def in self._saved_model.meta_graphs:
315      if set(meta_graph_def.meta_info_def.tags) == set(tags):
316        meta_graph_def_to_load = meta_graph_def
317        found_match = True
318        break
319
320    if not found_match:
321      raise RuntimeError(
322          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
323          " could not be found in SavedModel. To inspect available tag-sets in"
324          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
325      )
326    return meta_graph_def_to_load
327
328  def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
329    """Load ops and nodes from SavedModel MetaGraph into graph.
330
331    Args:
332      graph: tf.Graph object.
333      tags: a set of string tags identifying a MetaGraphDef.
334      import_scope: Optional `string` -- if specified, prepend this string
335        followed by '/' to all loaded tensor names. This scope is applied to
336        tensor instances loaded into the passed session, but it is *not* written
337        through to the static `MetaGraphDef` protocol buffer that is returned.
338      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
339
340    Returns:
341      A tuple of
342        * Saver defined by the MetaGraph, which can be used to restore the
343          variable values.
344        * List of `Operation`/`Tensor` objects returned from
345          `tf.import_graph_def` (may be `None`).
346    """
347    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
348    with graph.as_default():
349      return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
350          meta_graph_def, import_scope=import_scope, **saver_kwargs)
351
352  def restore_variables(self, sess, saver, import_scope=None):
353    """Restore SavedModel variable values into the session.
354
355    Args:
356      sess: tf.Session to restore variable values.
357      saver: a tf.train.Saver object. Can be None if there are no variables in
358        graph. This may be the saver returned by the load_graph() function, or a
359        default `tf.train.Saver()`.
360      import_scope: Optional `string` -- if specified, prepend this string
361        followed by '/' to all loaded tensor names. This scope is applied to
362        tensor instances loaded into the passed session, but it is *not* written
363        through to the static `MetaGraphDef` protocol buffer that is returned.
364
365    Raises:
366      ValueError: if no saver was passed to the saver argument, and there are
367        variables in the graph.
368    """
369    with sess.graph.as_default():
370      if (saver is None and
371          not variables._all_saveable_objects(scope=import_scope)):  # pylint: disable=protected-access
372        tf_logging.info("The specified SavedModel has no variables; no "
373                        "checkpoints were restored.")
374      elif isinstance(saver, tf_saver.Saver):
375        saver.restore(sess, self._variables_path)
376      else:
377        raise ValueError(
378            "No tf.train.Saver object was passed to the function "
379            "SavedModelLoader.restore_variables. Since there are variables in "
380            "the graph, a saver is required.")
381
382  def run_init_ops(self, sess, tags, import_scope=None):
383    """Run initialization ops defined in the `MetaGraphDef`.
384
385    Args:
386      sess: tf.Session to restore variable values.
387      tags: a set of string tags identifying a MetaGraphDef.
388      import_scope: Optional `string` -- if specified, prepend this string
389        followed by '/' to all loaded tensor names. This scope is applied to
390        tensor instances loaded into the passed session, but it is *not* written
391        through to the static `MetaGraphDef` protocol buffer that is returned.
392    """
393    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
394    with sess.graph.as_default():
395      # Get asset tensors, if any.
396      asset_tensors_dictionary = get_asset_tensors(
397          self._export_dir, meta_graph_def, import_scope=import_scope)
398
399      init_op = get_init_op(meta_graph_def, import_scope)
400      if init_op is not None:
401        sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
402
403  def load(self, sess, tags, import_scope=None, **saver_kwargs):
404    """Load the MetaGraphDef graph and restore variable values into the session.
405
406    Args:
407      sess: tf.Session to restore variable values.
408      tags: a set of string tags identifying a MetaGraphDef.
409      import_scope: Optional `string` -- if specified, prepend this string
410        followed by '/' to all loaded tensor names. This scope is applied to
411        tensor instances loaded into the passed session, but it is *not* written
412        through to the static `MetaGraphDef` protocol buffer that is returned.
413      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
414
415    Returns:
416      `MetagraphDef` proto of the graph that was loaded.
417    """
418    with sess.graph.as_default():
419      saver, _ = self.load_graph(sess.graph, tags, import_scope,
420                                 **saver_kwargs)
421      self.restore_variables(sess, saver, import_scope)
422      self.run_init_ops(sess, tags, import_scope)
423    return self.get_meta_graph_def_from_tags(tags)
424