1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Import a trackable object from a SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23
24from tensorflow.core.protobuf import graph_debug_info_pb2
25from tensorflow.python.distribute import distribute_utils
26from tensorflow.python.distribute import distribution_strategy_context as ds_context
27from tensorflow.python.distribute import values_util
28from tensorflow.python.eager import context
29from tensorflow.python.eager import function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import custom_gradient
38from tensorflow.python.ops import lookup_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.saved_model import function_deserialization
42from tensorflow.python.saved_model import load_options
43from tensorflow.python.saved_model import load_v1_in_v2
44from tensorflow.python.saved_model import loader_impl
45from tensorflow.python.saved_model import nested_structure_coder
46from tensorflow.python.saved_model import revived_types
47from tensorflow.python.saved_model import utils_impl as saved_model_utils
48from tensorflow.python.training.saving import checkpoint_options
49from tensorflow.python.training.saving import saveable_object_util
50from tensorflow.python.training.tracking import base
51from tensorflow.python.training.tracking import data_structures
52from tensorflow.python.training.tracking import graph_view
53from tensorflow.python.training.tracking import tracking
54from tensorflow.python.training.tracking import util
55from tensorflow.python.util import nest
56from tensorflow.python.util.tf_export import tf_export
57
58
59def _unused_handle():
60  """Returns a placeholder as a handle that is not supposed to be accessed."""
61  error_message = ("Trying to access a placeholder that is not supposed to be "
62                   "executed. This means you are executing a graph generated "
63                   "from the cross-replica context in an in-replica context.")
64
65  assert_op = control_flow_ops.Assert(
66      array_ops.placeholder_with_default(False, shape=()),
67      [error_message])
68
69  with ops.control_dependencies([assert_op]):
70    return array_ops.placeholder(dtype=dtypes.resource)
71
72
73class _WrapperFunction(function.ConcreteFunction):
74  """A class wraps a concrete function to handle different distributed contexts.
75
76  The reason for wrapping a concrete function is because the _captured_inputs
77  fields used for in-replica context and cross-replica context are different.
78  When `load()` is called from within a tf.distribute.strategy scope, the
79  captured inputs are distributed variables. When using these distributed
80  variables during calling the function, we need different approaches when it is
81  in-replica and when it is not in-replica. When it is in replica, naturally we
82  should use the corresponding component of the distributed variable; when it is
83  not in-replica, calling the function should mean that it is constructing a
84  graph that is not actually going to be used. A typical use case is when
85  constructing a functional model. In this case, return a placeholder with a
86  control dependency to ensure that is never accessed.
87  """
88
89  def __init__(self, concrete_function):
90    # Shallow copy the concrete_function
91    self.__dict__.update(vars(concrete_function))
92
93  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
94
95    def get_handle(x):
96      return x.handle if distribute_utils.is_distributed_variable(x) else x
97
98    def get_unused_handle(x):
99      return _unused_handle() if distribute_utils.is_distributed_variable(x)   \
100          else x
101
102    if (ds_context.get_replica_context() is not None or
103        values_util.is_saving_non_distributed()):
104      # If we're in the replica context or are saving a non-distributed version
105      # of the model, we resolve the captured variables to the corresponding
106      # resource handle. In both situation we call var.handle, but it has
107      # different behavior. In the replica context, var.handle resolves the
108      # replica local variable handle if the variable is replicated. When saving
109      # a non-distributed version of the model, var.handle resolves to the
110      # primary variable handle, since we only save one copy of a replicated
111      # variable.
112      captured_inputs = list(map(get_handle, captured_inputs))
113    else:  # cross-replica context
114      captured_inputs = list(map(get_unused_handle, captured_inputs))
115    return super(_WrapperFunction, self)._call_flat(args, captured_inputs,
116                                                    cancellation_manager)
117
118
119class Loader(object):
120  """Helper class to load an object-based SavedModel."""
121
122  def __init__(self, object_graph_proto, saved_model_proto, export_dir,
123               ckpt_options, filters):
124    meta_graph = saved_model_proto.meta_graphs[0]
125    self._asset_file_def = meta_graph.asset_file_def
126    self._operation_attributes = {
127        node.name: node.attr for node in meta_graph.graph_def.node}
128    self._proto = object_graph_proto
129    self._export_dir = export_dir
130    self._concrete_functions = (
131        function_deserialization.load_function_def_library(
132            meta_graph.graph_def.library))
133    self._checkpoint_options = ckpt_options
134
135    # Stores user-defined node_filters argument.
136    self._node_filters = filters
137    # Stores map of string paths to integers.
138    self._node_path_to_id = self._convert_node_paths_to_ints()
139    self._loaded_nodes = {}
140    if isinstance(filters, dict):
141      # If node_filters is a dict, then the values may contain already created
142      # trackable objects. In this case, create a dictionary mapping node IDs to
143      # the already created nodes. This dict will be updated in
144      # `_retrieve_all_filtered_nodes` with tracked dependencies.
145      for node_path, node in filters.items():
146        if isinstance(node, tuple):
147          self._loaded_nodes[self._node_path_to_id[node_path]] = node
148        else:
149          self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
150
151    # Get a list of all integer node ids to load, or None if all nodes should be
152    # loaded. This list includes ids of child nodes.
153    self._filtered_nodes = self._retrieve_all_filtered_nodes()
154
155    for name, concrete_function in self._concrete_functions.items():
156      # Wrap all the concrete function so that they are capable of dealing with
157      # both in replica and cross replica cases.
158      self._concrete_functions[name] = _WrapperFunction(concrete_function)
159
160    self._load_all()
161    self._restore_checkpoint()
162
163    for node in self._nodes:
164      if isinstance(node, tracking.CapturableResource):
165        init_op = node._initialize()  # pylint: disable=protected-access
166        if not context.executing_eagerly():
167          ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
168
169  def _convert_node_paths_to_ints(self):
170    """Maps all string node paths in node_filters to the int node ids."""
171    if self._node_filters is None:
172      return None
173    path_to_int = {}
174    for node_id in self._node_filters:
175      int_node_id = None
176      if isinstance(node_id, str):
177        node_path = node_id.split(".")
178        if node_path[0] != "root":
179          raise ValueError(
180              "When passing string identifiers to node_filters, the first name"
181              " must be root.")
182        int_node_id = 0
183        for n, name in enumerate(node_path[1:]):
184          int_node_id = self._find_node_child(
185              int_node_id, name, ".".join(node_path[:n+2]))
186        path_to_int[node_id] = int_node_id
187      else:
188        raise TypeError("Elements in node_filters must be strings.")
189    return path_to_int
190
191  def _retrieve_all_filtered_nodes(self):
192    """Traverses through the object graph to get the IDs of all nodes to load.
193
194    As a side-effect, if node_filters is a dictionary that contains already-
195    created objects, then the dependencies tracked by those objects will be
196    added to node_filters.
197
198    Returns:
199      List of all nodes to load, or None if all nodes should be loaded.
200
201    """
202    if self._node_filters is None:
203      return None  # All nodes should be loaded.
204
205    all_filtered_nodes = set()
206    nodes_to_visit = list(self._node_filters)
207
208    while nodes_to_visit:
209      node_path = nodes_to_visit.pop(0)
210      node_id = self._node_path_to_id[node_path]
211      if node_id in all_filtered_nodes:
212        continue
213      all_filtered_nodes.add(node_id)
214
215      node, setter = self._loaded_nodes.get(node_id, (None, None))
216      if node is not None:
217        if not isinstance(node, base.Trackable):
218          raise TypeError(
219              "Error when processing dictionary values passed to nodes_to_load."
220              "Object at {} is expected to be a checkpointable TensorFlow "
221              "object (e.g. tf.Variable, tf.Module or Keras layer)."
222              .format(node_path))
223        node._maybe_initialize_trackable()  # pylint: disable=protected-access
224
225      for reference in self._proto.nodes[node_id].children:
226        child_object, _ = self._loaded_nodes.get(
227            reference.node_id, (None, None))
228
229        # See if node already tracks the child reference, in which case add the
230        # child to the loaded_nodes dict.
231        if child_object is None and node is not None:
232          child_object = node._lookup_dependency(reference.local_name)  # pylint: disable=protected-access
233          if isinstance(child_object, data_structures.TrackableDataStructure):
234            # Make setattr a noop to avoid overwriting already existing data
235            # structures.
236            setter = lambda *args: None
237
238            self._loaded_nodes[reference.node_id] = (child_object, setter)
239
240        child_path = "{}.{}".format(node_path, reference.local_name)
241        self._node_path_to_id[child_path] = reference.node_id
242        nodes_to_visit.append(child_path)
243
244    if 0 in all_filtered_nodes:
245      return None
246    return all_filtered_nodes
247
248  def _find_node_child(self, node_id, child_name, path):
249    for reference in self._proto.nodes[node_id].children:
250      if reference.local_name == child_name:
251        return reference.node_id
252    raise ValueError("unable to find node {}".format(path))
253
254  def _load_all(self):
255    """Loads all nodes and functions from the SavedModel and their edges."""
256    self._load_nodes()
257    self._load_edges()
258    # TODO(b/124045874): There are limitations with functions whose captures
259    # trigger other functions to be executed. For now it is only guaranteed to
260    # work if the captures of a function only trigger functions without
261    # captures.
262    self._setup_functions_structures()
263    self._setup_functions_captures()
264
265    self._create_saveable_object_factories()
266
267  def _create_saveable_object_factories(self):
268    for node_id, proto in self._iter_all_nodes():
269      node = self.get(node_id)
270      node._self_saveable_object_factories = {}  # pylint: disable=protected-access
271      for name, saveable_object_proto in proto.saveable_objects.items():
272        node._self_saveable_object_factories[name] = (  # pylint: disable=protected-access
273            saveable_object_util.restored_saved_object_factory(
274                self.get(saveable_object_proto.save_function),
275                self.get(saveable_object_proto.restore_function)))
276
277  def _load_edges(self):
278    """Adds edges from objects to other objects and functions."""
279    for node_id, object_proto in self._iter_all_nodes():
280      self._add_object_graph_edges(object_proto, node_id)
281
282    # If root object isn't loaded, then create edges from the root for
283    # checkpoint compatibility.
284    if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
285      root = self.get(0)
286      for node_path in self._node_filters:
287        loaded_node = self._nodes[self._node_path_to_id[node_path]]
288        path = node_path.split(".")
289        current_node = root
290        for name in path[1:-1]:
291          if not hasattr(current_node, name):
292            setattr(current_node, name, self._recreate_base_user_object()[0])
293          current_node = getattr(current_node, name)
294        if not hasattr(current_node, path[-1]):
295          setattr(current_node, path[-1], loaded_node)
296
297  def _add_object_graph_edges(self, proto, node_id):
298    """Adds edges from an object to its children."""
299    obj = self._nodes[node_id]
300    setter = self._node_setters[node_id]
301
302    for reference in proto.children:
303      setter(obj, reference.local_name, self._nodes[reference.node_id])
304      # Note: if an object has an attribute `__call__` add a class method
305      # that allows `obj()` syntax to work. This is done per-instance to
306      # allow `callable` to be used to find out if an object is callable.
307      if reference.local_name == "__call__" and not callable(obj):
308        setattr(type(obj), "__call__", _call_attribute)
309
310  def _setup_functions_structures(self):
311    """Setup structure for inputs and outputs of restored functions."""
312    coder = nested_structure_coder.StructureCoder()
313    for name, proto in sorted(self._proto.concrete_functions.items()):
314      concrete_function = self._concrete_functions[name]
315      # By setting the structured_outputs directly, we can rely on this
316      # function_lib.ConcreteFunction object to perform the output repacking
317      # logic. The only limitation of that logic is that it only works
318      # with output that is convertible to Tensors and the conversion
319      # always happens. For example tf.TensorShape([2, 3]) will be
320      # converted to Tensor representing [2, 3].
321      original_outputs = coder.decode_proto(proto.output_signature)
322      # The original_outputs here had Tensors converted to TensorSpecs, so
323      # the restored function's structured_outputs field will not be
324      # exactly the same. Fortunately the repacking logic cares only about
325      # the structure; and the unpacking logic cares only about structure
326      # and types.
327      concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
328      concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
329          coder.decode_proto(proto.canonicalized_input_signature))
330      concrete_function._initialize_function_spec()  # pylint: disable=protected-access
331
332  def _setup_functions_captures(self):
333    """Setup captures and variables in restored functions."""
334    concrete_functions = sorted(self._proto.concrete_functions.items())
335    for name, proto in concrete_functions:
336      concrete_function = self._concrete_functions[name]
337      bound_inputs = [
338          self._get_tensor_from_node(node_id, name)
339          for node_id in proto.bound_inputs]
340      bound_variables = [
341          self._nodes[node_id]
342          for node_id in proto.bound_inputs
343          if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
344      ]
345      # TODO(andresp): This is only injecting the captured inputs into the
346      # concrete function, note that we did not modify the FuncGraph
347      # itself.
348      concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
349      concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
350      if bound_inputs:
351        for bound_input, internal_capture in zip(
352            bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
353          if distribute_utils.is_distributed_variable(bound_input):
354            concrete_function.graph.capture_distributed_variable(
355                bound_input, internal_capture)
356          else:
357            concrete_function.graph.replace_capture(bound_input,
358                                                    internal_capture)
359            if internal_capture.dtype == dtypes.resource:
360              if resource_variable_ops.is_resource_variable(bound_input):
361                try:
362                  handle = bound_input.handle
363                except ValueError:
364                  # For mirrored variables we'll copy handle data for components
365                  # as they get captured.
366                  pass
367                else:
368                  custom_gradient.copy_handle_data(handle, internal_capture)
369              else:
370                custom_gradient.copy_handle_data(bound_input, internal_capture)
371            # Setting "captures" first means "capture" won't create a new
372            # placeholder for this input.
373            concrete_function.graph.capture(bound_input)
374
375  def _get_tensor_from_node(self, node_id, fn_name):
376    """Resolves a node id into a tensor to be captured for a function."""
377    if self._node_filters is not None and self._nodes[node_id] is None:
378      raise ValueError(
379          "Error when processing nodes_to_load. Function \"{}\" requires "
380          "inputs/variables that are not loaded when nodes_to_load={}"
381          .format(fn_name, self._node_filters))
382
383    with ops.init_scope():
384      obj = self._nodes[node_id]
385      if distribute_utils.is_distributed_variable(obj):
386        return obj
387      elif resource_variable_ops.is_resource_variable(obj):
388        return obj.handle
389      elif isinstance(obj, tracking.Asset):
390        return obj.asset_path
391      elif tensor_util.is_tf_type(obj):
392        return obj
393      elif isinstance(obj, tracking.CapturableResource):
394        # Note: this executes restored functions in the CapturableResource.
395        return obj.resource_handle
396      raise ValueError("Can't convert node %s to tensor" % (type(obj)))
397
398  def _initialize_loaded_nodes(self):
399    nodes = {}
400    node_setters = {}
401    for node_id, (node, setter) in self._loaded_nodes.items():
402      nodes[node_id] = node
403      node_setters[node_id] = setter
404    return nodes, node_setters
405
406  def _iter_all_nodes(self):
407    if self._filtered_nodes is None:
408      return enumerate(self._proto.nodes)
409    else:
410      return [(node_id, self._proto.nodes[node_id])
411              for node_id in self._filtered_nodes]
412
413  def _load_nodes(self):
414    """Load all saved objects."""
415    # `nodes` maps from node ids to recreated objects
416    # `node_setters` maps from node ids to setter functions
417    # (same signature as setattr) for setting dependencies.
418    nodes, node_setters = self._initialize_loaded_nodes()
419
420    # Figure out which objects are slot variables. These objects are created
421    # with Optimizer.add_slot rather than _recreate_variable.
422    slot_variable_node_ids = set()
423
424    for _, proto in self._iter_all_nodes():
425      for slot_variable_proto in proto.slot_variables:
426        slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
427
428    # Re-create everything except slot variables.
429    for node_id, proto in self._iter_all_nodes():
430      if node_id in slot_variable_node_ids or nodes.get(node_id) is not None:
431        # Defer recreating slot variables so we can use the public Optimizer
432        # interface.
433        continue
434      node, setter = self._recreate(proto, node_id)
435      nodes[node_id] = node
436      node_setters[node_id] = setter
437
438    # Now that we have created the variables being optimized, we have enough
439    # information to re-create slot variables for them.
440    for node_id, proto in self._iter_all_nodes():
441      optimizer_object = nodes[node_id]
442      for slot_variable_proto in proto.slot_variables:
443        optimized_variable = nodes[
444            slot_variable_proto.original_variable_node_id]
445        slot_variable = optimizer_object.add_slot(
446            var=optimized_variable,
447            slot_name=slot_variable_proto.slot_name)
448        nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
449        node_setters[slot_variable_proto.slot_variable_node_id] = setattr
450
451    # If root object is not loaded, add a dummy root object for checkpoint
452    # compatibility.
453    if 0 not in nodes:
454      nodes[0] = self._recreate_base_user_object()[0]
455
456    self._nodes = [nodes.get(node_id)
457                   for node_id in range(len(self._proto.nodes))]
458    self._node_setters = node_setters
459
460  @property
461  def _expect_partial_checkpoint(self):
462    """Whether to expect that some objects aren't loaded.
463
464    This should be set to True in subclasses of the Loader class which generate
465    a trackable object with an object graph that is different from the graph
466    in the SavedModel. Setting this property to True suppresses the warnings
467    that are printed out when there are unused parts of the checkpoint or
468    object.
469
470    Returns:
471      boolean
472    """
473    return False
474
475  def _restore_checkpoint(self):
476    """Load state from checkpoint into the deserialized objects."""
477    variables_path = saved_model_utils.get_variables_path(self._export_dir)
478    # TODO(andresp): Clean use of private methods of TrackableSaver.
479    # pylint: disable=protected-access
480    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
481    with ops.device("CPU"):
482      saver._file_prefix_placeholder = constant_op.constant(variables_path)
483    if self._expect_partial_checkpoint:
484      load_status = saver.restore(variables_path,
485                                  self._checkpoint_options).expect_partial()
486    else:
487      load_status = saver.restore(variables_path, self._checkpoint_options)
488    load_status.assert_existing_objects_matched()
489    checkpoint = load_status._checkpoint
490
491    if not context.executing_eagerly():
492      # When running in eager mode, the `restore` call above has already run and
493      # restored the state of trackables, and calling `position.restore_ops()`
494      # would re-run the restore. In graph mode, that will return a cached list
495      # of ops that must run to restore the object on that position. We have to
496      # wire them in the initializers of the objects so that they get
497      # initialized properly when using common practices (e.g. the ones used by
498      # ManagedSession) without further user action.
499      for object_id, obj in dict(checkpoint.object_by_proto_id).items():
500        position = base.CheckpointPosition(checkpoint=checkpoint,
501                                           proto_id=object_id)
502        restore_ops = position.restore_ops()
503        if restore_ops:
504          if resource_variable_ops.is_resource_variable(obj):
505            if len(restore_ops) == 1:
506              obj._initializer_op = restore_ops[0]
507            else:
508              obj._initializer_op = control_flow_ops.group(*restore_ops)
509          elif isinstance(obj, lookup_ops.LookupInterface):
510            # We don't need to check for eager execution here, since this code
511            # path should only be taken if we are restoring in graph mode.
512            ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
513          else:
514            raise NotImplementedError(
515                ("Missing functionality to restore state of object "
516                 "%r from the checkpoint." % obj))
517
518  def adjust_debug_info_func_names(self, debug_info):
519    """Rewrite func names in the debug info by using the concrete func names."""
520    output_debug_info = graph_debug_info_pb2.GraphDebugInfo()
521    output_debug_info.files[:] = debug_info.files
522    for key in debug_info.traces:
523      node, func = key.split("@")
524      new_func = ""
525      if func in self._concrete_functions:
526        new_func = self._concrete_functions[func].function_def.signature.name
527      output_debug_info.traces[node + "@" + new_func].CopyFrom(
528          debug_info.traces[key])
529    return output_debug_info
530
531  def get(self, node_id):
532    if isinstance(node_id, str):
533      node_id = self._node_path_to_id[node_id]
534    return self._nodes[node_id]
535
536  def _recreate(self, proto, node_id):
537    """Creates a Python object from a SavedObject protocol buffer."""
538    factory = {
539        "user_object": (
540            lambda: self._recreate_user_object(proto.user_object, node_id)),
541        "asset": lambda: self._recreate_asset(proto.asset),
542        "function": lambda: self._recreate_function(proto.function),
543        "bare_concrete_function": functools.partial(
544            self._recreate_bare_concrete_function,
545            proto.bare_concrete_function),
546        "variable": lambda: self._recreate_variable(proto.variable),
547        "constant": lambda: self._recreate_constant(proto.constant),
548        "resource": lambda: self._recreate_resource(proto.resource),
549    }
550    kind = proto.WhichOneof("kind")
551    if kind not in factory:
552      raise ValueError("Unknown SavedObject type: %r" % kind)
553    return factory[kind]()
554
555  def _recreate_user_object(self, proto, node_id):
556    """Instantiates a SavedUserObject."""
557    looked_up = revived_types.deserialize(proto)
558    if looked_up is None:
559      return self._recreate_base_user_object(proto, node_id)
560    return looked_up
561
562  def _recreate_base_user_object(self, proto=None, node_id=None):
563    del proto, node_id
564    # Note: each user object has its own class. This allows making each one
565    # individually callable by adding a `__call__` method to the classes of
566    # the objects instances that have a `__call__` property.
567
568    class _UserObject(tracking.AutoTrackable):
569      pass
570
571    return _UserObject(), setattr
572
573  def _recreate_asset(self, proto):
574    filename = os.path.join(
575        saved_model_utils.get_assets_dir(self._export_dir),
576        self._asset_file_def[proto.asset_file_def_index].filename)
577    asset = tracking.Asset(filename)
578    if not context.executing_eagerly():
579      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path)
580    return asset, setattr
581
582  def _recreate_function(self, proto):
583    return function_deserialization.recreate_function(
584        proto, self._concrete_functions), setattr
585
586  def _recreate_bare_concrete_function(self, proto):
587    return function_deserialization.setup_bare_concrete_function(
588        proto, self._concrete_functions), setattr
589
590  def _recreate_variable(self, proto):
591    name = proto.name if proto.name else None
592    if name is not None:
593      dbg_name = name
594    else:
595      dbg_name = "<variable loaded from saved model>"
596    synchronization, aggregation, trainable = (
597        variables.validate_synchronization_aggregation_trainable(
598            proto.synchronization, proto.aggregation, proto.trainable,
599            name=dbg_name))
600
601    def uninitialized_variable_creator(next_creator, **kwargs):
602      """A variable creator that creates uninitialized variables."""
603      del next_creator
604      return resource_variable_ops.UninitializedVariable(**kwargs)
605
606    # Create a variable_creator_scope that creates uninitialized variables with
607    # a lower priority such that a potential distributed variable_creator_scope
608    # can take precedence.
609    with ops.get_default_graph()._variable_creator_scope(  # pylint: disable=protected-access
610        uninitialized_variable_creator,
611        priority=50):
612      return variables.Variable(
613          shape=proto.shape,
614          dtype=proto.dtype,
615          name=name,
616          trainable=trainable,
617          synchronization=synchronization,
618          aggregation=aggregation), setattr
619
620  def _recreate_constant(self, proto):
621    tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
622    ndarray = tensor_util.MakeNdarray(tensor_proto)
623    if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string:
624      with ops.device("CPU"):
625        imported_constant = constant_op.constant(ndarray)
626    else:
627      imported_constant = constant_op.constant(ndarray)
628    return imported_constant, setattr
629
630  def _recreate_resource(self, proto):
631    return _RestoredResource(device=proto.device), setattr
632
633
634# TODO(b/124205571,b/124092991): Solve destruction of resources.
635class _RestoredResource(tracking.TrackableResource):
636  """Restored SavedResource."""
637
638  def __init__(self, device=""):
639    super(_RestoredResource, self).__init__(device=device)
640    self._destroy_resource_fn = None
641
642  def _create_resource(self):
643    raise RuntimeError()
644
645  def _initialize(self):
646    raise RuntimeError()
647
648  @property
649  def _destroy_resource(self):
650    return self._destroy_resource_fn
651
652  @_destroy_resource.setter
653  def _destroy_resource(self, destroy_resource_fn):
654    self._resource_deleter = tracking.CapturableResourceDeleter(
655        destroy_resource_fn)
656    self._destroy_resource_fn = destroy_resource_fn
657
658  def _list_functions_for_serialization(self, unused_serialization_cache):
659    # Overwrite this method to avoid the implementation of
660    # base class to re-wrap the polymorphic functions into
661    # another layer of `tf.function`.
662    functions = {
663        "_create_resource": self._create_resource,
664        "_initialize": self._initialize,
665    }
666    if self._destroy_resource:
667      functions.update(_destroy_resource=self._destroy_resource)
668    return functions
669
670
671def _call_attribute(instance, *args, **kwargs):
672  return instance.__call__(*args, **kwargs)
673
674
675@tf_export("__internal__.saved_model.load_partial", v1=[])
676def load_partial(export_dir, filters, tags=None, options=None):
677  """Partially load a SavedModel (saved from V2).
678
679  Similar to `tf.saved_model.load`, but with an additional argument that
680  lets you specify which nodes to load.
681  `tf.saved_model.load_partial(export_dir, ["root"])` and
682  `tf.saved_model.load(export_dir)` are equivalent.
683
684  Note: This only works for SavedModels saved with TensorFlow V2 from
685  `tf.saved_model.save` or Keras. This will not load SavedModels save from
686  the Estimator API.
687
688  In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
689  The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
690  layers, etc.) and edges that are the name of the attributes connecting the
691  objects.
692
693  *Example 1*
694
695  ```
696  model = tf.Module()
697  model.child_layer = tf.Module()
698  model.child_layer.v = tf.Variable(5.)
699  tf.saved_model.save(model, '/tmp/model')
700  loaded = tf.__internal__.saved_model.load_partial(
701  ...   '/tmp/model',
702  ...   ['root.child_layer', 'root.child_layer.v'])
703  loaded['root.child_layer'].v.numpy()
704  5.
705  loaded['root.child_layer'].v is loaded['root.child_layer.v']
706  True
707
708  *Example 2*
709  model = tf.Module()
710  model.child_layer = tf.Module()
711  model.child_layer.v = tf.Variable(5.)
712  >>>
713  tf.saved_model.save(model, '/tmp/model')
714  # Create a variable
715  new_variable = tf.Variable(0.)
716  loaded = tf.__internal__.saved_model.load_partial(
717  ...   '/tmp/model',
718  ...   {'root.child_layer': None, 'root.child_layer.v': new_variable})
719  loaded['root.child_layer'].v.numpy()
720  5.
721  new_variable.numpy()
722  5.
723  ```
724
725  **Loading under different distribution strategies**
726  You can load different parts of the model under different distribution
727  strategies. Note that this is very experimental so use with care.
728
729  ```
730  model = tf.Module()
731  model.layer_1 = tf.Module()
732  model.layer_1.v = tf.Variable(5.)
733  model.layer_2 = tf.Module()
734  model.layer_2.v = tf.Variable(7.)
735  tf.saved_model.save(model, '/tmp/model')
736  # Load with no strategy
737  loaded = tf.__internal__.saved_model.load_partial(
738  ...   '/tmp/model',
739  ...   ['root.layer_1'])
740  loaded['root.layer_1'].v
741  <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
742  strategy = tf.distribute.MirroredStrategy()
743  with strategy.scope():
744  ...   loaded2 = tf.__internal__.saved_model.load_partial(
745  ...     '/tmp/model',
746  ...     ['root.layer_2'])
747  loaded2['root.layer_2'].v
748  MirroredVariable:{
749      0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
750  }
751  ```
752
753  Args:
754    export_dir: The SavedModel directory to load from.
755    filters: A list or dictionary where each element or key is a string
756      path to nodes that should be loaded. Node paths consist of all the child
757      attribute names to reach that node in the form: `root.{attribute_name}`.
758      The loader will load all of the specified nodes and their recursive
759      descendants. When this option is defined, the loader will return a
760      dictionary mapping the node paths to the loaded objects.
761    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
762      if the SavedModel contains a single MetaGraph, as for those exported from
763      `tf.saved_model.save`.
764    options: `tf.saved_model.LoadOptions` object that specifies options for
765      loading.
766
767  Returns:
768    A dictionary mapping node paths from the filter to loaded objects.
769  """
770  return load_internal(export_dir, tags, options, filters=filters)
771
772
773@tf_export("saved_model.load", v1=["saved_model.load_v2"])
774def load(export_dir, tags=None, options=None):
775  """Load a SavedModel from `export_dir`.
776
777  Signatures associated with the SavedModel are available as functions:
778
779  ```python
780  imported = tf.saved_model.load(path)
781  f = imported.signatures["serving_default"]
782  print(f(x=tf.constant([[1.]])))
783  ```
784
785  Objects exported with `tf.saved_model.save` additionally have trackable
786  objects and functions assigned to attributes:
787
788  ```python
789  exported = tf.train.Checkpoint(v=tf.Variable(3.))
790  exported.f = tf.function(
791      lambda x: exported.v * x,
792      input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
793  tf.saved_model.save(exported, path)
794  imported = tf.saved_model.load(path)
795  assert 3. == imported.v.numpy()
796  assert 6. == imported.f(x=tf.constant(2.)).numpy()
797  ```
798
799  _Loading Keras models_
800
801  Keras models are trackable, so they can be saved to SavedModel. The object
802  returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
803  `.fit`, `.predict`, etc. methods). A few attributes and functions are still
804  available: `.variables`, `.trainable_variables` and `.__call__`.
805
806  ```python
807  model = tf.keras.Model(...)
808  tf.saved_model.save(model, path)
809  imported = tf.saved_model.load(path)
810  outputs = imported(inputs)
811  ```
812
813  Use `tf.keras.models.load_model` to restore the Keras model.
814
815  _Importing SavedModels from TensorFlow 1.x_
816
817  SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
818  graph instead of `tf.function` objects. These SavedModels will be loaded with
819  the following attributes:
820
821  * `.signatures`: A dictionary mapping signature names to functions.
822  * `.prune(feeds, fetches) `: A method which allows you to extract
823    functions for new subgraphs. This is equivalent to importing the SavedModel
824    and naming feeds and fetches in a Session from TensorFlow 1.x.
825
826    ```python
827    imported = tf.saved_model.load(path_to_v1_saved_model)
828    pruned = imported.prune("x:0", "out:0")
829    pruned(tf.ones([]))
830    ```
831
832    See `tf.compat.v1.wrap_function` for details.
833  * `.variables`: A list of imported variables.
834  * `.graph`: The whole imported graph.
835  * `.restore(save_path)`: A function that restores variables from a checkpoint
836    saved from `tf.compat.v1.Saver`.
837
838  _Consuming SavedModels asynchronously_
839
840  When consuming SavedModels asynchronously (the producer is a separate
841  process), the SavedModel directory will appear before all files have been
842  written, and `tf.saved_model.load` will fail if pointed at an incomplete
843  SavedModel. Rather than checking for the directory, check for
844  "saved_model_dir/saved_model.pb". This file is written atomically as the last
845  `tf.saved_model.save` file operation.
846
847  Args:
848    export_dir: The SavedModel directory to load from.
849    tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
850      if the SavedModel contains a single MetaGraph, as for those exported from
851      `tf.saved_model.save`.
852    options: `tf.saved_model.LoadOptions` object that specifies options for
853      loading.
854
855  Returns:
856    A trackable object with a `signatures` attribute mapping from signature
857    keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
858    it also points to trackable objects, functions, debug info which it has been
859    saved.
860
861  Raises:
862    ValueError: If `tags` don't match a MetaGraph in the SavedModel.
863  """
864  return load_internal(export_dir, tags, options)["root"]
865
866
867def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
868                  filters=None):
869  """Loader implementation."""
870  options = options or load_options.LoadOptions()
871  if tags is not None and not isinstance(tags, set):
872    # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
873    # sequences for nest.flatten, so we put those through as-is.
874    tags = nest.flatten(tags)
875  saved_model_proto, debug_info = (
876      loader_impl.parse_saved_model_with_debug_info(export_dir))
877
878  if (len(saved_model_proto.meta_graphs) == 1 and
879      saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
880    meta_graph_def = saved_model_proto.meta_graphs[0]
881    if (tags is not None
882        and set(tags) != set(meta_graph_def.meta_info_def.tags)):
883      raise ValueError(
884          ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
885           "incompatible argument tags={} to tf.saved_model.load. You may omit "
886           "it, pass 'None', or pass matching tags.")
887          .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
888    object_graph_proto = meta_graph_def.object_graph_def
889
890    ckpt_options = checkpoint_options.CheckpointOptions(
891        experimental_io_device=options.experimental_io_device)
892    with ops.init_scope():
893      try:
894        loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
895                            ckpt_options, filters)
896      except errors.NotFoundError as err:
897        raise FileNotFoundError(
898            str(err) + "\n If trying to load on a different device from the "
899            "computational device, consider using setting the "
900            "`experimental_io_device` option on tf.saved_model.LoadOptions "
901            "to the io_device such as '/job:localhost'."
902        )
903      root = loader.get(0)
904      if isinstance(loader, Loader):
905        root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info)
906    root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
907    root.tensorflow_git_version = (
908        meta_graph_def.meta_info_def.tensorflow_git_version)
909  else:
910    if filters:
911      raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any "
912                       "version) cannot be loaded with node filters.")
913    with ops.init_scope():
914      root = load_v1_in_v2.load(export_dir, tags)
915      root.graph_debug_info = debug_info
916
917  if filters:
918    return {node_id: loader.get(node_id) for node_id in filters}
919  else:
920    return {"root": root}
921