1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Options for saving SavedModels."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum
22import six
23
24from tensorflow.python.util import compat
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("saved_model.experimental.VariablePolicy")
29class VariablePolicy(enum.Enum):
30  """Enum defining options for variable handling when saving.
31
32  NONE
33    No policy applied: Distributed variables are saved as one variable, with no
34    device attached.
35
36  SAVE_VARIABLE_DEVICES
37    When saving variables, also save their device assignment.
38    This is useful if one wants to hardcode devices in saved models, but it also
39    makes them non-portable if soft device placement is disabled (more details
40    in `tf.config.set_soft_device_placement`). This is currently not
41    fully supported by `saved_model.load`, and is mainly intended to be used
42    when one will be reading the saved model at a lower API level. In the
43    example below, the graph saved by the call to `saved_model.save` will have
44    the variable devices correctly specified:
45    ```python
46    exported = tf.train.Checkpoint()
47    with tf.device('/GPU:0'):
48      exported.x_gpu = tf.Variable(1.0)
49    with tf.device('/CPU:0'):
50      exported.x_cpu = tf.Variable(1.0)
51    tf.saved_model.save(exported, export_dir,
52        options = tf.saved_model.SaveOptions(
53            experimental_variable_policy=
54              tf.saved_model.experimental.VariablePolicy.SAVE_VARIABLE_DEVICES))
55    ```
56    Distributed variables are still saved as one variable under this policy.
57
58  EXPAND_DISTRIBUTED_VARIABLES
59    Distributed variables will be saved with information about their components,
60    allowing for their restoration on load. Also, the saved graph will contain
61    references to those variables. This is useful when one wants to use the
62    model for training in environments where the original distribution strategy
63    is not available.
64  """
65
66  NONE = None
67
68  SAVE_VARIABLE_DEVICES = "save_variable_devices"
69
70  EXPAND_DISTRIBUTED_VARIABLES = "expand_distributed_variables"
71
72  def _save_variable_devices(self):
73    """Checks whether variable devices should be saved."""
74    return self != VariablePolicy.NONE
75
76  def _expand_distributed_variables(self):
77    """Checks whether distributed variables should be expanded."""
78    return self == VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES
79
80  @staticmethod
81  def from_obj(obj):
82    """Tries to convert `obj` to a VariablePolicy instance."""
83    if obj is None:
84      return VariablePolicy.NONE
85    if isinstance(obj, VariablePolicy):
86      return obj
87    key = str(obj).lower()
88    for policy in VariablePolicy:
89      if key == policy.value:
90        return policy
91    raise ValueError('Invalid VariablePolicy value "%s".' % obj)
92
93
94@tf_export("saved_model.SaveOptions")
95class SaveOptions(object):
96  """Options for saving to SavedModel.
97
98  This function may be used in the `options` argument in functions that
99  save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
100  """
101
102  # Define object attributes in __slots__ for improved memory and performance.
103  __slots__ = ("namespace_whitelist", "save_debug_info", "function_aliases",
104               "experimental_io_device", "experimental_variable_policy")
105
106  def __init__(self,
107               namespace_whitelist=None,
108               save_debug_info=False,
109               function_aliases=None,
110               experimental_io_device=None,
111               experimental_variable_policy=None):
112    """Creates an object that stores options for SavedModel saving.
113
114    Args:
115      namespace_whitelist: List of strings containing op namespaces to whitelist
116        when saving a model. Saving an object that uses namespaced ops must
117        explicitly add all namespaces to the whitelist. The namespaced ops must
118        be registered into the framework when loading the SavedModel.
119      save_debug_info: Boolean indicating whether debug information is saved. If
120        True, then a debug/saved_model_debug_info.pb file will be written with
121        the contents of a GraphDebugInfo binary protocol buffer containing stack
122        trace information for all ops and functions that are saved.
123      function_aliases: Python dict. Mapping from string to object returned by
124        @tf.function. A single tf.function can generate many ConcreteFunctions.
125        If a downstream tool wants to refer to all concrete functions generated
126        by a single tf.function you can use the `function_aliases` argument to
127        store a map from the alias name to all concrete function names.
128        E.g.
129
130        >>> class Adder(tf.Module):
131        ...   @tf.function
132        ...   def double(self, x):
133        ...     return x + x
134
135        >>> model = Adder()
136        >>> model.double.get_concrete_function(
137        ...   tf.TensorSpec(shape=[], dtype=tf.float32, name="float_input"))
138        >>> model.double.get_concrete_function(
139        ...   tf.TensorSpec(shape=[], dtype=tf.string, name="string_input"))
140
141        >>> options = tf.saved_model.SaveOptions(
142        ...   function_aliases={'double': model.double})
143        >>> tf.saved_model.save(model, '/tmp/adder', options=options)
144
145      experimental_io_device: string. Applies in a distributed setting.
146        Tensorflow device to use to access the filesystem. If `None` (default)
147        then for each variable the filesystem is accessed from the CPU:0 device
148        of the host where that variable is assigned. If specified, the
149        filesystem is instead accessed from that device for all variables.
150
151        This is for example useful if you want to save to a local directory,
152        such as "/tmp" when running in a distributed setting. In that case pass
153        a device for the host where the "/tmp" directory is accessible.
154      experimental_variable_policy: The policy to apply to variables when
155        saving. This is either a `saved_model.experimental.VariablePolicy` enum
156        instance or one of its value strings (case is not important). See that
157        enum documentation for details. A value of `None` corresponds to the
158        default policy.
159    """
160    self.namespace_whitelist = _validate_namespace_whitelist(
161        namespace_whitelist)
162    self.save_debug_info = save_debug_info
163    self.function_aliases = function_aliases if function_aliases else dict()
164    self.experimental_io_device = experimental_io_device
165    self.experimental_variable_policy = (
166        VariablePolicy.from_obj(experimental_variable_policy))
167
168
169def _validate_namespace_whitelist(namespace_whitelist):
170  """Validates namespace whitelist argument."""
171  if namespace_whitelist is None:
172    return []
173  if not isinstance(namespace_whitelist, list):
174    raise TypeError("Namespace whitelist must be a list of strings.")
175
176  processed = []
177  for namespace in namespace_whitelist:
178    if not isinstance(namespace, six.string_types):
179      raise ValueError("Whitelisted namespace must be a string. Got: {} of type"
180                       " {}.".format(namespace, type(namespace)))
181    processed.append(compat.as_str(namespace))
182  return processed
183