1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Deprecated experimental Keras SavedModel implementation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import warnings 22 23import six 24 25from tensorflow.python.client import session 26from tensorflow.python.framework import ops 27from tensorflow.python.keras import backend as K 28from tensorflow.python.keras import optimizer_v1 29from tensorflow.python.keras.optimizer_v2 import optimizer_v2 30from tensorflow.python.keras.saving import model_config 31from tensorflow.python.keras.saving import saving_utils 32from tensorflow.python.keras.utils import mode_keys 33from tensorflow.python.keras.utils.generic_utils import LazyLoader 34from tensorflow.python.ops import variables 35from tensorflow.python.platform import gfile 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.saved_model import builder as saved_model_builder 38from tensorflow.python.saved_model import constants 39from tensorflow.python.saved_model import model_utils 40from tensorflow.python.saved_model import save as save_lib 41from tensorflow.python.saved_model import utils_impl as saved_model_utils 42from tensorflow.python.training import saver as saver_lib 43from tensorflow.python.training.tracking import graph_view 44from tensorflow.python.util import compat 45from tensorflow.python.util import nest 46from tensorflow.python.util.tf_export import keras_export 47 48# To avoid circular dependencies between keras/engine and keras/saving, 49# code in keras/saving must delay imports. 50 51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 52# once the issue with copybara is fixed. 53# pylint:disable=g-inconsistent-quotes 54metrics_lib = LazyLoader("metrics_lib", globals(), 55 "tensorflow.python.keras.metrics") 56models_lib = LazyLoader("models_lib", globals(), 57 "tensorflow.python.keras.models") 58sequential = LazyLoader( 59 "sequential", globals(), 60 "tensorflow.python.keras.engine.sequential") 61# pylint:enable=g-inconsistent-quotes 62 63 64@keras_export(v1=['keras.experimental.export_saved_model']) 65def export_saved_model(model, 66 saved_model_path, 67 custom_objects=None, 68 as_text=False, 69 input_signature=None, 70 serving_only=False): 71 """Exports a `tf.keras.Model` as a Tensorflow SavedModel. 72 73 Note that at this time, subclassed models can only be saved using 74 `serving_only=True`. 75 76 The exported `SavedModel` is a standalone serialization of Tensorflow objects, 77 and is supported by TF language APIs and the Tensorflow Serving system. 78 To load the model, use the function 79 `tf.keras.experimental.load_from_saved_model`. 80 81 The `SavedModel` contains: 82 83 1. a checkpoint containing the model weights. 84 2. a `SavedModel` proto containing the Tensorflow backend graph. Separate 85 graphs are saved for prediction (serving), train, and evaluation. If 86 the model has not been compiled, then only the graph computing predictions 87 will be exported. 88 3. the model's json config. If the model is subclassed, this will only be 89 included if the model's `get_config()` method is overwritten. 90 91 Example: 92 93 ```python 94 import tensorflow as tf 95 96 # Create a tf.keras model. 97 model = tf.keras.Sequential() 98 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 99 model.summary() 100 101 # Save the tf.keras model in the SavedModel format. 102 path = '/tmp/simple_keras_model' 103 tf.keras.experimental.export_saved_model(model, path) 104 105 # Load the saved keras model back. 106 new_model = tf.keras.experimental.load_from_saved_model(path) 107 new_model.summary() 108 ``` 109 110 Args: 111 model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag 112 `serving_only` must be set to True. 113 saved_model_path: a string specifying the path to the SavedModel directory. 114 custom_objects: Optional dictionary mapping string names to custom classes 115 or functions (e.g. custom loss functions). 116 as_text: bool, `False` by default. Whether to write the `SavedModel` proto 117 in text format. Currently unavailable in serving-only mode. 118 input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used 119 to specify the expected model inputs. See `tf.function` for more details. 120 serving_only: bool, `False` by default. When this is true, only the 121 prediction graph is saved. 122 123 Raises: 124 NotImplementedError: If the model is a subclassed model, and serving_only is 125 False. 126 ValueError: If the input signature cannot be inferred from the model. 127 AssertionError: If the SavedModel directory already exists and isn't empty. 128 """ 129 warnings.warn('`tf.keras.experimental.export_saved_model` is deprecated' 130 'and will be removed in a future version. ' 131 'Please use `model.save(..., save_format="tf")` or ' 132 '`tf.keras.models.save_model(..., save_format="tf")`.') 133 if serving_only: 134 save_lib.save( 135 model, 136 saved_model_path, 137 signatures=saving_utils.trace_model_call(model, input_signature)) 138 else: 139 _save_v1_format(model, saved_model_path, custom_objects, as_text, 140 input_signature) 141 142 try: 143 _export_model_json(model, saved_model_path) 144 except NotImplementedError: 145 logging.warning('Skipped saving model JSON, subclassed model does not have ' 146 'get_config() defined.') 147 148 149def _export_model_json(model, saved_model_path): 150 """Saves model configuration as a json string under assets folder.""" 151 model_json = model.to_json() 152 model_json_filepath = os.path.join( 153 saved_model_utils.get_or_create_assets_dir(saved_model_path), 154 compat.as_text(constants.SAVED_MODEL_FILENAME_JSON)) 155 with gfile.Open(model_json_filepath, 'w') as f: 156 f.write(model_json) 157 158 159def _export_model_variables(model, saved_model_path): 160 """Saves model weights in checkpoint format under variables folder.""" 161 saved_model_utils.get_or_create_variables_dir(saved_model_path) 162 checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path) 163 model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True) 164 return checkpoint_prefix 165 166 167def _save_v1_format(model, path, custom_objects, as_text, input_signature): 168 """Exports model to v1 SavedModel format.""" 169 if not model._is_graph_network: # pylint: disable=protected-access 170 if isinstance(model, sequential.Sequential): 171 # If input shape is not directly set in the model, the exported model 172 # will infer the expected shapes of the input from the model. 173 if not model.built: 174 raise ValueError('Weights for sequential model have not yet been ' 175 'created. Weights are created when the Model is first ' 176 'called on inputs or `build()` is called with an ' 177 '`input_shape`, or the first layer in the model has ' 178 '`input_shape` during construction.') 179 # TODO(kathywu): Build the model with input_signature to create the 180 # weights before _export_model_variables(). 181 else: 182 raise NotImplementedError( 183 'Subclassed models can only be exported for serving. Please set ' 184 'argument serving_only=True.') 185 186 builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access 187 188 # Manually save variables to export them in an object-based checkpoint. This 189 # skips the `builder.add_meta_graph_and_variables()` step, which saves a 190 # named-based checkpoint. 191 # TODO(b/113134168): Add fn to Builder to save with object-based saver. 192 # TODO(b/113178242): This should only export the model json structure. Only 193 # one save is needed once the weights can be copied from the model to clone. 194 checkpoint_path = _export_model_variables(model, path) 195 196 # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that 197 # Keras models and `Estimator`s are exported with the same format. 198 # Every time a mode is exported, the code checks to see if new variables have 199 # been created (e.g. optimizer slot variables). If that is the case, the 200 # checkpoint is re-saved to include the new variables. 201 export_args = {'builder': builder, 202 'model': model, 203 'custom_objects': custom_objects, 204 'checkpoint_path': checkpoint_path, 205 'input_signature': input_signature} 206 207 has_saved_vars = False 208 if model.optimizer: 209 if isinstance(model.optimizer, (optimizer_v1.TFOptimizer, 210 optimizer_v2.OptimizerV2)): 211 _export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args) 212 has_saved_vars = True 213 _export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args) 214 else: 215 logging.warning( 216 'Model was compiled with an optimizer, but the optimizer is not from ' 217 '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving ' 218 'graph was exported. The train and evaluate graphs were not added to ' 219 'the SavedModel.') 220 _export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args) 221 222 builder.save(as_text) 223 224 225def _get_var_list(model): 226 """Returns list of all checkpointed saveable objects in the model.""" 227 var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph() 228 return var_list 229 230 231def create_placeholder(spec): 232 return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name) 233 234 235def _export_mode( 236 mode, has_saved_vars, builder, model, custom_objects, checkpoint_path, 237 input_signature): 238 """Exports a model, and optionally saves new vars from the clone model. 239 240 Args: 241 mode: A `tf.estimator.ModeKeys` string. 242 has_saved_vars: A `boolean` indicating whether the SavedModel has already 243 exported variables. 244 builder: A `SavedModelBuilder` object. 245 model: A `tf.keras.Model` object. 246 custom_objects: A dictionary mapping string names to custom classes 247 or functions. 248 checkpoint_path: String path to checkpoint. 249 input_signature: Nested TensorSpec containing the expected inputs. Can be 250 `None`, in which case the signature will be inferred from the model. 251 252 Raises: 253 ValueError: If the train/eval mode is being exported, but the model does 254 not have an optimizer. 255 """ 256 compile_clone = (mode != mode_keys.ModeKeys.PREDICT) 257 if compile_clone and not model.optimizer: 258 raise ValueError( 259 'Model does not have an optimizer. Cannot export mode %s' % mode) 260 261 model_graph = ops.get_default_graph() 262 with ops.Graph().as_default() as g, K.learning_phase_scope( 263 mode == mode_keys.ModeKeys.TRAIN): 264 265 if input_signature is None: 266 input_tensors = None 267 else: 268 input_tensors = nest.map_structure(create_placeholder, input_signature) 269 270 # Clone the model into blank graph. This will create placeholders for inputs 271 # and targets. 272 clone = models_lib.clone_and_build_model( 273 model, input_tensors=input_tensors, custom_objects=custom_objects, 274 compile_clone=compile_clone) 275 276 # Make sure that iterations variable is added to the global step collection, 277 # to ensure that, when the SavedModel graph is loaded, the iterations 278 # variable is returned by `tf.compat.v1.train.get_global_step()`. This is 279 # required for compatibility with the SavedModelEstimator. 280 if compile_clone: 281 g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) 282 283 # Extract update and train ops from train/test/predict functions. 284 train_op = None 285 if mode == mode_keys.ModeKeys.TRAIN: 286 clone._make_train_function() # pylint: disable=protected-access 287 train_op = clone.train_function.updates_op 288 elif mode == mode_keys.ModeKeys.TEST: 289 clone._make_test_function() # pylint: disable=protected-access 290 else: 291 clone._make_predict_function() # pylint: disable=protected-access 292 g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates) 293 294 with session.Session().as_default(): 295 clone_var_list = _get_var_list(clone) 296 if has_saved_vars: 297 # Confirm all variables in the clone have an entry in the checkpoint. 298 status = clone.load_weights(checkpoint_path) 299 status.assert_existing_objects_matched() 300 else: 301 # Confirm that variables between the clone and model match up exactly, 302 # not counting optimizer objects. Optimizer objects are ignored because 303 # if the model has not trained, the slot variables will not have been 304 # created yet. 305 # TODO(b/113179535): Replace with trackable equivalence. 306 _assert_same_non_optimizer_objects(model, model_graph, clone, g) 307 308 # TODO(b/113178242): Use value transfer for trackable objects. 309 clone.load_weights(checkpoint_path) 310 311 # Add graph and variables to SavedModel. 312 # TODO(b/113134168): Switch to add_meta_graph_and_variables. 313 clone.save_weights(checkpoint_path, save_format='tf', overwrite=True) 314 builder._has_saved_variables = True # pylint: disable=protected-access 315 316 # Add graph to the SavedModel builder. 317 builder.add_meta_graph( 318 model_utils.EXPORT_TAG_MAP[mode], 319 signature_def_map=_create_signature_def_map(clone, mode), 320 saver=saver_lib.Saver( 321 clone_var_list, 322 # Allow saving Models with no variables. This is somewhat odd, but 323 # it's not necessarily a bug. 324 allow_empty=True), 325 init_op=variables.local_variables_initializer(), 326 train_op=train_op) 327 return None 328 329 330def _create_signature_def_map(model, mode): 331 """Creates a SignatureDef map from a Keras model.""" 332 inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)} 333 if model.optimizer: 334 targets_dict = {x.name.split(':')[0]: x 335 for x in model._targets if x is not None} # pylint: disable=protected-access 336 inputs_dict.update(targets_dict) 337 outputs_dict = {name: x 338 for name, x in zip(model.output_names, model.outputs)} 339 metrics = saving_utils.extract_model_metrics(model) 340 341 # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables 342 # are by default not added to any collections. We are doing this here, so 343 # that metric variables get initialized. 344 local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 345 vars_to_add = set() 346 if metrics is not None: 347 for key, value in six.iteritems(metrics): 348 if isinstance(value, metrics_lib.Metric): 349 vars_to_add.update(value.variables) 350 # Convert Metric instances to (value_tensor, update_op) tuple. 351 metrics[key] = (value.result(), value.updates[0]) 352 # Remove variables that are in the local variables collection already. 353 vars_to_add = vars_to_add.difference(local_vars) 354 for v in vars_to_add: 355 ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) 356 357 export_outputs = model_utils.export_outputs_for_mode( 358 mode, 359 predictions=outputs_dict, 360 loss=model.total_loss if model.optimizer else None, 361 metrics=metrics) 362 return model_utils.build_all_signature_defs( 363 inputs_dict, 364 export_outputs=export_outputs, 365 serving_only=(mode == mode_keys.ModeKeys.PREDICT)) 366 367 368def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument 369 """Asserts model and clone contain the same trackable objects.""" 370 371 # TODO(fchollet, kathywu): make sure this works in eager mode. 372 return True 373 374 375@keras_export(v1=['keras.experimental.load_from_saved_model']) 376def load_from_saved_model(saved_model_path, custom_objects=None): 377 """Loads a keras Model from a SavedModel created by `export_saved_model()`. 378 379 This function reinstantiates model state by: 380 1) loading model topology from json (this will eventually come 381 from metagraph). 382 2) loading model weights from checkpoint. 383 384 Example: 385 386 ```python 387 import tensorflow as tf 388 389 # Create a tf.keras model. 390 model = tf.keras.Sequential() 391 model.add(tf.keras.layers.Dense(1, input_shape=[10])) 392 model.summary() 393 394 # Save the tf.keras model in the SavedModel format. 395 path = '/tmp/simple_keras_model' 396 tf.keras.experimental.export_saved_model(model, path) 397 398 # Load the saved keras model back. 399 new_model = tf.keras.experimental.load_from_saved_model(path) 400 new_model.summary() 401 ``` 402 403 Args: 404 saved_model_path: a string specifying the path to an existing SavedModel. 405 custom_objects: Optional dictionary mapping names 406 (strings) to custom classes or functions to be 407 considered during deserialization. 408 409 Returns: 410 a keras.Model instance. 411 """ 412 warnings.warn('`tf.keras.experimental.load_from_saved_model` is deprecated' 413 'and will be removed in a future version. ' 414 'Please switch to `tf.keras.models.load_model`.') 415 # restore model topology from json string 416 model_json_filepath = os.path.join( 417 compat.as_bytes(saved_model_path), 418 compat.as_bytes(constants.ASSETS_DIRECTORY), 419 compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON)) 420 with gfile.Open(model_json_filepath, 'r') as f: 421 model_json = f.read() 422 model = model_config.model_from_json( 423 model_json, custom_objects=custom_objects) 424 425 # restore model weights 426 checkpoint_prefix = os.path.join( 427 compat.as_text(saved_model_path), 428 compat.as_text(constants.VARIABLES_DIRECTORY), 429 compat.as_text(constants.VARIABLES_FILENAME)) 430 model.load_weights(checkpoint_prefix) 431 return model 432