1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Exports a SavedModel from a Trackable Python object."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import gc
24import os
25
26from absl import logging
27from tensorflow.core.framework import versions_pb2
28from tensorflow.core.protobuf import meta_graph_pb2
29from tensorflow.core.protobuf import saved_model_pb2
30from tensorflow.core.protobuf import saved_object_graph_pb2
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function as defun
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import error_interpolation
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.framework import versions
42from tensorflow.python.lib.io import file_io
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import resource_variable_ops
46from tensorflow.python.platform import tf_logging
47from tensorflow.python.saved_model import builder_impl
48from tensorflow.python.saved_model import constants
49from tensorflow.python.saved_model import function_serialization
50from tensorflow.python.saved_model import nested_structure_coder
51from tensorflow.python.saved_model import revived_types
52from tensorflow.python.saved_model import save_context
53from tensorflow.python.saved_model import save_options
54from tensorflow.python.saved_model import signature_constants
55from tensorflow.python.saved_model import signature_def_utils
56from tensorflow.python.saved_model import signature_serialization
57from tensorflow.python.saved_model import tag_constants
58from tensorflow.python.saved_model import utils_impl
59from tensorflow.python.training.saving import checkpoint_options
60from tensorflow.python.training.saving import functional_saver
61from tensorflow.python.training.saving import saveable_object_util
62from tensorflow.python.training.tracking import base
63from tensorflow.python.training.tracking import graph_view
64from tensorflow.python.training.tracking import tracking
65from tensorflow.python.training.tracking import util
66from tensorflow.python.util import compat
67from tensorflow.python.util import object_identity
68from tensorflow.python.util.tf_export import tf_export
69
70_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
71
72# A container for an EagerTensor constant which has been copied to the exported
73# Graph.
74_CapturedConstant = collections.namedtuple("_CapturedConstant",
75                                           ["eager_tensor", "graph_tensor"])
76
77# Number of untraced functions to display to user in warning message.
78_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
79
80
81class _AugmentedGraphView(graph_view.ObjectGraphView):
82  """An extendable graph which also tracks functions attached to objects.
83
84  Extensions through `add_object` appear in the object graph and any checkpoints
85  generated from it, even if they are not dependencies of the node they were
86  attached to in the saving program. For example a `.signatures` attribute is
87  added to exported SavedModel root objects without modifying the root object
88  itself.
89
90  Also tracks functions attached to objects in the graph, through the caching
91  `list_functions` method. Enumerating functions only through this method
92  ensures that we get a consistent view of functions, even if object attributes
93  create new functions every time they are accessed.
94  """
95
96  def __init__(self, root):
97    if (not context.executing_eagerly() and not ops.inside_function()):
98      saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
99    else:
100      saveables_cache = None
101    super(_AugmentedGraphView, self).__init__(root, saveables_cache)
102    # Object -> (name -> dep)
103    self._extra_dependencies = object_identity.ObjectIdentityDictionary()
104    self._functions = object_identity.ObjectIdentityDictionary()
105    # Cache shared between objects in the same object graph. This is passed to
106    # each trackable object's `_list_extra_dependencies_for_serialization` and
107    # `_list_functions_for_serialization` function.
108    self._serialization_cache = object_identity.ObjectIdentityDictionary()
109
110  def add_object(self, parent_node, name_in_parent, subgraph_root):
111    """Attach an object to `parent_node`, overriding any existing dependency."""
112    self._extra_dependencies.setdefault(parent_node,
113                                        {})[name_in_parent] = subgraph_root
114
115  def list_dependencies(self, obj):
116    """Overrides a parent method to include `add_object` objects."""
117    extra_dependencies = self.list_extra_dependencies(obj)
118    extra_dependencies.update(self._extra_dependencies.get(obj, {}))
119
120    used_names = set()
121    for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
122      used_names.add(name)
123      if name in extra_dependencies:
124        # Extra dependencies (except for `.signatures`, which is always added
125        # when saving) should not have naming conflicts with dependencies
126        # defined by the user.
127        if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME:
128          raise ValueError(
129              "Error when exporting object {} of with identifier={}. The object"
130              " has an attribute named {}, which is reserved. List of all "
131              "reserved attributes: {}".format(
132                  obj,
133                  obj._object_identifier,  # pylint: disable=protected-access
134                  name,
135                  extra_dependencies.keys()))
136        yield base.TrackableReference(name, extra_dependencies[name])
137      else:
138        yield base.TrackableReference(name, dep)
139    for name, dep in extra_dependencies.items():
140      if name in used_names:
141        continue
142      yield base.TrackableReference(name, dep)
143
144  def list_extra_dependencies(self, obj):
145    return obj._list_extra_dependencies_for_serialization(  # pylint: disable=protected-access
146        self._serialization_cache)
147
148  def list_functions(self, obj):
149    obj_functions = self._functions.get(obj, None)
150    if obj_functions is None:
151      obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
152          self._serialization_cache)
153      self._functions[obj] = obj_functions
154    return obj_functions
155
156
157class _SaveableView(object):
158  """Provides a frozen view over a trackable root.
159
160  This class helps to create a single stable view over an object to save. The
161  saving code should access properties and functions via this class and not via
162  the original object as there are cases where an object construct their
163  trackable attributes and functions dynamically per call and will yield
164  different objects if invoked more than once.
165
166  Changes to the graph, for example adding objects, must happen in
167  `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is
168  constructed. Changes after the `_SaveableView` has been constructed will be
169  ignored.
170  """
171
172  def __init__(self, checkpoint_view, options, wrapped_functions=None):
173    """Initializes a SaveableView.
174
175    Args:
176      checkpoint_view: A GraphView object.
177      options: A SaveOptions instance.
178      wrapped_functions: Dictionary that maps concrete functions to functions
179        that do not capture cached variable values.
180    """
181
182    self.checkpoint_view = checkpoint_view
183    self._options = options
184    # Maps functions -> wrapped functions that capture variables
185    self._wrapped_functions = wrapped_functions or {}
186    # Run through the nodes in the object graph first for side effects of
187    # creating variables.
188    self._trace_all_concrete_functions()
189
190    (self._trackable_objects, self.node_paths, self._node_ids,
191     self._slot_variables) = (
192         self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
193    self._initialize_nodes_and_concrete_functions()
194
195    # Maps names of concrete functions in the object to names of wrapped
196    # functions. When writing the SavedFunction protos, the names of the
197    # wrapped functions should be used in place of the original functions.
198    self.function_name_map = {
199        compat.as_text(original.name): compat.as_text(wrapped.name)
200        for original, wrapped in self._wrapped_functions.items()}
201    self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
202
203  def _initialize_nodes_and_concrete_functions(self):
204    """Creates graph with nodes for trackable objects and functions.
205
206    Adds functions for each trackable object to `self.nodes` and associated
207    concrete functions to `self.concrete_functions` for serialization. Also adds
208    the object's save and restore functions for loading values from checkpoint.
209    """
210    self.nodes = list(self._trackable_objects)
211    self.concrete_functions = []
212    self._seen_function_names = set()
213    self._untraced_functions = []
214    # Maps node -> local name -> (save function, restore function)
215    self._saveable_objects_map = object_identity.ObjectIdentityDictionary()
216
217    for obj in self._trackable_objects:
218      for function in self.checkpoint_view.list_functions(obj).values():
219        self._add_function_to_graph(function)
220      # Resource (and TPU/Mirrored) variables are automatically revived with
221      # their saveables defined, so there is no need to trace the save
222      # and restore functions.
223      if resource_variable_ops.is_resource_variable(obj):
224        continue
225      # Trace object save and restore functions to populate `saveables_map`
226      # field in the SavedModel proto.
227      saveable_map = saveable_object_util.trace_save_restore_functions(obj)
228      if saveable_map:
229        for save_fn, restore_fn in saveable_map.values():
230          self._add_function_to_graph(save_fn)
231          self._add_function_to_graph(restore_fn)
232        self._saveable_objects_map[obj] = saveable_map
233
234    if self._untraced_functions:
235      logging.warning(
236          "Found untraced functions such as %s while saving (showing %d of %d)."
237          " These functions will not be directly callable after loading.",
238          ", ".join(self._untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
239          min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self._untraced_functions)),
240          len(self._untraced_functions))
241
242  def _add_function_to_graph(self, function):
243    """Adds function to serialize to graph."""
244    # Updates self.nodes, self._node_ids, self.concrete_functions,
245    # and self._untraced_functions.
246    if function not in self._node_ids:
247      self._node_ids[function] = len(self.nodes)
248      # Add the function to nodes as well.
249      self.nodes.append(function)
250    if isinstance(function, def_function.Function):
251      concrete_functions = (
252          function._list_all_concrete_functions_for_serialization())  # pylint: disable=protected-access
253    else:
254      concrete_functions = [function]
255    if not concrete_functions:
256      self._untraced_functions.append(function._name)  # pylint: disable=protected-access
257    for concrete_function in concrete_functions:
258      if concrete_function.name not in self._seen_function_names:
259        self.concrete_functions.append(concrete_function)
260        self._seen_function_names.add(concrete_function.name)
261
262  def _trace_all_concrete_functions(self):
263    """Trace concrete functions to force side-effects.
264
265    Lists the concrete functions in order to:
266      - populate the cache for functions that have an input_signature
267        and have not been called
268      - force side effects of creation of concrete functions, e.g. create
269        variables on first run.
270    """
271    for obj in self.checkpoint_view.list_objects():
272      for function in self.checkpoint_view.list_functions(obj).values():
273        if isinstance(function, def_function.Function):
274          function._list_all_concrete_functions_for_serialization()  # pylint: disable=protected-access
275
276  @property
277  def root(self):
278    return self.nodes[0]
279
280  def fill_object_graph_proto(self, proto):
281    """Populate the nodes, children and slot_variables of a SavedObjectGraph."""
282    for node_id, node in enumerate(self.nodes):
283      assert self._node_ids[node] == node_id
284      object_proto = proto.nodes.add()
285      object_proto.slot_variables.extend(self._slot_variables.get(node, ()))
286      if isinstance(
287          node,
288          (def_function.Function, defun.ConcreteFunction, _CapturedConstant)):
289        continue
290      for child in self.checkpoint_view.list_dependencies(node):
291        child_proto = object_proto.children.add()
292        child_proto.node_id = self._node_ids[child.ref]
293        child_proto.local_name = child.name
294      for local_name, ref_function in (
295          self.checkpoint_view.list_functions(node).items()):
296        child_proto = object_proto.children.add()
297        child_proto.node_id = self._node_ids[ref_function]
298        child_proto.local_name = local_name
299
300      if node not in self._saveable_objects_map:
301        continue
302
303      for local_name, (save_fn, restore_fn) in (
304          self._saveable_objects_map[node].items()):
305        saveable_object_proto = object_proto.saveable_objects[local_name]
306        saveable_object_proto.save_function = self._node_ids[save_fn]
307        saveable_object_proto.restore_function = self._node_ids[restore_fn]
308
309  def map_resources(self):
310    """Makes new resource handle ops corresponding to existing resource tensors.
311
312    Creates resource handle ops in the current default graph, whereas
313    `accessible_objects` will be from an eager context. Resource mapping adds
314    resource handle ops to the main GraphDef of a SavedModel, which allows the
315    C++ loader API to interact with resources.
316
317    Returns:
318      A tuple of (object_map, resource_map, asset_info):
319        object_map: A dictionary mapping from object in `accessible_objects` to
320          replacement objects created to hold the new resource tensors.
321        resource_map: A dictionary mapping from resource tensors extracted from
322          `accessible_objects` to newly created resource tensors.
323        asset_info: An _AssetInfo tuple describing external assets referenced
324          from accessible_objects.
325    """
326    # Only makes sense when adding to the export Graph
327    assert not context.executing_eagerly()
328    # TODO(allenl): Handle MirroredVariables and other types of variables which
329    # may need special casing.
330    object_map = object_identity.ObjectIdentityDictionary()
331    resource_map = {}
332    asset_info = _AssetInfo(
333        asset_defs=[],
334        asset_initializers_by_resource={},
335        asset_filename_map={},
336        asset_index={})
337
338    for node_id, obj in enumerate(self.nodes):
339      if isinstance(obj, tracking.Asset):
340        _process_asset(obj, asset_info, resource_map)
341        self.captured_tensor_node_ids[obj.asset_path] = node_id
342      elif isinstance(obj, base.Trackable):
343        node_object_map, node_resource_map = obj._map_resources(self._options)  # pylint: disable=protected-access
344        for capturable in node_resource_map.keys():
345          self.captured_tensor_node_ids[capturable] = node_id
346        object_map.update(node_object_map)
347        resource_map.update(node_resource_map)
348
349    # Note: some concrete functions can have been realized when tracing other
350    # functions, and might closure-capture tensors from their parent functions.
351    # This is normal, but it means those concrete functions can't be serialized
352    # as their own independent endpoints, so we filter them out here.
353    bad_functions = []
354    for concrete_function in self.concrete_functions:
355      if not concrete_function.graph.saveable:
356        raise ValueError(
357            ("Unable to save function {name} for the following reason(s):\n" +
358             "\n".join(concrete_function.graph.saving_errors)).format(
359                 name=concrete_function.name))
360      for capture in concrete_function.captured_inputs:
361        if (tensor_util.is_tf_type(capture) and
362            capture.dtype not in _UNCOPIABLE_DTYPES and
363            capture not in self.captured_tensor_node_ids):
364          if hasattr(capture, "_cached_variable"):
365            if concrete_function not in self._wrapped_functions:
366              wrapped = self._wrapped_functions[concrete_function] = (
367                  function_serialization.wrap_cached_variables(
368                      concrete_function))
369              self.function_name_map[compat.as_text(concrete_function.name)] = (
370                  compat.as_text(wrapped.name))
371            continue
372          capture_constant_value = tensor_util.constant_value(capture)
373          if capture_constant_value is None:
374            bad_functions.append(concrete_function)
375            continue
376          copied_tensor = constant_op.constant(capture_constant_value)
377          node_id = len(self.nodes)
378          node = _CapturedConstant(
379              eager_tensor=capture, graph_tensor=copied_tensor)
380          self.nodes.append(node)
381          self._node_ids[capture] = node_id
382          self._node_ids[node] = node_id
383          self.captured_tensor_node_ids[capture] = node_id
384          resource_map[capture] = copied_tensor
385
386    self.concrete_functions = [
387        self._wrapped_functions.get(x, x) for x in self.concrete_functions
388        if x not in bad_functions
389    ]
390    return object_map, resource_map, asset_info
391
392
393def _tensor_dict_to_tensorinfo(tensor_dict):
394  return {
395      key: utils_impl.build_tensor_info_internal(value)
396      for key, value in tensor_dict.items()
397  }
398
399
400def _map_captures_to_created_tensors(original_captures, resource_map):
401  """Maps eager tensors captured by a function to Graph resources for export.
402
403  Args:
404    original_captures: A dictionary mapping from tensors captured by the
405      function to interior placeholders for those tensors (inside the function
406      body).
407    resource_map: A dictionary mapping from resource tensors owned by the eager
408      context to resource tensors in the exported graph.
409
410  Returns:
411    A list of stand-in tensors which belong to the exported graph, corresponding
412    to the function's captures.
413
414  Raises:
415    AssertionError: If the function references a resource which is not part of
416      `resource_map`.
417  """
418  export_captures = []
419  for exterior, interior in original_captures:
420    mapped_resource = resource_map.get(exterior, None)
421    if mapped_resource is None:
422      trackable_referrers = []
423      # Try to figure out where the resource came from by iterating over objects
424      # which reference it. This is slow and doesn't help us figure out how to
425      # match it to other objects when loading the SavedModel as a checkpoint,
426      # so we can't continue saving. But we can at least tell the user what
427      # needs attaching.
428      for primary_referrer in gc.get_referrers(exterior):
429        if isinstance(primary_referrer, base.Trackable):
430          trackable_referrers.append(primary_referrer)
431        for secondary_referrer in gc.get_referrers(primary_referrer):
432          if isinstance(secondary_referrer, base.Trackable):
433            trackable_referrers.append(secondary_referrer)
434      raise AssertionError(
435          ("Tried to export a function which references untracked resource {}. "
436           "TensorFlow objects (e.g. tf.Variable) captured by functions must "
437           "be tracked by assigning them to an attribute of a tracked object "
438           "or assigned to an attribute of the main object directly.\n\n"
439           "Trackable Python objects referring to this tensor "
440           "(from gc.get_referrers, limited to two hops):\n{}"
441          ).format(interior,
442                   "\n".join([repr(obj) for obj in trackable_referrers])))
443    export_captures.append(mapped_resource)
444  return export_captures
445
446
447def _map_function_arguments_to_created_inputs(function_arguments, signature_key,
448                                              function_name):
449  """Creates exterior placeholders in the exported graph for function arguments.
450
451  Functions have two types of inputs: tensors captured from the outside (eager)
452  context, and arguments to the function which we expect to receive from the
453  user at each call. `_map_captures_to_created_tensors` replaces
454  captured tensors with stand-ins (typically these are resource dtype tensors
455  associated with variables). `_map_function_inputs_to_created_inputs` runs over
456  every argument, creating a new placeholder for each which will belong to the
457  exported graph rather than the function body.
458
459  Args:
460    function_arguments: A list of argument placeholders in the function body.
461    signature_key: The name of the signature being exported, for error messages.
462    function_name: The name of the function, for error messages.
463
464  Returns:
465    A tuple of (mapped_inputs, exterior_placeholders)
466      mapped_inputs: A list with entries corresponding to `function_arguments`
467        containing all of the inputs of the function gathered from the exported
468        graph (both captured resources and arguments).
469      exterior_argument_placeholders: A dictionary mapping from argument names
470        to placeholders in the exported graph, containing the explicit arguments
471        to the function which a user is expected to provide.
472
473  Raises:
474    ValueError: If argument names are not unique.
475  """
476  # `exterior_argument_placeholders` holds placeholders which are outside the
477  # function body, directly contained in a MetaGraph of the SavedModel. The
478  # function body itself contains nearly identical placeholders used when
479  # running the function, but these exterior placeholders allow Session-based
480  # APIs to call the function using feeds and fetches which name Tensors in the
481  # MetaGraph.
482  exterior_argument_placeholders = {}
483  mapped_inputs = []
484  for placeholder in function_arguments:
485    # `export_captures` contains an exhaustive set of captures, so if we don't
486    # find the input there then we now know we have an argument.
487    user_input_name = compat.as_str_any(
488        placeholder.op.get_attr("_user_specified_name"))
489    # If the internal placeholders for a function have names which were
490    # uniquified by TensorFlow, then a single user-specified argument name
491    # must refer to multiple Tensors. The resulting signatures would be
492    # confusing to call. Instead, we throw an exception telling the user to
493    # specify explicit names.
494    if user_input_name != placeholder.op.name:
495      # This should be unreachable, since concrete functions may not be
496      # generated with non-unique argument names.
497      raise ValueError(
498          ("Got non-flat/non-unique argument names for SavedModel "
499           "signature '{}': more than one argument to '{}' was named '{}'. "
500           "Signatures have one Tensor per named input, so to have "
501           "predictable names Python functions used to generate these "
502           "signatures should avoid *args and Tensors in nested "
503           "structures unless unique names are specified for each. Use "
504           "tf.TensorSpec(..., name=...) to provide a name for a Tensor "
505           "input.").format(signature_key, compat.as_str_any(function_name),
506                            user_input_name))
507    arg_placeholder = array_ops.placeholder(
508        shape=placeholder.shape,
509        dtype=placeholder.dtype,
510        name="{}_{}".format(signature_key, user_input_name))
511    exterior_argument_placeholders[user_input_name] = arg_placeholder
512    mapped_inputs.append(arg_placeholder)
513  return mapped_inputs, exterior_argument_placeholders
514
515
516def _call_function_with_mapped_captures(function, args, resource_map):
517  """Calls `function` in the exported graph, using mapped resource captures."""
518  export_captures = _map_captures_to_created_tensors(function.graph.captures,
519                                                     resource_map)
520  # Calls the function quite directly, since we have new captured resource
521  # tensors we need to feed in which weren't part of the original function
522  # definition.
523  # pylint: disable=protected-access
524  outputs = function._call_flat(args, export_captures)
525  # pylint: enable=protected-access
526  return outputs
527
528
529def _generate_signatures(signature_functions, resource_map):
530  """Validates and calls `signature_functions` in the default graph.
531
532  Args:
533    signature_functions: A dictionary mapping string keys to concrete TensorFlow
534      functions (e.g. from `signature_serialization.canonicalize_signatures`)
535      which will be used to generate SignatureDefs.
536    resource_map: A dictionary mapping from resource tensors in the eager
537      context to resource tensors in the Graph being exported. This dictionary
538      is used to re-bind resources captured by functions to tensors which will
539      exist in the SavedModel.
540
541  Returns:
542    Each function in the `signature_functions` dictionary is called with
543    placeholder Tensors, generating a function call operation and output
544    Tensors. The placeholder Tensors, the function call operation, and the
545    output Tensors from the function call are part of the default Graph.
546
547    This function then returns a dictionary with the same structure as
548    `signature_functions`, with the concrete functions replaced by SignatureDefs
549    implicitly containing information about how to call each function from a
550    TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
551    the generated placeholders and Tensor outputs by name.
552
553    The caller is expected to include the default Graph set while calling this
554    function as a MetaGraph in a SavedModel, including the returned
555    SignatureDefs as part of that MetaGraph.
556  """
557  signatures = {}
558  for signature_key, function in sorted(signature_functions.items()):
559    if function.graph.captures:
560      argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
561    else:
562      argument_inputs = function.graph.inputs
563    mapped_inputs, exterior_argument_placeholders = (
564        _map_function_arguments_to_created_inputs(argument_inputs,
565                                                  signature_key, function.name))
566    outputs = _call_function_with_mapped_captures(
567        function, mapped_inputs, resource_map)
568    signatures[signature_key] = signature_def_utils.build_signature_def(
569        _tensor_dict_to_tensorinfo(exterior_argument_placeholders),
570        _tensor_dict_to_tensorinfo(outputs),
571        method_name=signature_constants.PREDICT_METHOD_NAME)
572  return signatures
573
574
575def _trace_resource_initializers(accessible_objects):
576  """Create concrete functions from `CapturableResource` objects."""
577  resource_initializers = []
578
579  def _wrap_initializer(obj):
580    obj._initialize()  # pylint: disable=protected-access
581    return constant_op.constant(1.)  # Dummy control output
582
583  def _wrap_obj_initializer(obj):
584    return lambda: _wrap_initializer(obj)
585
586  for obj in accessible_objects:
587    if isinstance(obj, tracking.CapturableResource):
588      resource_initializers.append(
589          def_function.function(
590              _wrap_obj_initializer(obj),
591              # All inputs are captures.
592              input_signature=[]).get_concrete_function())
593  return resource_initializers
594
595
596_AssetInfo = collections.namedtuple(
597    "_AssetInfo",
598    [
599        # List of AssetFileDef protocol buffers
600        "asset_defs",
601        # Map from asset variable resource Tensors to their init ops
602        "asset_initializers_by_resource",
603        # Map from base asset filenames to full paths
604        "asset_filename_map",
605        # Map from Asset to index of corresponding AssetFileDef
606        "asset_index"
607    ])
608
609
610def _process_asset(trackable_asset, asset_info, resource_map):
611  """Add `trackable_asset` to `asset_info` and `resource_map`."""
612  original_path_tensor = trackable_asset.asset_path
613  original_path = tensor_util.constant_value(original_path_tensor)
614  try:
615    original_path = str(original_path.astype(str))
616  except AttributeError:
617    # Already a string rather than a numpy array
618    pass
619  path = builder_impl.get_asset_filename_to_add(
620      asset_filepath=original_path,
621      asset_filename_map=asset_info.asset_filename_map)
622  # TODO(andresp): Instead of mapping 1-1 between trackable asset
623  # and asset in the graph def consider deduping the assets that
624  # point to the same file.
625  asset_path_initializer = array_ops.placeholder(
626      shape=original_path_tensor.shape,
627      dtype=dtypes.string,
628      name="asset_path_initializer")
629  asset_variable = resource_variable_ops.ResourceVariable(
630      asset_path_initializer)
631  asset_info.asset_filename_map[path] = original_path
632  asset_def = meta_graph_pb2.AssetFileDef()
633  asset_def.filename = path
634  asset_def.tensor_info.name = asset_path_initializer.name
635  asset_info.asset_defs.append(asset_def)
636  asset_info.asset_initializers_by_resource[original_path_tensor] = (
637      asset_variable.initializer)
638  asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
639  resource_map[original_path_tensor] = asset_variable
640
641
642def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
643                         namespace_whitelist):
644  """Generates a MetaGraph which calls `signature_functions`.
645
646  Args:
647    meta_graph_def: The MetaGraphDef proto to fill.
648    saveable_view: The _SaveableView being exported.
649    signature_functions: A dictionary mapping signature keys to concrete
650      functions containing signatures to add to the MetaGraph.
651    namespace_whitelist: List of strings containing whitelisted op namespaces.
652
653  Returns:
654    A tuple of (_AssetInfo, Graph) containing the captured assets and
655    exported Graph generated from tracing the saveable_view.
656  """
657  # List objects from the eager context to make sure Optimizers give us the
658  # right Graph-dependent variables.
659  accessible_objects = saveable_view.nodes
660  resource_initializer_functions = _trace_resource_initializers(
661      accessible_objects)
662  exported_graph = ops.Graph()
663  resource_initializer_ops = []
664  with exported_graph.as_default():
665    object_map, resource_map, asset_info = saveable_view.map_resources()
666    for resource_initializer_function in resource_initializer_functions:
667      asset_dependencies = []
668      for capture in resource_initializer_function.graph.external_captures:
669        asset_initializer = asset_info.asset_initializers_by_resource.get(
670            capture, None)
671        if asset_initializer is not None:
672          asset_dependencies.append(asset_initializer)
673      with ops.control_dependencies(asset_dependencies):
674        resource_initializer_ops.append(
675            _call_function_with_mapped_captures(resource_initializer_function,
676                                                [], resource_map))
677    resource_initializer_ops.extend(
678        asset_info.asset_initializers_by_resource.values())
679    with ops.control_dependencies(resource_initializer_ops):
680      init_op = control_flow_ops.no_op()
681    # Add the same op to the main_op collection and to the init_op
682    # signature. The collection is for compatibility with older loader APIs;
683    # only one will be executed.
684    meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
685        init_op.name)
686    meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
687        signature_def_utils.op_signature_def(init_op,
688                                             constants.INIT_OP_SIGNATURE_KEY))
689
690  # Saving an object-based checkpoint again gathers variables. We need to do the
691  # gathering from the eager context so Optimizers save the right set of
692  # variables, but want any operations associated with the save/restore to be in
693  # the exported graph (thus the `to_graph` argument).
694  saver = functional_saver.MultiDeviceSaver(
695      saveable_view.checkpoint_view.frozen_saveable_objects(
696          object_map=object_map, to_graph=exported_graph,
697          call_with_mapped_captures=functools.partial(
698              _call_function_with_mapped_captures, resource_map=resource_map)))
699
700  with exported_graph.as_default():
701    signatures = _generate_signatures(signature_functions, resource_map)
702    for concrete_function in saveable_view.concrete_functions:
703      concrete_function.add_to_graph()
704    saver_def = saver.to_proto()
705    meta_graph_def.saver_def.CopyFrom(saver_def)
706  graph_def = exported_graph.as_graph_def(add_shapes=True)
707  _verify_ops(graph_def, namespace_whitelist)
708
709  meta_graph_def.graph_def.CopyFrom(graph_def)
710  meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
711  meta_graph_def.meta_info_def.tensorflow_version = versions.__version__
712  meta_graph_def.meta_info_def.tensorflow_git_version = (
713      versions.__git_version__)
714  # We currently always strip default attributes.
715  meta_graph_def.meta_info_def.stripped_default_attrs = True
716  meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
717      meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def))
718  meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
719  for signature_key, signature in signatures.items():
720    meta_graph_def.signature_def[signature_key].CopyFrom(signature)
721  meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
722  return asset_info, exported_graph
723
724
725def _verify_ops(graph_def, namespace_whitelist):
726  """Verifies that all namespaced ops in the graph are whitelisted."""
727  invalid_ops = []
728  invalid_namespaces = set()
729
730  all_operations = []
731  all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def))
732
733  for op in all_operations:
734    if ">" in op:
735      namespace = op.split(">")[0]
736      if namespace not in namespace_whitelist:
737        invalid_ops.append(op)
738        invalid_namespaces.add(namespace)
739  if invalid_ops:
740    raise ValueError(
741        "Attempted to save ops from non-whitelisted namespaces to SavedModel: "
742        "{}.\nPlease verify that these ops should be saved, since they must be "
743        "available when loading the SavedModel. If loading from Python, you "
744        "must import the library defining these ops. From C++, link the custom "
745        "ops to the serving binary. Once you've confirmed this, please add the "
746        "following namespaces to the `namespace_whitelist` argument in "
747        "tf.saved_model.SaveOptions: {}.".format(invalid_ops,
748                                                 invalid_namespaces))
749
750
751def _serialize_object_graph(saveable_view, asset_file_def_index):
752  """Save a SavedObjectGraph proto for `root`."""
753  # SavedObjectGraph is similar to the TrackableObjectGraph proto in the
754  # checkpoint. It will eventually go into the SavedModel.
755  proto = saved_object_graph_pb2.SavedObjectGraph()
756  saveable_view.fill_object_graph_proto(proto)
757
758  coder = nested_structure_coder.StructureCoder()
759  for concrete_function in saveable_view.concrete_functions:
760    name = compat.as_text(concrete_function.name)
761    name = saveable_view.function_name_map.get(name, name)
762    serialized = function_serialization.serialize_concrete_function(
763        concrete_function, saveable_view.captured_tensor_node_ids, coder)
764    if serialized is not None:
765      proto.concrete_functions[name].CopyFrom(serialized)
766
767  saved_object_metadata = False
768  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
769    has_saved_object_metadata = _write_object_proto(
770        obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
771    saved_object_metadata = saved_object_metadata or has_saved_object_metadata
772  return proto, saved_object_metadata
773
774
775def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
776  """Saves an object into SavedObject proto."""
777  has_saved_object_metadata = False  # The metadata field will be deprecated.
778  if isinstance(obj, tracking.Asset):
779    proto.asset.SetInParent()
780    proto.asset.asset_file_def_index = asset_file_def_index[obj]
781  elif resource_variable_ops.is_resource_variable(obj):
782    proto.variable.SetInParent()
783    if not obj.name.endswith(":0"):
784      raise ValueError("Cowardly refusing to save variable {} because of"
785                       " unexpected suffix which won't be restored.".format(
786                           obj.name))
787    proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
788    proto.variable.trainable = obj.trainable
789    proto.variable.dtype = obj.dtype.as_datatype_enum
790    proto.variable.synchronization = obj.synchronization.value
791    proto.variable.aggregation = obj.aggregation.value
792    proto.variable.shape.CopyFrom(obj.shape.as_proto())
793    options = save_context.get_save_options()
794    if options.experimental_variable_policy._save_variable_devices(  # pylint: disable=protected-access
795    ):
796      if hasattr(obj, "device"):
797        proto.variable.device = obj.device
798  elif isinstance(obj, def_function.Function):
799    proto.function.CopyFrom(function_serialization.serialize_function(
800        obj, function_name_map))
801  elif isinstance(obj, defun.ConcreteFunction):
802    proto.bare_concrete_function.CopyFrom(
803        function_serialization.serialize_bare_concrete_function(
804            obj, function_name_map))
805  elif isinstance(obj, _CapturedConstant):
806    proto.constant.operation = obj.graph_tensor.op.name
807  elif isinstance(obj, tracking.CapturableResource):
808    proto.resource.device = obj._resource_device  # pylint: disable=protected-access
809  else:
810    registered_type_proto = revived_types.serialize(obj)
811    if registered_type_proto is None:
812      # Fallback for types with no matching registration
813      # pylint:disable=protected-access
814      metadata = obj._tracking_metadata
815      if metadata:
816        has_saved_object_metadata = True
817      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
818          identifier=obj._object_identifier,
819          version=versions_pb2.VersionDef(
820              producer=1, min_consumer=1, bad_consumers=[]),
821          metadata=metadata)
822      # pylint:enable=protected-access
823    proto.user_object.CopyFrom(registered_type_proto)
824
825  # Give the object a chance to modify the SavedObject proto.
826  # This is currently used by MirroredVariables to optionally write their
827  # component variables to the proto.
828  #
829  # This is not yet an official Trackable method, the only current use case
830  # being MirroredVariables. See the method implementation there for more
831  # documentation.
832  if hasattr(obj, "_write_object_proto"):
833    obj._write_object_proto(proto, options)  # pylint: disable=protected-access
834  return has_saved_object_metadata
835
836
837def _export_debug_info(exported_graph, export_dir):
838  """Exports debug information from graph to file.
839
840  Creates and writes GraphDebugInfo with traces for ops in all functions of the
841  exported_graph.
842
843  Args:
844    exported_graph: A Graph that has been created by tracing a saveable view.
845    export_dir: SavedModel directory in which to write the debug info.
846  """
847  exported_operations = []
848  for fn_name in exported_graph._functions:  # pylint: disable=protected-access
849    fn = exported_graph._get_function(fn_name)  # pylint: disable=protected-access
850    if not isinstance(fn, defun._EagerDefinedFunction):  # pylint: disable=protected-access
851      continue
852
853    fn_graph = fn.graph
854    for fn_op in fn_graph.get_operations():
855      exported_operations.append((fn_name, fn_op))
856
857  graph_debug_info = error_interpolation.create_graph_debug_info_def(
858      exported_operations)
859  file_io.atomic_write_string_to_file(
860      os.path.join(
861          utils_impl.get_or_create_debug_dir(export_dir),
862          constants.DEBUG_INFO_FILENAME_PB),
863      graph_debug_info.SerializeToString(deterministic=True))
864
865
866@tf_export(
867    "saved_model.save",
868    v1=["saved_model.save", "saved_model.experimental.save"])
869def save(obj, export_dir, signatures=None, options=None):
870  # pylint: disable=line-too-long
871  """Exports a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) (and subclasses) `obj` to [SavedModel format](https://www.tensorflow.org/guide/saved_model#the_savedmodel_format_on_disk).
872
873  The `obj` must inherit from the [`Trackable` class](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/tracking/base.py#L591).
874
875  Example usage:
876
877  >>> class Adder(tf.Module):
878  ...   @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
879  ...   def add(self, x):
880  ...     return x + x
881
882  >>> model = Adder()
883  >>> tf.saved_model.save(model, '/tmp/adder')
884
885  The resulting SavedModel is then servable with an input named "x", a scalar
886  with dtype float32.
887
888  _Signatures_
889
890  Signatures define the input and output types for a computation. The optional
891  save `signatures` argument controls which methods in `obj` will be
892  available to programs which consume `SavedModel`s, for example, serving
893  APIs. Python functions may be decorated with
894  `@tf.function(input_signature=...)` and passed as signatures directly, or
895  lazily with a call to `get_concrete_function` on the method decorated with
896  `@tf.function`.
897
898  Example:
899
900  >>> class Adder(tf.Module):
901  ...   @tf.function
902  ...   def add(self, x):
903  ...     return x + x
904
905  >>> model = Adder()
906  >>> tf.saved_model.save(
907  ...   model, '/tmp/adder',signatures=model.add.get_concrete_function(
908  ...     tf.TensorSpec([], tf.float32)))
909
910  If a `@tf.function` does not have an input signature and
911  `get_concrete_function` is not called on that method, the function will not
912  be directly callable in the restored SavedModel.
913
914  Example:
915
916  >>> class Adder(tf.Module):
917  ...   @tf.function
918  ...   def add(self, x):
919  ...     return x + x
920
921  >>> model = Adder()
922  >>> tf.saved_model.save(model, '/tmp/adder')
923  >>> restored = tf.saved_model.load('/tmp/adder')
924  >>> restored.add(1.)
925  Traceback (most recent call last):
926  ...
927  ValueError: Found zero restored functions for caller function.
928
929  If the `signatures` argument is omitted, `obj` will be searched for
930  `@tf.function`-decorated methods. If exactly one traced `@tf.function` is
931  found, that method will be used as the default signature for the SavedModel.
932  Else, any `@tf.function` attached to `obj` or its dependencies will be
933  exported for use with `tf.saved_model.load`.
934
935  When invoking a signature in an exported SavedModel, `Tensor` arguments are
936  identified by name. These names will come from the Python function's argument
937  names by default. They may be overridden by specifying a `name=...` argument
938  in the corresponding `tf.TensorSpec` object. Explicit naming is required if
939  multiple `Tensor`s are passed through a single argument to the Python
940  function.
941
942  The outputs of functions used as `signatures` must either be flat lists, in
943  which case outputs will be numbered, or a dictionary mapping string keys to
944  `Tensor`, in which case the keys will be used to name outputs.
945
946  Signatures are available in objects returned by `tf.saved_model.load` as a
947  `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
948  on an object with a custom `.signatures` attribute will raise an exception.
949
950  _Using `tf.saved_model.save` with Keras models_
951
952  While Keras has its own [saving and loading API](https://www.tensorflow.org/guide/keras/save_and_serialize),
953  this function can be used to export Keras models. For example, exporting with
954  a signature specified:
955
956  >>> class Adder(tf.keras.Model):
957  ...   @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
958  ...   def concat(self, x):
959  ...      return x + x
960
961  >>> model = Adder()
962  >>> tf.saved_model.save(model, '/tmp/adder')
963
964  Exporting from a function without a fixed signature:
965
966  >>> class Adder(tf.keras.Model):
967  ...   @tf.function
968  ...   def concat(self, x):
969  ...      return x + x
970
971  >>> model = Adder()
972  >>> tf.saved_model.save(
973  ...   model, '/tmp/adder',
974  ...   signatures=model.concat.get_concrete_function(
975  ...     tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))
976
977  `tf.keras.Model` instances constructed from inputs and outputs already have a
978  signature and so do not require a `@tf.function` decorator or a `signatures`
979  argument. If neither are specified, the model's forward pass is exported.
980
981  >>> x = tf.keras.layers.Input((4,), name="x")
982  >>> y = tf.keras.layers.Dense(5, name="out")(x)
983  >>> model = tf.keras.Model(x, y)
984  >>> tf.saved_model.save(model, '/tmp/saved_model/')
985
986  The exported SavedModel takes "x" with shape [None, 4] and returns "out"
987  with shape [None, 5]
988
989  _Variables and Checkpoints_
990
991  Variables must be tracked by assigning them to an attribute of a tracked
992  object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
993  from `tf.keras.layers`, optimizers from `tf.train`) track their variables
994  automatically. This is the same tracking scheme that `tf.train.Checkpoint`
995  uses, and an exported `Checkpoint` object may be restored as a training
996  checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
997  "variables/" subdirectory.
998
999  `tf.function` does not hard-code device annotations from outside the function
1000  body, instead of using the calling context's device. This means for example
1001  that exporting a model that runs on a GPU and serving it on a CPU will
1002  generally work, with some exceptions:
1003
1004    * `tf.device` annotations inside the body of the function will be hard-coded
1005      in the exported model; this type of annotation is discouraged.
1006    * Device-specific operations, e.g. with "cuDNN" in the name or with
1007      device-specific layouts, may cause issues.
1008    * For `ConcreteFunctions`, active distribution strategies will cause device
1009      placements to be hard-coded in the function.
1010
1011  SavedModels exported with `tf.saved_model.save` [strip default-valued
1012  attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
1013  automatically, which removes one source of incompatibilities when the consumer
1014  of a SavedModel is running an older TensorFlow version than the
1015  producer. There are however other sources of incompatibilities which are not
1016  handled automatically, such as when the exported model contains operations
1017  which the consumer does not have definitions for.
1018
1019  Args:
1020    obj: A trackable object (e.g. tf.Module or tf.train.Checkpoint) to export.
1021    export_dir: A directory in which to write the SavedModel.
1022    signatures: Optional, one of three types:
1023      * a `tf.function` with an input signature specified, which will use the
1024        default serving signature key,
1025      * the result of `f.get_concrete_function` on a `@tf.function`-decorated
1026        function `f`, in which case `f` will be used to generate a signature for
1027        the SavedModel under the default serving signature key,
1028      * a dictionary, which maps signature keys to either `tf.function`
1029        instances with input signatures or concrete functions. Keys of such a
1030        dictionary may be arbitrary strings, but will typically be from the
1031        `tf.saved_model.signature_constants` module.
1032    options: `tf.saved_model.SaveOptions` object for configuring save options.
1033
1034  Raises:
1035    ValueError: If `obj` is not trackable.
1036
1037  @compatibility(eager)
1038  Not well supported when graph building. From TensorFlow 1.x,
1039  `tf.compat.v1.enable_eager_execution()` should run first. Calling
1040  tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
1041  add new save operations to the default graph each iteration.
1042
1043  May not be called from within a function body.
1044  @end_compatibility
1045  """
1046  # pylint: enable=line-too-long
1047  save_and_return_nodes(obj, export_dir, signatures, options,
1048                        raise_metadata_warning=True)
1049
1050
1051def save_and_return_nodes(obj,
1052                          export_dir,
1053                          signatures=None,
1054                          options=None,
1055                          raise_metadata_warning=False,
1056                          experimental_skip_checkpoint=False):
1057  """Saves a SavedModel while returning all saved nodes and their paths.
1058
1059  Please see `tf.saved_model.save` for details.
1060
1061  Args:
1062    obj: A trackable object to export.
1063    export_dir: A directory in which to write the SavedModel.
1064    signatures: A function or dictionary of functions to save in the SavedModel
1065      as signatures.
1066    options: `tf.saved_model.SaveOptions` object for configuring save options.
1067    raise_metadata_warning: Whether to raise the metadata warning. This arg will
1068      be removed in TF 2.5.
1069    experimental_skip_checkpoint: If set to `True`, the checkpoint will not
1070      be written.
1071
1072  Returns:
1073    A tuple of (a list of saved nodes in the order they are serialized to the
1074      `SavedObjectGraph`, dictionary mapping nodes to one possible path from
1075      the root node to the key node)
1076  """
1077  options = options or save_options.SaveOptions()
1078  # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
1079  # compatible (no sessions) and share it with this export API rather than
1080  # making a SavedModel proto and writing it directly.
1081  saved_model = saved_model_pb2.SavedModel()
1082  meta_graph_def = saved_model.meta_graphs.add()
1083
1084  _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
1085      _build_meta_graph(obj, signatures, options, meta_graph_def,
1086                        raise_metadata_warning))
1087  saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
1088
1089  # Write the checkpoint, copy assets into the assets directory, and write out
1090  # the SavedModel proto itself.
1091  if not experimental_skip_checkpoint:
1092    utils_impl.get_or_create_variables_dir(export_dir)
1093    ckpt_options = checkpoint_options.CheckpointOptions(
1094        experimental_io_device=options.experimental_io_device)
1095    object_saver.save(
1096        utils_impl.get_variables_path(export_dir), options=ckpt_options)
1097    builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
1098                                                export_dir)
1099  # Note that this needs to be the last file operation when saving the
1100  # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
1101  # indication that the SavedModel is completely written.
1102  if context.executing_eagerly():
1103    try:
1104      context.async_wait()  # Ensure save operations have completed.
1105    except errors.NotFoundError as err:
1106      raise FileNotFoundError(
1107          str(err) + "\n If trying to save on a different device from the "
1108          "computational device, consider using setting the "
1109          "`experimental_io_device` option on tf.saved_model.SaveOptions "
1110          "to the io_device such as '/job:localhost'."
1111      )
1112
1113  path = os.path.join(
1114      compat.as_str(export_dir),
1115      compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
1116  file_io.atomic_write_string_to_file(
1117      path, saved_model.SerializeToString(deterministic=True))
1118  # Save debug info, if requested.
1119  if options.save_debug_info:
1120    _export_debug_info(exported_graph, export_dir)
1121
1122  # Clean reference cycles so repeated export()s don't make work for the garbage
1123  # collector. Before this point, we need to keep references to captured
1124  # constants in the saved graph.
1125  ops.dismantle_graph(exported_graph)
1126
1127  return saved_nodes, node_paths
1128
1129
1130def export_meta_graph(obj, filename, signatures=None, options=None):
1131  """Exports the MetaGraph proto of the `obj` to a file.
1132
1133  This function goes through the same procedures saved_model.save goes to
1134  produce the given object's MetaGraph, then saves it to the given file. It
1135  skips saving checkpoint information, and is useful when all one wants is the
1136  graph defining the model.
1137
1138  Args:
1139    obj: A trackable object to build the MetaGraph from.
1140    filename: The file into which to write the MetaGraph.
1141    signatures: Optional, either a `tf.function` with an input signature
1142      specified or the result of `f.get_concrete_function` on a
1143      `@tf.function`-decorated function `f`, in which case `f` will be used to
1144      generate a signature for the SavedModel under the default serving
1145      signature key. `signatures` may also be a dictionary, in which case it
1146      maps from signature keys to either `tf.function` instances with input
1147      signatures or concrete functions. The keys of such a dictionary may be
1148      arbitrary strings, but will typically be from the
1149      `tf.saved_model.signature_constants` module.
1150    options: Optional, `tf.saved_model.SaveOptions` object that specifies
1151      options for saving.
1152  """
1153  options = options or save_options.SaveOptions()
1154  export_dir = os.path.dirname(filename)
1155  meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
1156      obj, signatures, options)
1157
1158  file_io.atomic_write_string_to_file(
1159      filename, meta_graph_def.SerializeToString(deterministic=True))
1160
1161  # Save debug info, if requested.
1162  if options.save_debug_info:
1163    _export_debug_info(exported_graph, export_dir)
1164
1165  # Clean reference cycles so repeated export()s don't make work for the garbage
1166  # collector. Before this point, we need to keep references to captured
1167  # constants in the saved graph.
1168  ops.dismantle_graph(exported_graph)
1169
1170
1171def _build_meta_graph_impl(obj,
1172                           signatures,
1173                           options,
1174                           meta_graph_def=None,
1175                           raise_metadata_warning=True):
1176  """Creates a MetaGraph containing the resources and functions of an object."""
1177  if ops.inside_function():
1178    raise AssertionError(
1179        "tf.saved_model.save is not supported inside a traced @tf.function. "
1180        "Move the call to the outer eagerly-executed context.")
1181  # pylint: enable=line-too-long
1182  if not isinstance(obj, base.Trackable):
1183    raise ValueError(
1184        "Expected a Trackable object for export, got {}.".format(obj))
1185  meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
1186
1187  checkpoint_graph_view = _AugmentedGraphView(obj)
1188  if signatures is None:
1189    signatures = signature_serialization.find_function_to_export(
1190        checkpoint_graph_view)
1191
1192  signatures, wrapped_functions = (
1193      signature_serialization.canonicalize_signatures(signatures))
1194  signature_serialization.validate_saveable_view(checkpoint_graph_view)
1195  signature_map = signature_serialization.create_signature_map(signatures)
1196  checkpoint_graph_view.add_object(
1197      parent_node=checkpoint_graph_view.root,
1198      name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
1199      subgraph_root=signature_map)
1200
1201  # Use _SaveableView to provide a frozen listing of properties and functions.
1202  saveable_view = _SaveableView(checkpoint_graph_view, options,
1203                                wrapped_functions)
1204  object_saver = util.TrackableSaver(checkpoint_graph_view)
1205  asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
1206                                                    saveable_view, signatures,
1207                                                    options.namespace_whitelist)
1208  if options.function_aliases:
1209    function_aliases = meta_graph_def.meta_info_def.function_aliases
1210    for alias, func in options.function_aliases.items():
1211      for fdef in func._stateful_fn._function_cache.all_values():  # pylint: disable=protected-access
1212        function_aliases[fdef.name] = alias
1213      for fdef in func._stateless_fn._function_cache.all_values():  # pylint: disable=protected-access
1214        function_aliases[fdef.name] = alias
1215
1216  object_graph_proto, saved_object_metadata = _serialize_object_graph(
1217      saveable_view, asset_info.asset_index)
1218  meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
1219
1220  if saved_object_metadata and raise_metadata_warning:
1221    tf_logging.warn(
1222        'FOR KERAS USERS: The object that you are saving contains one or more '
1223        'Keras models or layers. If you are loading the SavedModel with '
1224        '`tf.keras.models.load_model`, continue reading (otherwise, you may '
1225        'ignore the following instructions). Please change your code to save '
1226        'with `tf.keras.models.save_model` or `model.save`, and confirm that '
1227        'the file "keras.metadata" exists in the export directory. In the '
1228        'future, Keras will only load the SavedModels that have this file. In '
1229        'other words, `tf.saved_model.save` will no longer write SavedModels '
1230        'that can be recovered as Keras models (this will apply in TF 2.5).'
1231        '\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
1232        ' this property has been used to save metadata in the SavedModel. The '
1233        'metadta field will be deprecated soon, so please move the metadata to '
1234        'a different file.')
1235
1236  return (meta_graph_def, exported_graph, object_saver, asset_info,
1237          saveable_view.nodes, saveable_view.node_paths)
1238
1239
1240def _build_meta_graph(obj,
1241                      signatures,
1242                      options,
1243                      meta_graph_def=None,
1244                      raise_metadata_warning=True):
1245  """Creates a MetaGraph under a save context.
1246
1247  Args:
1248    obj: A trackable object to build the MetaGraph from.
1249    signatures: Can be a `tf.function` with an input signature specified or the
1250      result of `f.get_concrete_function` on a `@tf.function`-decorated function
1251      `f`. `signatures` may also be a dictionary, in which case it maps from
1252      signature keys to `tf.function` instances. If None, finds signature to
1253      export from the `@tf.function`-decorated methods in `obj`.
1254    options: `tf.saved_model.SaveOptions` object that specifies options for
1255      saving.
1256    meta_graph_def: Optional, the MetaGraphDef proto fill.
1257    raise_metadata_warning: Whether to raise a warning when user objects contain
1258      non-empty metadata.
1259
1260  Raises:
1261    AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
1262    ValueError: If `obj` is not trackable.
1263
1264  Returns:
1265    meta_graph_def: Filled MetaGraphDef proto
1266    exported_graph: `tf.Graph` object generated from `obj`.
1267    object_saver: `util.TrackableSaver` of the `obj` and its dependencies.
1268    asset_info: `_AssetInfo` tuple containing external assets in the `obj`.
1269  """
1270
1271  with save_context.save_context(options):
1272    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1273                                  raise_metadata_warning)
1274