1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Deprecated experimental Keras SavedModel implementation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import warnings
22
23import six
24
25from tensorflow.python.client import session
26from tensorflow.python.framework import ops
27from tensorflow.python.keras import backend as K
28from tensorflow.python.keras import optimizer_v1
29from tensorflow.python.keras.optimizer_v2 import optimizer_v2
30from tensorflow.python.keras.saving import model_config
31from tensorflow.python.keras.saving import saving_utils
32from tensorflow.python.keras.utils import mode_keys
33from tensorflow.python.keras.utils.generic_utils import LazyLoader
34from tensorflow.python.ops import variables
35from tensorflow.python.platform import gfile
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.saved_model import builder as saved_model_builder
38from tensorflow.python.saved_model import constants
39from tensorflow.python.saved_model import model_utils
40from tensorflow.python.saved_model import save as save_lib
41from tensorflow.python.saved_model import utils_impl as saved_model_utils
42from tensorflow.python.training import saver as saver_lib
43from tensorflow.python.training.tracking import graph_view
44from tensorflow.python.util import compat
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import keras_export
47
48# To avoid circular dependencies between keras/engine and keras/saving,
49# code in keras/saving must delay imports.
50
51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
52# once the issue with copybara is fixed.
53# pylint:disable=g-inconsistent-quotes
54metrics_lib = LazyLoader("metrics_lib", globals(),
55                         "tensorflow.python.keras.metrics")
56models_lib = LazyLoader("models_lib", globals(),
57                        "tensorflow.python.keras.models")
58sequential = LazyLoader(
59    "sequential", globals(),
60    "tensorflow.python.keras.engine.sequential")
61# pylint:enable=g-inconsistent-quotes
62
63
64@keras_export(v1=['keras.experimental.export_saved_model'])
65def export_saved_model(model,
66                       saved_model_path,
67                       custom_objects=None,
68                       as_text=False,
69                       input_signature=None,
70                       serving_only=False):
71  """Exports a `tf.keras.Model` as a Tensorflow SavedModel.
72
73  Note that at this time, subclassed models can only be saved using
74  `serving_only=True`.
75
76  The exported `SavedModel` is a standalone serialization of Tensorflow objects,
77  and is supported by TF language APIs and the Tensorflow Serving system.
78  To load the model, use the function
79  `tf.keras.experimental.load_from_saved_model`.
80
81  The `SavedModel` contains:
82
83  1. a checkpoint containing the model weights.
84  2. a `SavedModel` proto containing the Tensorflow backend graph. Separate
85     graphs are saved for prediction (serving), train, and evaluation. If
86     the model has not been compiled, then only the graph computing predictions
87     will be exported.
88  3. the model's json config. If the model is subclassed, this will only be
89     included if the model's `get_config()` method is overwritten.
90
91  Example:
92
93  ```python
94  import tensorflow as tf
95
96  # Create a tf.keras model.
97  model = tf.keras.Sequential()
98  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
99  model.summary()
100
101  # Save the tf.keras model in the SavedModel format.
102  path = '/tmp/simple_keras_model'
103  tf.keras.experimental.export_saved_model(model, path)
104
105  # Load the saved keras model back.
106  new_model = tf.keras.experimental.load_from_saved_model(path)
107  new_model.summary()
108  ```
109
110  Args:
111    model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
112      `serving_only` must be set to True.
113    saved_model_path: a string specifying the path to the SavedModel directory.
114    custom_objects: Optional dictionary mapping string names to custom classes
115      or functions (e.g. custom loss functions).
116    as_text: bool, `False` by default. Whether to write the `SavedModel` proto
117      in text format. Currently unavailable in serving-only mode.
118    input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
119      to specify the expected model inputs. See `tf.function` for more details.
120    serving_only: bool, `False` by default. When this is true, only the
121      prediction graph is saved.
122
123  Raises:
124    NotImplementedError: If the model is a subclassed model, and serving_only is
125      False.
126    ValueError: If the input signature cannot be inferred from the model.
127    AssertionError: If the SavedModel directory already exists and isn't empty.
128  """
129  warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated'
130                'and will be removed in a future version. '
131                'Please use `model.save(..., save_format="tf")` or '
132                '`tf.keras.models.save_model(..., save_format="tf")`.')
133  if serving_only:
134    save_lib.save(
135        model,
136        saved_model_path,
137        signatures=saving_utils.trace_model_call(model, input_signature))
138  else:
139    _save_v1_format(model, saved_model_path, custom_objects, as_text,
140                    input_signature)
141
142  try:
143    _export_model_json(model, saved_model_path)
144  except NotImplementedError:
145    logging.warning('Skipped saving model JSON, subclassed model does not have '
146                    'get_config() defined.')
147
148
149def _export_model_json(model, saved_model_path):
150  """Saves model configuration as a json string under assets folder."""
151  model_json = model.to_json()
152  model_json_filepath = os.path.join(
153      saved_model_utils.get_or_create_assets_dir(saved_model_path),
154      compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
155  with gfile.Open(model_json_filepath, 'w') as f:
156    f.write(model_json)
157
158
159def _export_model_variables(model, saved_model_path):
160  """Saves model weights in checkpoint format under variables folder."""
161  saved_model_utils.get_or_create_variables_dir(saved_model_path)
162  checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
163  model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
164  return checkpoint_prefix
165
166
167def _save_v1_format(model, path, custom_objects, as_text, input_signature):
168  """Exports model to v1 SavedModel format."""
169  if not model._is_graph_network:  # pylint: disable=protected-access
170    if isinstance(model, sequential.Sequential):
171      # If input shape is not directly set in the model, the exported model
172      # will infer the expected shapes of the input from the model.
173      if not model.built:
174        raise ValueError('Weights for sequential model have not yet been '
175                         'created. Weights are created when the Model is first '
176                         'called on inputs or `build()` is called with an '
177                         '`input_shape`, or the first layer in the model has '
178                         '`input_shape` during construction.')
179      # TODO(kathywu): Build the model with input_signature to create the
180      # weights before _export_model_variables().
181    else:
182      raise NotImplementedError(
183          'Subclassed models can only be exported for serving. Please set '
184          'argument serving_only=True.')
185
186  builder = saved_model_builder._SavedModelBuilder(path)  # pylint: disable=protected-access
187
188  # Manually save variables to export them in an object-based checkpoint. This
189  # skips the `builder.add_meta_graph_and_variables()` step, which saves a
190  # named-based checkpoint.
191  # TODO(b/113134168): Add fn to Builder to save with object-based saver.
192  # TODO(b/113178242): This should only export the model json structure. Only
193  # one save is needed once the weights can be copied from the model to clone.
194  checkpoint_path = _export_model_variables(model, path)
195
196  # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
197  # Keras models and `Estimator`s are exported with the same format.
198  # Every time a mode is exported, the code checks to see if new variables have
199  # been created (e.g. optimizer slot variables). If that is the case, the
200  # checkpoint is re-saved to include the new variables.
201  export_args = {'builder': builder,
202                 'model': model,
203                 'custom_objects': custom_objects,
204                 'checkpoint_path': checkpoint_path,
205                 'input_signature': input_signature}
206
207  has_saved_vars = False
208  if model.optimizer:
209    if isinstance(model.optimizer, (optimizer_v1.TFOptimizer,
210                                    optimizer_v2.OptimizerV2)):
211      _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
212      has_saved_vars = True
213      _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
214    else:
215      logging.warning(
216          'Model was compiled with an optimizer, but the optimizer is not from '
217          '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
218          'graph was exported. The train and evaluate graphs were not added to '
219          'the SavedModel.')
220  _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args)
221
222  builder.save(as_text)
223
224
225def _get_var_list(model):
226  """Returns list of all checkpointed saveable objects in the model."""
227  var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
228  return var_list
229
230
231def create_placeholder(spec):
232  return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name)
233
234
235def _export_mode(
236    mode, has_saved_vars, builder, model, custom_objects, checkpoint_path,
237    input_signature):
238  """Exports a model, and optionally saves new vars from the clone model.
239
240  Args:
241    mode: A `tf.estimator.ModeKeys` string.
242    has_saved_vars: A `boolean` indicating whether the SavedModel has already
243      exported variables.
244    builder: A `SavedModelBuilder` object.
245    model: A `tf.keras.Model` object.
246    custom_objects: A dictionary mapping string names to custom classes
247      or functions.
248    checkpoint_path: String path to checkpoint.
249    input_signature: Nested TensorSpec containing the expected inputs. Can be
250      `None`, in which case the signature will be inferred from the model.
251
252  Raises:
253    ValueError: If the train/eval mode is being exported, but the model does
254      not have an optimizer.
255  """
256  compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
257  if compile_clone and not model.optimizer:
258    raise ValueError(
259        'Model does not have an optimizer. Cannot export mode %s' % mode)
260
261  model_graph = ops.get_default_graph()
262  with ops.Graph().as_default() as g, K.learning_phase_scope(
263      mode == mode_keys.ModeKeys.TRAIN):
264
265    if input_signature is None:
266      input_tensors = None
267    else:
268      input_tensors = nest.map_structure(create_placeholder, input_signature)
269
270    # Clone the model into blank graph. This will create placeholders for inputs
271    # and targets.
272    clone = models_lib.clone_and_build_model(
273        model, input_tensors=input_tensors, custom_objects=custom_objects,
274        compile_clone=compile_clone)
275
276    # Make sure that iterations variable is added to the global step collection,
277    # to ensure that, when the SavedModel graph is loaded, the iterations
278    # variable is returned by `tf.compat.v1.train.get_global_step()`. This is
279    # required for compatibility with the SavedModelEstimator.
280    if compile_clone:
281      g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
282
283    # Extract update and train ops from train/test/predict functions.
284    train_op = None
285    if mode == mode_keys.ModeKeys.TRAIN:
286      clone._make_train_function()  # pylint: disable=protected-access
287      train_op = clone.train_function.updates_op
288    elif mode == mode_keys.ModeKeys.TEST:
289      clone._make_test_function()  # pylint: disable=protected-access
290    else:
291      clone._make_predict_function()  # pylint: disable=protected-access
292    g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
293
294    with session.Session().as_default():
295      clone_var_list = _get_var_list(clone)
296      if has_saved_vars:
297        # Confirm all variables in the clone have an entry in the checkpoint.
298        status = clone.load_weights(checkpoint_path)
299        status.assert_existing_objects_matched()
300      else:
301        # Confirm that variables between the clone and model match up exactly,
302        # not counting optimizer objects. Optimizer objects are ignored because
303        # if the model has not trained, the slot variables will not have been
304        # created yet.
305        # TODO(b/113179535): Replace with trackable equivalence.
306        _assert_same_non_optimizer_objects(model, model_graph, clone, g)
307
308        # TODO(b/113178242): Use value transfer for trackable objects.
309        clone.load_weights(checkpoint_path)
310
311        # Add graph and variables to SavedModel.
312        # TODO(b/113134168): Switch to add_meta_graph_and_variables.
313        clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
314        builder._has_saved_variables = True  # pylint: disable=protected-access
315
316      # Add graph to the SavedModel builder.
317      builder.add_meta_graph(
318          model_utils.EXPORT_TAG_MAP[mode],
319          signature_def_map=_create_signature_def_map(clone, mode),
320          saver=saver_lib.Saver(
321              clone_var_list,
322              # Allow saving Models with no variables. This is somewhat odd, but
323              # it's not necessarily a bug.
324              allow_empty=True),
325          init_op=variables.local_variables_initializer(),
326          train_op=train_op)
327    return None
328
329
330def _create_signature_def_map(model, mode):
331  """Creates a SignatureDef map from a Keras model."""
332  inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
333  if model.optimizer:
334    targets_dict = {x.name.split(':')[0]: x
335                    for x in model._targets if x is not None}  # pylint: disable=protected-access
336    inputs_dict.update(targets_dict)
337  outputs_dict = {name: x
338                  for name, x in zip(model.output_names, model.outputs)}
339  metrics = saving_utils.extract_model_metrics(model)
340
341  # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
342  # are by default not added to any collections. We are doing this here, so
343  # that metric variables get initialized.
344  local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
345  vars_to_add = set()
346  if metrics is not None:
347    for key, value in six.iteritems(metrics):
348      if isinstance(value, metrics_lib.Metric):
349        vars_to_add.update(value.variables)
350        # Convert Metric instances to (value_tensor, update_op) tuple.
351        metrics[key] = (value.result(), value.updates[0])
352  # Remove variables that are in the local variables collection already.
353  vars_to_add = vars_to_add.difference(local_vars)
354  for v in vars_to_add:
355    ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
356
357  export_outputs = model_utils.export_outputs_for_mode(
358      mode,
359      predictions=outputs_dict,
360      loss=model.total_loss if model.optimizer else None,
361      metrics=metrics)
362  return model_utils.build_all_signature_defs(
363      inputs_dict,
364      export_outputs=export_outputs,
365      serving_only=(mode == mode_keys.ModeKeys.PREDICT))
366
367
368def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):  # pylint: disable=unused-argument
369  """Asserts model and clone contain the same trackable objects."""
370
371  # TODO(fchollet, kathywu): make sure this works in eager mode.
372  return True
373
374
375@keras_export(v1=['keras.experimental.load_from_saved_model'])
376def load_from_saved_model(saved_model_path, custom_objects=None):
377  """Loads a keras Model from a SavedModel created by `export_saved_model()`.
378
379  This function reinstantiates model state by:
380  1) loading model topology from json (this will eventually come
381     from metagraph).
382  2) loading model weights from checkpoint.
383
384  Example:
385
386  ```python
387  import tensorflow as tf
388
389  # Create a tf.keras model.
390  model = tf.keras.Sequential()
391  model.add(tf.keras.layers.Dense(1, input_shape=[10]))
392  model.summary()
393
394  # Save the tf.keras model in the SavedModel format.
395  path = '/tmp/simple_keras_model'
396  tf.keras.experimental.export_saved_model(model, path)
397
398  # Load the saved keras model back.
399  new_model = tf.keras.experimental.load_from_saved_model(path)
400  new_model.summary()
401  ```
402
403  Args:
404    saved_model_path: a string specifying the path to an existing SavedModel.
405    custom_objects: Optional dictionary mapping names
406        (strings) to custom classes or functions to be
407        considered during deserialization.
408
409  Returns:
410    a keras.Model instance.
411  """
412  warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated'
413                'and will be removed in a future version. '
414                'Please switch to `tf.keras.models.load_model`.')
415  # restore model topology from json string
416  model_json_filepath = os.path.join(
417      compat.as_bytes(saved_model_path),
418      compat.as_bytes(constants.ASSETS_DIRECTORY),
419      compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
420  with gfile.Open(model_json_filepath, 'r') as f:
421    model_json = f.read()
422  model = model_config.model_from_json(
423      model_json, custom_objects=custom_objects)
424
425  # restore model weights
426  checkpoint_prefix = os.path.join(
427      compat.as_text(saved_model_path),
428      compat.as_text(constants.VARIABLES_DIRECTORY),
429      compat.as_text(constants.VARIABLES_FILENAME))
430  model.load_weights(checkpoint_prefix)
431  return model
432