1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loader implementation for SavedModel with hermetic, language-neutral exports.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from google.protobuf import message
25from google.protobuf import text_format
26
27from tensorflow.core.protobuf import graph_debug_info_pb2
28from tensorflow.core.protobuf import meta_graph_pb2
29from tensorflow.core.protobuf import saved_model_pb2
30from tensorflow.python.framework import ops
31from tensorflow.python.lib.io import file_io
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import tf_logging
34from tensorflow.python.saved_model import constants
35from tensorflow.python.saved_model import signature_def_utils
36from tensorflow.python.saved_model import utils_impl as saved_model_utils
37from tensorflow.python.training import saver as tf_saver
38from tensorflow.python.util import compat
39from tensorflow.python.util import deprecation
40from tensorflow.python.util.tf_export import tf_export
41
42
43def parse_saved_model_with_debug_info(export_dir):
44  """Reads the savedmodel as well as the graph debug info.
45
46  Args:
47    export_dir: Directory containing the SavedModel and GraphDebugInfo files.
48
49  Returns:
50    `SavedModel` and `GraphDebugInfo` protocol buffers.
51
52  Raises:
53    IOError: If the saved model file does not exist, or cannot be successfully
54    parsed. Missing graph debug info file is fine.
55  """
56  saved_model = _parse_saved_model(export_dir)
57
58  debug_info_path = os.path.join(
59      saved_model_utils.get_debug_dir(export_dir),
60      constants.DEBUG_INFO_FILENAME_PB)
61  debug_info = graph_debug_info_pb2.GraphDebugInfo()
62  if file_io.file_exists(debug_info_path):
63    with file_io.FileIO(debug_info_path, "rb") as debug_file:
64      try:
65        debug_info.ParseFromString(debug_file.read())
66      except message.DecodeError as e:
67        raise IOError("Cannot parse file %s: %s." % (debug_info_path, str(e)))
68
69  return (saved_model, debug_info)
70
71
72def parse_saved_model(export_dir):
73  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
74
75  Args:
76    export_dir: String or Pathlike, path to the directory containing the
77    SavedModel file.
78
79  Returns:
80    A `SavedModel` protocol buffer.
81
82  Raises:
83    IOError: If the file does not exist, or cannot be successfully parsed.
84  """
85  # Build the path to the SavedModel in pbtxt format.
86  path_to_pbtxt = os.path.join(
87      compat.as_bytes(compat.path_to_str(export_dir)),
88      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
89  # Build the path to the SavedModel in pb format.
90  path_to_pb = os.path.join(
91      compat.as_bytes(compat.path_to_str(export_dir)),
92      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
93
94  # Parse the SavedModel protocol buffer.
95  saved_model = saved_model_pb2.SavedModel()
96  if file_io.file_exists(path_to_pb):
97    try:
98      file_content = file_io.FileIO(path_to_pb, "rb").read()
99      saved_model.ParseFromString(file_content)
100      return saved_model
101    except message.DecodeError as e:
102      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
103  elif file_io.file_exists(path_to_pbtxt):
104    try:
105      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
106      text_format.Merge(file_content.decode("utf-8"), saved_model)
107      return saved_model
108    except text_format.ParseError as e:
109      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
110  else:
111    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
112                  (export_dir,
113                   constants.SAVED_MODEL_FILENAME_PBTXT,
114                   constants.SAVED_MODEL_FILENAME_PB))
115
116
117# TODO(b/120594573): Make this symbol also available as private, so that
118# tensorflow_transform and tensorflow_estimator do not break.
119_parse_saved_model = parse_saved_model
120
121
122def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
123  """Gets the asset tensors, if defined in the meta graph def to load.
124
125  Args:
126    export_dir: Directory where the SavedModel is located.
127    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
128    import_scope: Optional `string` -- if specified, prepend this followed by
129        '/' to all returned asset tensor names.
130
131  Returns:
132    A dictionary of asset tensors, keyed by the name of the asset tensor. The
133    value in the map corresponds to the absolute path of the asset file.
134  """
135  # Collection-def that may contain the assets key.
136  collection_def = meta_graph_def_to_load.collection_def
137
138  asset_tensor_dict = {}
139  asset_protos = []
140
141  if meta_graph_def_to_load.asset_file_def:
142    asset_protos = meta_graph_def_to_load.asset_file_def
143  elif constants.ASSETS_KEY in collection_def:
144    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
145    for asset_any_proto in assets_any_proto:
146      asset_proto = meta_graph_pb2.AssetFileDef()
147      asset_any_proto.Unpack(asset_proto)
148      asset_protos.append(asset_proto)
149
150  # Location of the assets for SavedModel.
151  assets_directory = os.path.join(
152      compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
153  # Process each asset and add it to the asset tensor dictionary.
154  for asset_proto in asset_protos:
155    tensor_name = asset_proto.tensor_info.name
156    if import_scope:
157      tensor_name = "%s/%s" % (import_scope, tensor_name)
158    asset_tensor_dict[tensor_name] = os.path.join(
159        compat.as_bytes(assets_directory),
160        compat.as_bytes(asset_proto.filename))
161
162  return asset_tensor_dict
163
164
165def _get_main_op_tensor(
166    meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
167  """Gets the main op tensor, if one exists.
168
169  Args:
170    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
171    init_op_key: name of the collection to check; should be one of MAIN_OP_KEY
172      or the deprecated LEGACY_INIT_OP_KEY
173
174  Returns:
175    The main op tensor, if it exists and `None` otherwise.
176
177  Raises:
178    RuntimeError: If the collection def corresponding to the main op key has
179        other than exactly one tensor.
180  """
181  # TODO(kathywu): Rename this method to _get_op_from_collection when
182  # dependency from SavedModelEstimator is removed.
183  collection_def = meta_graph_def_to_load.collection_def
184  init_op = None
185  if init_op_key in collection_def:
186    init_op_list = collection_def[init_op_key].node_list.value
187    if len(init_op_list) != 1:
188      raise RuntimeError("Expected exactly one SavedModel init op. "
189                         "Found: {}".format(init_op_list))
190    init_op = ops.get_collection(init_op_key)[0]
191  return init_op
192
193
194def _get_op_from_collection(meta_graph_def, op_key):
195  return _get_main_op_tensor(meta_graph_def, op_key)
196
197
198def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
199  """Retrieve op stored in the imported meta graph's signature def."""
200  if op_signature_key in meta_graph_def.signature_def:
201    return signature_def_utils.load_op_from_signature_def(
202        meta_graph_def.signature_def[op_signature_key], op_signature_key,
203        import_scope)
204  else:
205    return None
206
207
208def get_init_op(meta_graph_def, import_scope=None):
209  return (_get_op_from_signature_def(
210      meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
211          _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
212          _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
213
214
215def get_train_op(meta_graph_def, import_scope=None):
216  train_op = _get_op_from_signature_def(
217      meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
218  if train_op is None:
219    train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
220  return train_op
221
222
223@tf_export(v1=[
224    "saved_model.contains_saved_model",
225    "saved_model.maybe_saved_model_directory",
226    "saved_model.loader.maybe_saved_model_directory"
227])
228@deprecation.deprecated_endpoints(
229    "saved_model.loader.maybe_saved_model_directory")
230def maybe_saved_model_directory(export_dir):
231  """Checks whether the provided export directory could contain a SavedModel.
232
233  Note that the method does not load any data by itself. If the method returns
234  `false`, the export directory definitely does not contain a SavedModel. If the
235  method returns `true`, the export directory may contain a SavedModel but
236  provides no guarantee that it can be loaded.
237
238  Args:
239    export_dir: Absolute string path to possible export location. For example,
240                '/my/foo/model'.
241
242  Returns:
243    True if the export directory contains SavedModel files, False otherwise.
244  """
245  txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
246  pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
247  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
248
249
250@tf_export("saved_model.contains_saved_model", v1=[])
251def contains_saved_model(export_dir):
252  """Checks whether the provided export directory could contain a SavedModel.
253
254  Note that the method does not load any data by itself. If the method returns
255  `false`, the export directory definitely does not contain a SavedModel. If the
256  method returns `true`, the export directory may contain a SavedModel but
257  provides no guarantee that it can be loaded.
258
259  Args:
260    export_dir: Absolute string path to possible export location. For example,
261                '/my/foo/model'.
262
263  Returns:
264    True if the export directory contains SavedModel files, False otherwise.
265  """
266  return maybe_saved_model_directory(export_dir)
267
268
269@tf_export(v1=["saved_model.load", "saved_model.loader.load"])
270@deprecation.deprecated(
271    None,
272    "This function will only be available through the v1 compatibility "
273    "library as tf.compat.v1.saved_model.loader.load or "
274    "tf.compat.v1.saved_model.load. There will be a new function for importing "
275    "SavedModels in Tensorflow 2.0.")
276def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
277  """Loads the model from a SavedModel as specified by tags.
278
279  Args:
280    sess: The TensorFlow session to restore the variables.
281    tags: Set of string tags to identify the required MetaGraphDef. These should
282        correspond to the tags used when saving the variables using the
283        SavedModel `save()` API.
284    export_dir: Directory in which the SavedModel protocol buffer and variables
285        to be loaded are located.
286    import_scope: Optional `string` -- if specified, prepend this string
287        followed by '/' to all loaded tensor names. This scope is applied to
288        tensor instances loaded into the passed session, but it is *not* written
289        through to the static `MetaGraphDef` protocol buffer that is returned.
290    **saver_kwargs: Optional keyword arguments passed through to Saver.
291
292  Returns:
293    The `MetaGraphDef` protocol buffer loaded in the provided session. This
294    can be used to further extract signature-defs, collection-defs, etc.
295
296  Raises:
297    RuntimeError: MetaGraphDef associated with the tags cannot be found.
298  """
299  loader = SavedModelLoader(export_dir)
300  return loader.load(sess, tags, import_scope, **saver_kwargs)
301
302
303class SavedModelLoader(object):
304  """Load graphs and restore variable values from a `SavedModel`."""
305
306  def __init__(self, export_dir):
307    """Creates a `SavedModelLoader`.
308
309    Args:
310      export_dir: Directory in which the SavedModel protocol buffer and
311        variables to be loaded are located.
312    """
313    self._export_dir = export_dir
314    self._variables_path = saved_model_utils.get_variables_path(export_dir)
315    self._saved_model = parse_saved_model(export_dir)
316
317  @property
318  def export_dir(self):
319    """Directory containing the SavedModel."""
320    return self._export_dir
321
322  @property
323  def variables_path(self):
324    """Path to variable checkpoint files."""
325    return self._variables_path
326
327  @property
328  def saved_model(self):
329    """SavedModel object parsed from the export directory."""
330    return self._saved_model
331
332  def get_meta_graph_def_from_tags(self, tags):
333    """Return MetaGraphDef with the exact specified tags.
334
335    Args:
336      tags: A list or set of string tags that identify the MetaGraphDef.
337
338    Returns:
339      MetaGraphDef with the same tags.
340
341    Raises:
342      RuntimeError: if no metagraphs were found with the associated tags.
343    """
344    found_match = False
345    available_tags = []
346    for meta_graph_def in self._saved_model.meta_graphs:
347      available_tags.append(set(meta_graph_def.meta_info_def.tags))
348      if set(meta_graph_def.meta_info_def.tags) == set(tags):
349        meta_graph_def_to_load = meta_graph_def
350        found_match = True
351        break
352
353    if not found_match:
354      raise RuntimeError(
355          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
356          " could not be found in SavedModel. To inspect available tag-sets in"
357          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
358          "\navailable_tags: " + str(available_tags))
359    return meta_graph_def_to_load
360
361  def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
362    """Load ops and nodes from SavedModel MetaGraph into graph.
363
364    Args:
365      graph: tf.Graph object.
366      tags: a set of string tags identifying a MetaGraphDef.
367      import_scope: Optional `string` -- if specified, prepend this string
368        followed by '/' to all loaded tensor names. This scope is applied to
369        tensor instances loaded into the passed session, but it is *not* written
370        through to the static `MetaGraphDef` protocol buffer that is returned.
371      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
372
373    Returns:
374      A tuple of
375        * Saver defined by the MetaGraph, which can be used to restore the
376          variable values.
377        * List of `Operation`/`Tensor` objects returned from
378          `tf.import_graph_def` (may be `None`).
379    """
380    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
381    with graph.as_default():
382      return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
383          meta_graph_def, import_scope=import_scope, **saver_kwargs)
384
385  def restore_variables(self, sess, saver, import_scope=None):
386    """Restore SavedModel variable values into the session.
387
388    Args:
389      sess: tf.compat.v1.Session to restore variable values.
390      saver: a tf.compat.v1.train.Saver object. Can be None if there are no
391        variables in graph. This may be the saver returned by the load_graph()
392        function, or a default `tf.compat.v1.train.Saver()`.
393      import_scope: Optional `string` -- if specified, prepend this string
394        followed by '/' to all loaded tensor names. This scope is applied to
395        tensor instances loaded into the passed session, but it is *not* written
396        through to the static `MetaGraphDef` protocol buffer that is returned.
397
398    Raises:
399      ValueError: if no saver was passed to the saver argument, and there are
400        variables in the graph.
401    """
402    with sess.graph.as_default():
403      if (saver is None and
404          not variables._all_saveable_objects(scope=import_scope)):  # pylint: disable=protected-access
405        tf_logging.info("The specified SavedModel has no variables; no "
406                        "checkpoints were restored.")
407      elif isinstance(saver, tf_saver.Saver):
408        saver.restore(sess, self._variables_path)
409      else:
410        raise ValueError(
411            "No tf.train.Saver object was passed to the function "
412            "SavedModelLoader.restore_variables. Since there are variables in "
413            "the graph, a saver is required.")
414
415  def run_init_ops(self, sess, tags, import_scope=None):
416    """Run initialization ops defined in the `MetaGraphDef`.
417
418    Args:
419      sess: tf.compat.v1.Session to restore variable values.
420      tags: a set of string tags identifying a MetaGraphDef.
421      import_scope: Optional `string` -- if specified, prepend this string
422        followed by '/' to all loaded tensor names. This scope is applied to
423        tensor instances loaded into the passed session, but it is *not* written
424        through to the static `MetaGraphDef` protocol buffer that is returned.
425    """
426    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
427    with sess.graph.as_default():
428      # Get asset tensors, if any.
429      asset_tensors_dictionary = get_asset_tensors(
430          self._export_dir, meta_graph_def, import_scope=import_scope)
431
432      init_op = get_init_op(meta_graph_def, import_scope)
433      if init_op is not None:
434        sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
435
436  def load(self, sess, tags, import_scope=None, **saver_kwargs):
437    """Load the MetaGraphDef graph and restore variable values into the session.
438
439    Args:
440      sess: tf.compat.v1.Session to restore variable values.
441      tags: a set of string tags identifying a MetaGraphDef.
442      import_scope: Optional `string` -- if specified, prepend this string
443        followed by '/' to all loaded tensor names. This scope is applied to
444        tensor instances loaded into the passed session, but it is *not* written
445        through to the static `MetaGraphDef` protocol buffer that is returned.
446      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
447
448    Returns:
449      `MetagraphDef` proto of the graph that was loaded.
450    """
451    with sess.graph.as_default():
452      saver, _ = self.load_graph(sess.graph, tags, import_scope,
453                                 **saver_kwargs)
454      self.restore_variables(sess, saver, import_scope)
455      self.run_init_ops(sess, tags, import_scope)
456    return self.get_meta_graph_def_from_tags(tags)
457