1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.ops import variables
30from tensorflow.python.saved_model import function_deserialization
31from tensorflow.python.saved_model import load_v1_in_v2
32from tensorflow.python.saved_model import loader_impl
33from tensorflow.python.saved_model import nested_structure_coder
34from tensorflow.python.saved_model import revived_types
35from tensorflow.python.saved_model import utils_impl as saved_model_utils
36from tensorflow.python.training.tracking import base
37from tensorflow.python.training.tracking import graph_view
38from tensorflow.python.training.tracking import tracking
39from tensorflow.python.training.tracking import util
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import tf_export
42
43
44class _Loader(object):
45  """Helper class to load an object-based SavedModel."""
46
47  def __init__(self, object_graph_proto, saved_model_proto, export_dir):
48    meta_graph = saved_model_proto.meta_graphs[0]
49    self._asset_file_def = meta_graph.asset_file_def
50    self._operation_attributes = {
51        node.name: node.attr for node in meta_graph.graph_def.node}
52    self._proto = object_graph_proto
53    self._export_dir = export_dir
54    self._concrete_functions = (
55        function_deserialization.load_function_def_library(
56            meta_graph.graph_def.library))
57    self._load_all()
58    # TODO(b/124045874): There are limitations with functions whose captures
59    # trigger other functions to be executed. For now it is only guaranteed to
60    # work if the captures of a function only trigger functions without
61    # captures.
62    self._setup_functions_structures()
63    self._setup_functions_captures()
64    self._restore_checkpoint()
65
66    for node in self._nodes:
67      if isinstance(node, tracking.TrackableResource):
68        init_op = node._initialize()  # pylint: disable=protected-access
69        ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
70
71  def _setup_functions_structures(self):
72    """Setup structure for inputs and outputs of restored functions."""
73    coder = nested_structure_coder.StructureCoder()
74    for name, proto in sorted(self._proto.concrete_functions.items()):
75      concrete_function = self._concrete_functions[name]
76      # By setting the structured_outputs directly, we can rely on this
77      # function_lib.ConcreteFunction object to perform the output repacking
78      # logic. The only limitation of that logic is that it only works
79      # with output that is convertible to Tensors and the conversion
80      # always happens. For example tf.TensorShape([2, 3]) will be
81      # converted to Tensor representing [2, 3].
82      original_outputs = coder.decode_proto(proto.output_signature)
83      # The original_outputs here had Tensors converted to TensorSpecs, so
84      # the restored function's structured_outputs field will not be
85      # exactly the same. Fortunately the repacking logic cares only about
86      # the structure.
87      # TODO(vbardiovsky): Should we just replicate the structures, with
88      # Nones instead of real objects?
89      concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
90      concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
91          coder.decode_proto(proto.canonicalized_input_signature))
92
93  def _setup_functions_captures(self):
94    """Setup captures and variables in restored functions."""
95    concrete_functions = sorted(self._proto.concrete_functions.items())
96    for name, proto in concrete_functions:
97      concrete_function = self._concrete_functions[name]
98      bound_inputs = [
99          self._get_tensor_from_node(node_id)
100          for node_id in proto.bound_inputs]
101      bound_variables = [
102          self._nodes[node_id]
103          for node_id in proto.bound_inputs
104          if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
105      ]
106      # TODO(andresp): This is only injecting the captured inputs into the
107      # concrete function, note that we did not modify the FuncGraph
108      # itself.
109      concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
110      concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
111
112  def _get_tensor_from_node(self, node_id):
113    """Resolves a node id into a tensor to be captured for a function."""
114    with ops.init_scope():
115      obj = self._nodes[node_id]
116      if resource_variable_ops.is_resource_variable(obj):
117        return obj.handle
118      elif isinstance(obj, tracking.TrackableAsset):
119        return obj.asset_path
120      elif tensor_util.is_tensor(obj):
121        return obj
122      elif isinstance(obj, tracking.TrackableResource):
123        # Note: this executes restored functions in the TrackableResource.
124        return obj.resource_handle
125      raise ValueError("Can't convert node %s to tensor" % (type(obj)))
126
127  def _load_all(self):
128    """Load all saved objects and wire their properties."""
129    # Maps from node ids to recreated objects
130    nodes = {}
131    # Maps from node ids to setter functions (same signature as setattr) for
132    # setting dependencies.
133    node_setters = {}
134
135    # Figure out which objects are slot variables. These objects are created
136    # with Optimizer.add_slot rather than _recreate_variable.
137    slot_variable_node_ids = set()
138    for proto in self._proto.nodes:
139      for slot_variable_proto in proto.slot_variables:
140        slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
141
142    # Re-create everything except slot variables.
143    for node_id, proto in enumerate(self._proto.nodes):
144      if node_id in slot_variable_node_ids:
145        # Defer recreating slot variables so we can use the public Optimizer
146        # interface.
147        continue
148      node, setter = self._recreate(proto)
149      nodes[node_id] = node
150      node_setters[node_id] = setter
151
152    # Now that we have created the variables being optimized, we have enough
153    # information to re-create slot variables for them.
154    for node_id, proto in enumerate(self._proto.nodes):
155      optimizer_object = nodes[node_id]
156      for slot_variable_proto in proto.slot_variables:
157        optimized_variable = nodes[
158            slot_variable_proto.original_variable_node_id]
159        slot_variable = optimizer_object.add_slot(
160            var=optimized_variable,
161            slot_name=slot_variable_proto.slot_name)
162        nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
163        node_setters[slot_variable_proto.slot_variable_node_id] = setattr
164
165    self._nodes = []
166
167    # After creating the objects, construct the edges between the objects.
168    for node_id, object_proto in enumerate(self._proto.nodes):
169      obj = nodes[node_id]
170      setter = node_setters[node_id]
171      self._nodes.append(obj)
172
173      for reference in object_proto.children:
174        setter(obj, reference.local_name, nodes[reference.node_id])
175        # Note: if an object has an attribute `__call__` add a class method
176        # that allows `obj()` syntax to work. This is done per-instance to
177        # allow `callable` to be used to find out if an object is callable.
178        if reference.local_name == "__call__":
179          setattr(type(obj), "__call__", _call_attribute)
180
181  def _restore_checkpoint(self):
182    """Load state from checkpoint into the deserialized objects."""
183    variables_path = saved_model_utils.get_variables_path(self._export_dir)
184    # TODO(andresp): Clean use of private methods of TrackableSaver.
185    # pylint: disable=protected-access
186    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
187    saver._file_prefix_placeholder = constant_op.constant(variables_path)
188    load_status = saver.restore(variables_path)
189    load_status.assert_existing_objects_matched()
190    checkpoint = load_status._checkpoint
191
192    # When running in eager mode, the `restore` call above has already run and
193    # restored the state of trackables, call `position.restore_ops()` will
194    # return an empty list as there is nothing left to do. In graph mode, that
195    # will return the list of ops that must run to restore the object on that
196    # position. We have to wire them in the initializers of the objects so that
197    # they get initialized properly when using common practices (e.g. the ones
198    # used by ManagedSession) without further user action.
199    for object_id, obj in dict(checkpoint.object_by_proto_id).items():
200      position = base.CheckpointPosition(checkpoint=checkpoint,
201                                         proto_id=object_id)
202      restore_ops = position.restore_ops()
203      if restore_ops:
204        if resource_variable_ops.is_resource_variable(obj):
205          obj._initializer_op = restore_ops
206        else:
207          raise NotImplementedError(
208              ("Missing functionality to restore state of object "
209               "%r from the checkpoint." % obj))
210
211  def get(self, node_id):
212    return self._nodes[node_id]
213
214  def _recreate(self, proto):
215    """Creates a Python object from a SavedObject protocol buffer."""
216    factory = {
217        "user_object": lambda: self._recreate_user_object(proto.user_object),
218        "asset": lambda: self._recreate_asset(proto.asset),
219        "function": lambda: self._recreate_function(proto.function),
220        "bare_concrete_function": functools.partial(
221            self._recreate_bare_concrete_function,
222            proto.bare_concrete_function),
223        "variable": lambda: self._recreate_variable(proto.variable),
224        "constant": lambda: self._recreate_constant(proto.constant),
225        "resource": lambda: self._recreate_resource(proto.resource),
226    }
227    kind = proto.WhichOneof("kind")
228    if kind not in factory:
229      raise ValueError("Unknown SavedObject type: %r" % kind)
230    return factory[kind]()
231
232  def _recreate_user_object(self, proto):
233    """Instantiates a SavedUserObject."""
234    looked_up = revived_types.deserialize(proto)
235    if looked_up is None:
236      # Note: each user object has its own class. This allows to make each one
237      # individually callable by adding a `__call__` method to the classes of
238      # the objects instances that have a `__call__` property.
239
240      class _UserObject(tracking.AutoTrackable):
241        pass
242
243      return _UserObject(), setattr
244    return looked_up
245
246  def _recreate_asset(self, proto):
247    filename = os.path.join(
248        saved_model_utils.get_assets_dir(self._export_dir),
249        self._asset_file_def[proto.asset_file_def_index].filename)
250    return tracking.TrackableAsset(filename), setattr
251
252  def _recreate_function(self, proto):
253    return function_deserialization.recreate_function(
254        proto, self._concrete_functions), setattr
255
256  def _recreate_bare_concrete_function(self, proto):
257    return function_deserialization.setup_bare_concrete_function(
258        proto, self._concrete_functions), setattr
259
260  def _recreate_variable(self, proto):
261    # TODO(andresp): Can we use the checkpointed value as initializer?
262    dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape)
263    return variables.Variable(dummy_value, trainable=proto.trainable), setattr
264
265  def _recreate_constant(self, proto):
266    tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
267    imported_constant = constant_op.constant(
268        tensor_util.MakeNdarray(tensor_proto))
269    return imported_constant, setattr
270
271  def _recreate_resource(self, proto):
272    del proto
273    return _RestoredResource(), setattr
274
275
276# TODO(b/124205571,b/124092991): Solve destruction of resources.
277class _RestoredResource(tracking.TrackableResource):
278  """Restored SavedResource."""
279
280  def _create_resource(self):
281    raise RuntimeError()
282
283  def _initialize(self):
284    raise RuntimeError()
285
286  def _list_functions_for_serialization(self):
287    # Overwrite this method to avoid the implementation of
288    # base class to re-wrap the polymorphic functions into
289    # another layer of `tf.function`.
290    return {
291        "_create_resource": self._create_resource,
292        "_initialize": self._initialize,
293    }
294
295
296def _call_attribute(instance, *args, **kwargs):
297  return instance.__call__(*args, **kwargs)
298
299
300@tf_export("saved_model.load", v1=["saved_model.load_v2"])
301def load(export_dir, tags=None):
302  """Load a SavedModel from `export_dir`.
303
304  Signatures associated with the SavedModel are available as functions:
305
306  ```python
307  imported = tf.saved_model.load(path)
308  f = imported.signatures["serving_default"]
309  print(f(x=tf.constant([[1.]])))
310  ```
311
312  Objects exported with `tf.saved_model.save` additionally have trackable
313  objects and functions assigned to attributes:
314
315  ```python
316  exported = tf.train.Checkpoint(v=tf.Variable(3.))
317  exported.f = tf.function(
318      lambda x: exported.v * x,
319      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
320  tf.saved_model.save(exported, path)
321  imported = tf.saved_model.load(path)
322  assert 3. == imported.v.numpy()
323  assert 6. == imported.f(x=tf.constant(2.)).numpy()
324  ```
325
326  Args:
327    export_dir: The SavedModel directory to load from.
328    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
329      if the SavedModel contains a single MetaGraph, as for those exported from
330      `tf.saved_model.load`.
331
332  Returns:
333    A trackable object with a `signatures` attribute mapping from signature
334    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
335    it also points to trackable objects and functions which were attached
336    to the exported object.
337
338  Raises:
339    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
340  """
341  if tags is not None and not isinstance(tags, set):
342    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
343    # sequences for nest.flatten, so we put those through as-is.
344    tags = nest.flatten(tags)
345  saved_model_proto = loader_impl.parse_saved_model(export_dir)
346  if (len(saved_model_proto.meta_graphs) == 1
347      and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
348    meta_graph_def = saved_model_proto.meta_graphs[0]
349    if (tags is not None
350        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
351      raise ValueError(
352          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
353           "incompatible argument tags={} to tf.saved_model.load. You may omit "
354           "it, pass 'None', or pass matching tags.")
355          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
356    object_graph_proto = meta_graph_def.object_graph_def
357    with ops.init_scope():
358      loader = _Loader(object_graph_proto,
359                       saved_model_proto,
360                       export_dir)
361      root = loader.get(0)
362  else:
363    with ops.init_scope():
364      root = load_v1_in_v2.load(export_dir, tags)
365  return root
366