1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Tests for saving and loading Keras models and layers from SavedModel. 17 18These should ensure that all layer properties are correctly assigned after 19loading from the SavedModel. 20 21Tests that focus on the model structure should go in revive_test.py 22""" 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import os 28import shutil 29 30from absl.testing import parameterized 31import numpy as np 32 33from tensorflow.core.example import example_pb2 34from tensorflow.core.example import feature_pb2 35from tensorflow.python import keras 36from tensorflow.python import tf2 37from tensorflow.python.data.ops import dataset_ops 38from tensorflow.python.distribute import mirrored_strategy 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.feature_column import feature_column_v2 as fc 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import tensor_spec 46from tensorflow.python.keras import combinations 47from tensorflow.python.keras import keras_parameterized 48from tensorflow.python.keras import regularizers 49from tensorflow.python.keras import testing_utils 50from tensorflow.python.keras.feature_column.dense_features import DenseFeatures 51from tensorflow.python.keras.saving.saved_model import load as keras_load 52from tensorflow.python.keras.saving.saved_model import save_impl as keras_save 53from tensorflow.python.keras.utils import control_flow_util 54from tensorflow.python.keras.utils import generic_utils 55from tensorflow.python.keras.utils import tf_contextlib 56from tensorflow.python.keras.utils import tf_inspect 57from tensorflow.python.ops import array_ops 58from tensorflow.python.ops import init_ops 59from tensorflow.python.ops import math_ops 60from tensorflow.python.ops import parsing_ops 61from tensorflow.python.ops import variables 62from tensorflow.python.ops.ragged import ragged_factory_ops 63from tensorflow.python.platform import test 64from tensorflow.python.saved_model import load as tf_load 65from tensorflow.python.saved_model import save as tf_save 66 67 68class LayerWithLearningPhase(keras.engine.base_layer.Layer): 69 70 def build(self, input_shape): 71 self.input_spec = keras.layers.InputSpec(shape=[None] * len(input_shape)) 72 self.built = True 73 74 def call(self, x, training=None): 75 if training is None: 76 training = keras.backend.learning_phase() 77 output = control_flow_util.smart_cond(training, lambda: x * 0, 78 lambda: array_ops.identity(x)) 79 if not context.executing_eagerly(): 80 output._uses_learning_phase = True # pylint: disable=protected-access 81 return output 82 83 def compute_output_shape(self, input_shape): 84 return input_shape 85 86 @property 87 def _use_input_spec_as_call_signature(self): 88 return True 89 90 91class LayerWithLoss(keras.layers.Layer): 92 93 def call(self, inputs): 94 self.add_loss(math_ops.reduce_sum(inputs), inputs=inputs) 95 return inputs * 2 96 97 98class LayerWithUpdate(keras.layers.Layer): 99 100 def build(self, _): 101 self.v = self.add_weight( 102 'v', 103 shape=[], 104 initializer=keras.initializers.zeros, 105 trainable=False, 106 dtype=dtypes.float32) 107 108 def call(self, inputs, training=True): 109 if training: 110 self.add_update(self.v.assign_add(1.)) 111 return inputs * 2. 112 113 114@generic_utils.register_keras_serializable('Testing') 115class GlobalLayerThatShouldFailIfNotAdded(keras.layers.Layer): 116 _must_restore_from_config = True 117 118 119@keras_parameterized.run_all_keras_modes 120class TestSavedModelFormatAllModes(keras_parameterized.TestCase): 121 122 def _save_model_dir(self, dirname='saved_model'): 123 temp_dir = self.get_temp_dir() 124 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 125 return os.path.join(temp_dir, dirname) 126 127 def _test_save_and_load(self, use_dataset=False): 128 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 129 model.layers[-1].activity_regularizer = regularizers.get('l2') 130 model.activity_regularizer = regularizers.get('l2') 131 model.compile( 132 loss='mse', 133 optimizer='rmsprop') 134 def callable_loss(): 135 return math_ops.reduce_sum(model.weights[0]) 136 model.add_loss(callable_loss) 137 138 x = np.random.random((1, 3)) 139 y = np.random.random((1, 4)) 140 141 if not tf2.enabled(): 142 # The layer autocast behavior only runs when autocast is enabled, so 143 # in V1, the numpy inputs still need to be cast to float32. 144 x = x.astype(np.float32) 145 y = y.astype(np.float32) 146 147 if use_dataset: 148 dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(1) 149 model.fit(dataset) 150 else: 151 model.train_on_batch(x, y) 152 153 saved_model_dir = self._save_model_dir() 154 tf_save.save(model, saved_model_dir) 155 loaded = keras_load.load(saved_model_dir) 156 self.evaluate(variables.variables_initializer(loaded.variables)) 157 self.assertAllClose(self.evaluate(model.weights), 158 self.evaluate(loaded.weights)) 159 160 input_arr = constant_op.constant( 161 np.random.random((1, 3)).astype(np.float32)) 162 self.assertAllClose(self.evaluate(model(input_arr)), 163 self.evaluate(loaded(input_arr))) 164 # Validate losses. The order of conditional losses may change between the 165 # model and loaded model, so sort the losses first. 166 if context.executing_eagerly(): 167 self.assertAllClose(sorted(self.evaluate(model.losses)), 168 sorted(self.evaluate(loaded.losses))) 169 else: 170 self.assertAllClose(self.evaluate(model.get_losses_for(None)), 171 self.evaluate(loaded.get_losses_for(None))) 172 self.assertAllClose( 173 sorted(self.evaluate(model.get_losses_for(input_arr))), 174 sorted(self.evaluate(loaded.get_losses_for(input_arr)))) 175 176 @keras_parameterized.run_with_all_model_types 177 def test_model_save_and_load(self): 178 self._test_save_and_load(use_dataset=True) 179 180 @keras_parameterized.run_with_all_model_types 181 def test_model_save_and_load_dataset(self): 182 self._test_save_and_load(use_dataset=True) 183 184 def test_trainable_weights(self): 185 layer = keras.layers.Dense(4, name='custom_layer') 186 layer.build([3,]) 187 layer.add_weight( 188 'extra_weight', shape=[], 189 initializer=init_ops.constant_initializer(11), 190 trainable=True) 191 layer.add_weight( 192 'extra_weight_2', shape=[], 193 initializer=init_ops.constant_initializer(12), 194 trainable=False) 195 196 saved_model_dir = self._save_model_dir() 197 self.evaluate(variables.variables_initializer(layer.variables)) 198 tf_save.save(layer, saved_model_dir) 199 loaded = keras_load.load(saved_model_dir) 200 self.evaluate(variables.variables_initializer(loaded.variables)) 201 202 equal_attrs = ['name', '_expects_training_arg', 'trainable'] 203 for attr in equal_attrs: 204 self.assertEqual(getattr(layer, attr), getattr(loaded, attr)) 205 206 all_close = ['weights', 'trainable_weights', 'non_trainable_weights'] 207 for attr in all_close: 208 self.assertAllClose(self.evaluate(getattr(layer, attr)), 209 self.evaluate(getattr(loaded, attr))) 210 211 def test_maintains_losses(self): 212 """Tests that the layer losses do not change before and after export.""" 213 model = keras.models.Sequential([LayerWithLoss()]) 214 model.compile( 215 loss='mse', 216 optimizer='rmsprop') 217 input_arr = np.random.random((1, 3)) 218 target_arr = np.random.random((1, 3)) 219 220 # Test that symbolic losses are maintained (train_on_batch saves symbolic 221 # losses.) 222 model.train_on_batch(input_arr, target_arr) 223 previous_losses = model.losses[:] 224 225 saved_model_dir = self._save_model_dir() 226 tf_save.save(model, saved_model_dir) 227 228 with previous_losses[0].graph.as_default(): 229 # If we try to compare symbolic Tensors in eager mode assertAllEqual will 230 # return False even if they are the same Tensor. 231 self.assertAllEqual(previous_losses, model.losses) 232 233 if context.executing_eagerly(): 234 # Test that eager losses are maintained. 235 model(input_arr) # Calls model eagerly, creating eager losses. 236 previous_losses = model.losses[:] 237 tf_save.save(model, saved_model_dir) 238 self.assertAllEqual(previous_losses, model.losses) 239 240 def test_layer_with_learning_phase(self): 241 layer = LayerWithLearningPhase() 242 layer.build([None, None]) 243 saved_model_dir = self._save_model_dir() 244 tf_save.save(layer, saved_model_dir) 245 loaded = keras_load.load(saved_model_dir) 246 input_arr = array_ops.ones((4, 3)) 247 248 # Run the layer, and use the keras backend learning phase 249 keras.backend.set_learning_phase(0) 250 self.assertAllEqual(input_arr, loaded(input_arr)) 251 keras.backend.set_learning_phase(1) 252 self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr)) 253 254 # Run the layer while explicitly setting the training argument 255 self.assertAllEqual( 256 input_arr, loaded(input_arr, training=constant_op.constant(False))) 257 self.assertAllEqual( 258 array_ops.zeros((4, 3)), 259 loaded(input_arr, training=constant_op.constant(True))) 260 261 @keras_parameterized.run_with_all_model_types 262 def test_standard_loader(self): 263 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 264 model.activity_regularizer = regularizers.get('l2') 265 def eager_loss(): 266 return math_ops.reduce_sum(model.weights[0]) 267 model.add_loss(eager_loss) 268 269 # Call predict to ensure that all layers are built and inputs are set. 270 model.predict(np.random.random((1, 3)).astype(np.float32)) 271 saved_model_dir = self._save_model_dir() 272 273 tf_save.save(model, saved_model_dir) 274 275 loaded = tf_load.load(saved_model_dir) 276 self.evaluate(variables.variables_initializer(loaded.variables)) 277 all_close = ['variables', 'trainable_variables', 278 'non_trainable_variables'] 279 for attr in all_close: 280 self.assertAllClose(self.evaluate(getattr(model, attr)), 281 self.evaluate(getattr(loaded.keras_api, attr))) 282 self.assertLen(loaded.regularization_losses, 1) 283 expected_layers = len(model.layers) 284 self.assertEqual(expected_layers, len(loaded.keras_api.layers)) 285 input_arr = array_ops.ones((4, 3)) 286 self.assertAllClose(self.evaluate(model(input_arr)), 287 self.evaluate(loaded(input_arr, training=False))) 288 289 @keras_parameterized.run_with_all_model_types 290 def test_compiled_model(self): 291 # TODO(b/134519980): Issue with model.fit if the model call function uses 292 # a tf.function (Graph mode only). 293 if not context.executing_eagerly(): 294 return 295 296 input_arr = np.random.random((1, 3)) 297 target_arr = np.random.random((1, 4)) 298 299 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 300 expected_predict = model.predict(input_arr) 301 302 # Compile and save model. 303 model.compile('rmsprop', 'mse') 304 saved_model_dir = self._save_model_dir() 305 tf_save.save(model, saved_model_dir) 306 307 loaded = keras_load.load(saved_model_dir) 308 actual_predict = loaded.predict(input_arr) 309 self.assertAllClose(expected_predict, actual_predict) 310 311 loss_before = loaded.evaluate(input_arr, target_arr) 312 loaded.fit(input_arr, target_arr) 313 loss_after = loaded.evaluate(input_arr, target_arr) 314 self.assertLess(loss_after, loss_before) 315 predict = loaded.predict(input_arr) 316 317 ckpt_path = os.path.join(self.get_temp_dir(), 'weights') 318 loaded.save_weights(ckpt_path) 319 320 # Ensure that the checkpoint is compatible with the original model. 321 model.load_weights(ckpt_path) 322 self.assertAllClose(predict, model.predict(input_arr)) 323 324 def test_metadata_input_spec(self): 325 class LayerWithNestedSpec(keras.layers.Layer): 326 327 def __init__(self): 328 super(LayerWithNestedSpec, self).__init__() 329 self.input_spec = { 330 'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}), 331 'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')} 332 333 @property 334 def _use_input_spec_as_call_signature(self): 335 return True 336 337 layer = LayerWithNestedSpec() 338 saved_model_dir = self._save_model_dir() 339 tf_save.save(layer, saved_model_dir) 340 loaded = keras_load.load(saved_model_dir) 341 self.assertEqual(3, loaded.input_spec['a'].max_ndim) 342 self.assertEqual({-1: 2}, loaded.input_spec['a'].axes) 343 self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape) 344 self.assertEqual('float16', loaded.input_spec['b'].dtype) 345 346 def test_must_restore_from_config_fails_if_layer_is_not_in_scope(self): 347 348 class LayerThatShouldFailIfNotAdded(keras.layers.Layer): 349 _must_restore_from_config = True 350 351 layer = LayerThatShouldFailIfNotAdded() 352 saved_model_dir = self._save_model_dir() 353 tf_save.save(layer, saved_model_dir) 354 with self.assertRaisesRegex(RuntimeError, 'Unable to restore a layer of'): 355 _ = keras_load.load(saved_model_dir) 356 357 def test_must_restore_from_config_custom_object_scope(self): 358 359 class LayerThatShouldFailIfNotAdded(keras.layers.Layer): 360 _must_restore_from_config = True 361 362 layer = LayerThatShouldFailIfNotAdded() 363 saved_model_dir = self._save_model_dir() 364 tf_save.save(layer, saved_model_dir) 365 with generic_utils.CustomObjectScope( 366 {'LayerThatShouldFailIfNotAdded': LayerThatShouldFailIfNotAdded}): 367 _ = keras_load.load(saved_model_dir) 368 369 def test_must_restore_from_config_registration(self): 370 layer = GlobalLayerThatShouldFailIfNotAdded() 371 saved_model_dir = self._save_model_dir() 372 tf_save.save(layer, saved_model_dir) 373 _ = keras_load.load(saved_model_dir) 374 375 def test_multi_input_model(self): 376 input_1 = keras.layers.Input(shape=(3,)) 377 input_2 = keras.layers.Input(shape=(5,)) 378 model = keras.Model([input_1, input_2], [input_1, input_2]) 379 saved_model_dir = self._save_model_dir() 380 381 model.save(saved_model_dir, save_format='tf') 382 loaded = keras_load.load(saved_model_dir) 383 input_arr_1 = np.random.random((1, 3)).astype('float32') 384 input_arr_2 = np.random.random((1, 5)).astype('float32') 385 386 outputs = loaded([input_arr_1, input_arr_2]) 387 self.assertAllEqual(input_arr_1, outputs[0]) 388 self.assertAllEqual(input_arr_2, outputs[1]) 389 390 def test_revived_sequential(self): 391 model = keras.models.Sequential() 392 model.add(keras.layers.Dense(5, input_shape=(3,), 393 kernel_regularizer=regularizers.get('l2'))) 394 model.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2'))) 395 396 self.evaluate(variables.variables_initializer(model.variables)) 397 398 saved_model_dir = self._save_model_dir() 399 model.save(saved_model_dir, save_format='tf') 400 loaded = keras_load.load(saved_model_dir) 401 402 self.assertLen(loaded.layers, 2) 403 self.assertLen(loaded.losses, 2) 404 405 loaded.pop() 406 407 self.assertLen(loaded.layers, 1) 408 self.assertLen(loaded.losses, 1) 409 410 loaded.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2'))) 411 412 self.assertLen(loaded.layers, 2) 413 self.assertLen(loaded.losses, 2) 414 415 def testBatchNormUpdates(self): 416 model = keras.models.Sequential( 417 keras.layers.BatchNormalization(input_shape=(1,))) 418 self.evaluate(variables.variables_initializer(model.variables)) 419 saved_model_dir = self._save_model_dir() 420 421 # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save 422 # metadata warning. 423 # with self.captureWritesToStream(sys.stderr) as captured_logs: 424 model.save(saved_model_dir, save_format='tf') 425 loaded = keras_load.load(saved_model_dir) 426 427 # Assert that saving does not log deprecation warnings 428 # (even if it needs to set learning phase for compat reasons) 429 # if context.executing_eagerly(): 430 # self.assertNotIn('deprecated', captured_logs.contents()) 431 432 input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) 433 input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) 434 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0]) 435 436 self.evaluate(loaded(input_arr, training=True)) 437 if not context.executing_eagerly(): 438 self.evaluate(loaded.get_updates_for(input_arr)) 439 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12]) 440 441 self.evaluate(loaded(input_arr2, training=False)) 442 if not context.executing_eagerly(): 443 self.evaluate(loaded.get_updates_for(input_arr2)) 444 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12]) 445 446 def testDisablingBatchNormTrainableBeforeSaving(self): 447 # We disable trainable on the batchnorm layers before saving 448 model = keras.models.Sequential( 449 keras.layers.BatchNormalization(input_shape=(1,))) 450 model.trainable = False 451 self.evaluate(variables.variables_initializer(model.variables)) 452 saved_model_dir = self._save_model_dir() 453 model.save(saved_model_dir, save_format='tf') 454 loaded = keras_load.load(saved_model_dir) 455 self.evaluate(variables.variables_initializer(loaded.variables)) 456 input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) 457 input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) 458 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0]) 459 460 # Trainable should still be disabled after loading 461 self.evaluate(loaded(input_arr, training=True)) 462 if not context.executing_eagerly(): 463 self.evaluate(loaded.get_updates_for(input_arr)) 464 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.0]) 465 466 # Re-enabling trainable on the loaded model should cause the batchnorm 467 # layer to start training again. 468 # Note: this only works in v2. 469 if context.executing_eagerly(): 470 loaded.trainable = True 471 self.evaluate(loaded(input_arr, training=True)) 472 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12]) 473 474 self.evaluate(loaded(input_arr2, training=False)) 475 self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12]) 476 477 def testSaveWithSignatures(self): 478 model = keras.models.Sequential() 479 model.add(keras.layers.Dense(5, input_shape=(3,), 480 kernel_regularizer=regularizers.get('l2'))) 481 model.add(keras.layers.Dropout(0.5)) 482 model.add(keras.layers.Dense(4, kernel_regularizer=regularizers.get('l2'))) 483 484 input_arr = np.random.random((2, 3)) 485 target_arr = np.random.random((2, 4)) 486 487 model.compile( 488 loss='mse', 489 optimizer='rmsprop') 490 model.train_on_batch(input_arr, target_arr) 491 492 @def_function.function(input_signature=[tensor_spec.TensorSpec((None, 3))]) 493 def predict(inputs): 494 return {'predictions': model(inputs)} 495 496 feature_configs = { 497 'inputs': parsing_ops.FixedLenFeature( 498 shape=[2, 3], dtype=dtypes.float32)} 499 500 @def_function.function( 501 input_signature=[tensor_spec.TensorSpec([None], dtypes.string)]) 502 def parse_and_predict(examples): 503 features = parsing_ops.parse_single_example(examples[0], feature_configs) 504 return {'predictions': model(features['inputs']), 505 'layer_1_outputs': model.layers[0](features['inputs'])} 506 507 saved_model_dir = self._save_model_dir() 508 model.save(saved_model_dir, save_format='tf', signatures={ 509 'predict': predict, 510 'parse_and_predict': parse_and_predict}) 511 model.save('/tmp/saved', save_format='tf', signatures={ 512 'predict': predict, 513 'parse_and_predict': parse_and_predict}) 514 515 loaded = keras_load.load(saved_model_dir) 516 517 self.assertAllClose( 518 model.predict(input_arr), 519 loaded.signatures['predict'](ops.convert_to_tensor_v2_with_dispatch( 520 input_arr.astype('float32')))['predictions']) 521 522 feature = { 523 'inputs': feature_pb2.Feature( 524 float_list=feature_pb2.FloatList( 525 value=input_arr.astype('float32').flatten()))} 526 example = example_pb2.Example( 527 features=feature_pb2.Features(feature=feature)) 528 outputs = loaded.signatures['parse_and_predict']( 529 ops.convert_to_tensor_v2_with_dispatch([example.SerializeToString()])) 530 self.assertAllClose(model.predict(input_arr), outputs['predictions']) 531 self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs']) 532 533 def testTrainingDefaults(self): 534 def assert_training_default(fn, default_value): 535 arg_spec = tf_inspect.getfullargspec(fn) 536 index = len(arg_spec.args) - arg_spec.args.index('training') 537 self.assertEqual(arg_spec.defaults[-index], default_value) 538 539 class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer): 540 541 def call(self, inputs, training): 542 return control_flow_util.smart_cond(training, lambda: inputs * 0, 543 lambda: array_ops.identity(inputs)) 544 545 class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer): 546 547 def call(self, inputs, training=True): 548 return control_flow_util.smart_cond(training, lambda: inputs * 0, 549 lambda: array_ops.identity(inputs)) 550 551 class Model(keras.models.Model): 552 553 def __init__(self): 554 super(Model, self).__init__() 555 self.layer_with_training_default_none = LayerWithLearningPhase() 556 self.layer_with_training_default_true = LayerWithTrainingDefaultTrue() 557 self.layer_with_required_training_arg = LayerWithTrainingRequiredArg() 558 559 def call(self, inputs): 560 x = self.layer_with_training_default_none(inputs) 561 x += self.layer_with_training_default_true(inputs) 562 x += self.layer_with_required_training_arg(inputs, False) 563 return x 564 565 model = Model() 566 # Build and set model inputs 567 model.predict(np.ones([1, 3]).astype('float32')) 568 saved_model_dir = self._save_model_dir() 569 model.save(saved_model_dir, save_format='tf') 570 load = tf_load.load(saved_model_dir) 571 572 # Ensure that the Keras loader is able to load and build the model. 573 _ = keras_load.load(saved_model_dir) 574 575 assert_training_default(load.__call__, False) 576 assert_training_default( 577 load.layer_with_training_default_none.__call__, False) 578 assert_training_default( 579 load.layer_with_training_default_true.__call__, True) 580 581 # Assert that there are no defaults for layer with required training arg 582 arg_spec = tf_inspect.getfullargspec( 583 load.layer_with_required_training_arg.__call__) 584 self.assertFalse(arg_spec.defaults) # defaults is None or empty 585 586 def testTraceModelWithKwarg(self): 587 class Model(keras.models.Model): 588 589 def call(self, inputs, keyword=None): 590 return array_ops.identity(inputs) 591 592 model = Model() 593 prediction = model.predict(np.ones([1, 3]).astype('float32')) 594 saved_model_dir = self._save_model_dir() 595 model.save(saved_model_dir, save_format='tf') 596 597 loaded = keras_load.load(saved_model_dir) 598 self.assertAllClose(prediction, 599 loaded.predict(np.ones([1, 3]).astype('float32'))) 600 601 def testFeatureColumns(self): 602 # TODO(b/120099662): Error with table initialization with Keras models in 603 # graph mode. 604 if context.executing_eagerly(): 605 numeric = fc.numeric_column('a') 606 bucketized = fc.bucketized_column(numeric, boundaries=[5, 10, 15]) 607 cat_vocab = fc.categorical_column_with_vocabulary_list( 608 'b', ['1', '2', '3']) 609 one_hot = fc.indicator_column(cat_vocab) 610 embedding = fc.embedding_column(cat_vocab, dimension=8) 611 feature_layer = DenseFeatures([bucketized, one_hot, embedding]) 612 model = keras.models.Sequential(feature_layer) 613 614 features = {'a': np.array([13, 15]), 'b': np.array(['1', '2'])} 615 predictions = model.predict(features) 616 617 saved_model_dir = self._save_model_dir() 618 model.save(saved_model_dir, save_format='tf') 619 loaded = keras_load.load(saved_model_dir) 620 loaded_predictions = loaded.predict(features) 621 self.assertAllClose(predictions, loaded_predictions) 622 623 def testSaveTensorKwarg(self): 624 625 class LayerWithTensorKwarg(keras.layers.Layer): 626 627 def call(self, inputs, tensor=None): 628 if tensor is not None: 629 return inputs * math_ops.cast(tensor, dtypes.float32) 630 else: 631 return inputs 632 633 t = self.evaluate(array_ops.sequence_mask(1)) 634 inputs = keras.layers.Input(shape=(3)) 635 model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t)) 636 637 input_arr = np.random.random((1, 3)) 638 predictions = model.predict(input_arr) 639 640 saved_model_dir = self._save_model_dir() 641 model.save(saved_model_dir, save_format='tf') 642 loaded = keras_load.load(saved_model_dir) 643 loaded_predictions = loaded.predict(input_arr) 644 self.assertAllClose(predictions, loaded_predictions) 645 646 def testModelWithTfFunctionCall(self): 647 class Subclass(keras.models.Model): 648 649 @def_function.function 650 def call(self, inputs, training=False): 651 return inputs * math_ops.cast(training, dtypes.float32) 652 653 model = Subclass() 654 model.predict(array_ops.ones((1, 2)), steps=1) 655 saved_model_dir = self._save_model_dir() 656 model.save(saved_model_dir, save_format='tf') 657 loaded = keras_load.load(saved_model_dir) 658 self.assertAllEqual( 659 [[1, 5]], 660 self.evaluate(loaded(array_ops.constant([[1, 5.]]), training=True))) 661 self.assertAllEqual( 662 [[0, 0]], 663 self.evaluate(loaded(array_ops.constant([[1, 5.]]), training=False))) 664 665 def testReviveFunctionalModel(self): 666 667 class CustomAdd(keras.layers.Add): 668 669 def build(self, input_shape): 670 self.w = self.add_weight('w', shape=[]) 671 super(CustomAdd, self).build(input_shape) 672 673 def call(self, inputs): 674 outputs = super(CustomAdd, self).call(inputs) 675 return outputs * self.w 676 677 input1 = keras.layers.Input(shape=(None, 3), name='input_1') 678 input2 = keras.layers.Input(shape=(None, 3), name='input_2') 679 680 d = keras.layers.Dense(4, name='dense_with_two_inbound_nodes') 681 output1 = d(input1) 682 output2 = d(input2) 683 684 # Use a custom layer in this model to ensure that layers aren't being 685 # recreated directly from the config. 686 outputs = CustomAdd(name='custom')([output1, output2]) 687 model = keras.models.Model([input1, input2], outputs, name='save_model') 688 689 self.evaluate(variables.variables_initializer(model.variables)) 690 saved_model_dir = self._save_model_dir() 691 model.save(saved_model_dir, save_format='tf') 692 693 loaded = keras_load.load(saved_model_dir) 694 self.assertEqual('save_model', loaded.name) 695 self.assertLen( 696 loaded.get_layer('dense_with_two_inbound_nodes')._inbound_nodes, 2) 697 self.assertEqual('CustomAdd', type(loaded.get_layer('custom')).__name__) 698 self.assertLen(loaded.get_layer('custom').weights, 1) 699 700 def _testAddUpdate(self, scope): 701 with scope: 702 layer_with_update = LayerWithUpdate() 703 model = testing_utils.get_model_from_layers([layer_with_update], 704 input_shape=(3,)) 705 706 x = np.ones((10, 3)) 707 if testing_utils.get_model_type() == 'subclass': 708 model.predict(x, batch_size=10) 709 self.evaluate(variables.variables_initializer(model.variables)) 710 saved_model_dir = self._save_model_dir() 711 model.save(saved_model_dir, save_format='tf') 712 713 loaded = keras_load.load(saved_model_dir) 714 loaded_layer = loaded.layers[-1] 715 self.evaluate(variables.variables_initializer(loaded.variables)) 716 self.assertEqual(self.evaluate(loaded_layer.v), 0.) 717 718 loaded.compile('sgd', 'mse') 719 loaded.fit(x, x, batch_size=10) 720 self.assertEqual(self.evaluate(loaded_layer.v), 1.) 721 722 @keras_parameterized.run_with_all_model_types 723 def testSaveLayerWithUpdates(self): 724 @tf_contextlib.contextmanager 725 def nullcontextmanager(): 726 yield 727 self._testAddUpdate(nullcontextmanager()) 728 729 @keras_parameterized.run_with_all_model_types 730 def testSaveInStrategyScope(self): 731 self._testAddUpdate(mirrored_strategy.MirroredStrategy().scope()) 732 733 def testSaveTimeDistributedLayer(self): 734 model = keras.Sequential([ 735 keras.layers.TimeDistributed( 736 keras.layers.Dense(1, kernel_regularizer=regularizers.get('l2')), 737 input_shape=(None, 1))]) 738 predictions = model.predict_on_batch(array_ops.ones((3, 2, 1))) 739 740 saved_model_dir = self._save_model_dir() 741 model.save(saved_model_dir, save_format='tf') 742 743 loaded = keras_load.load(saved_model_dir) 744 self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))), 745 predictions) 746 747 @parameterized.named_parameters([ 748 ('with_unrolling', True), 749 ('no_unrolling', False) 750 ]) 751 def testSaveStatefulRNN(self, unroll): 752 batch = 12 753 timesteps = 10 754 input_dim = 8 755 input_arr = np.ones((batch, timesteps, input_dim)).astype('float32') 756 757 cells = [keras.layers.LSTMCell(32), keras.layers.LSTMCell(64)] 758 if unroll: 759 x = keras.Input(batch_shape=(batch, timesteps, input_dim)) 760 else: 761 x = keras.Input(batch_shape=(batch, None, input_dim)) 762 layer = keras.layers.RNN(cells, stateful=True, unroll=unroll) 763 y = layer(x) 764 765 model = keras.Model(x, y) 766 model.compile('rmsprop', 'mse', 767 run_eagerly=testing_utils.should_run_eagerly()) 768 model.train_on_batch( 769 np.zeros((batch, timesteps, input_dim)).astype('float32'), 770 np.zeros((batch, 64)).astype('float32')) 771 772 saved_model_dir = self._save_model_dir() 773 tf_save.save(model, saved_model_dir) 774 775 loaded = keras_load.load(saved_model_dir) 776 loaded_layer = loaded.layers[1] 777 778 if not context.executing_eagerly(): 779 keras.backend.get_session() # force variable initialization 780 781 self.assertAllClose(layer.states, loaded_layer.states) 782 self.assertAllClose(model(input_arr), loaded(input_arr)) 783 784 def testSaveStatelessConvLSTM2D(self): 785 data_format = 'channels_first' 786 batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4 787 input_arr = np.ones( 788 (batch, timesteps, channels, rows, cols)).astype('float32') 789 layer = keras.layers.ConvLSTM2D( 790 filters=16, kernel_size=(1, 1), data_format=data_format) 791 x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols)) 792 y = layer(x) 793 model = keras.Model(x, y) 794 795 predict_1 = model(input_arr) 796 saved_model_dir = self._save_model_dir() 797 tf_save.save(model, saved_model_dir) 798 del model 799 800 loaded = keras_load.load(saved_model_dir) 801 predict_2 = loaded(input_arr) 802 self.assertAllClose(predict_1, predict_2) 803 804 def testSaveWithRaggedInputs(self): 805 806 class EmbeddingMerger(keras.layers.Layer): 807 808 def __init__(self, list_features, **kwargs): 809 super().__init__(**kwargs) 810 self._supports_ragged_inputs = True 811 self.embeddings = { 812 feature: keras.layers.Embedding(10, 3) for feature in list_features} 813 self.mean = keras.layers.Lambda( 814 math_ops.reduce_mean, arguments=dict(axis=1)) 815 816 def call(self, inputs): 817 tensors = [self.embeddings[col](inputs[col]) for col in inputs] 818 tensors = [self.mean(inp) for inp in tensors] 819 return keras.layers.Add()(tensors) 820 821 list_features = ['feature_1', 'feature_2'] 822 feature_1 = ragged_factory_ops.constant([[0.], [1, 3]]) 823 feature_2 = ragged_factory_ops.constant([[1., 2], [4]]) 824 f = {'feature_1': feature_1, 825 'feature_2': feature_2} 826 f_inputs = { 827 'feature_1': keras.Input(shape=(None,), name='feature_1', ragged=True), 828 'feature_2': keras.Input(shape=(None,), name='feature_2', ragged=True)} 829 830 out = EmbeddingMerger(list_features)(f_inputs) 831 model = keras.Model(f_inputs, out) 832 self.evaluate(variables.variables_initializer(model.variables)) 833 saved_model_dir = self._save_model_dir() 834 tf_save.save(model, saved_model_dir) 835 836 loaded = keras_load.load(saved_model_dir) 837 self.evaluate(variables.variables_initializer(loaded.variables)) 838 self.assertAllClose(model.predict(f), loaded.predict(f)) 839 840 841class TestSavedModelFormat(test.TestCase): 842 843 def _save_model_dir(self, dirname='saved_model'): 844 temp_dir = self.get_temp_dir() 845 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 846 return os.path.join(temp_dir, dirname) 847 848 def test_load_with_partially_failed_serialization(self): 849 850 class BadCustomLayer(keras.layers.Layer): 851 852 def __call__(self, inputs): 853 return inputs 854 855 class Model(keras.models.Model): 856 857 def __init__(self): 858 super(Model, self).__init__() 859 self.layer = BadCustomLayer() 860 861 @def_function.function( 862 input_signature=[tensor_spec.TensorSpec([None, 1])]) 863 def call(self, inputs): 864 return self.layer(inputs) 865 866 model = Model() 867 inp = constant_op.constant([[1.0]]) 868 model(inp) 869 saved_model_dir = self._save_model_dir() 870 tf_save.save(model, saved_model_dir) 871 872 loaded = keras_load.load(saved_model_dir) 873 self.assertAllEqual([[1.0]], self.evaluate(loaded(inp))) 874 with self.assertRaisesRegex(ValueError, 'call function was not serialized'): 875 loaded.layer(inp) 876 877 def test_save_without_tracing(self): 878 879 class DoNotTrace(keras.layers.Layer): 880 881 def __init__(self): 882 super(DoNotTrace, self).__init__() 883 self.input_spec = keras.layers.InputSpec(shape=[None]) 884 self.built = True 885 886 def call(self, inputs): 887 raise ValueError('I said do not trace') 888 889 def get_config(self): 890 return {} 891 892 @property 893 def _use_input_spec_as_call_signature(self): 894 return True 895 896 root = keras.models.Sequential() 897 root.add(keras.layers.Input(shape=(3,))) 898 root.attached_layer = DoNotTrace() 899 900 saved_model_dir = self._save_model_dir() 901 902 # With the default settings, the call function is traced. 903 with self.assertRaisesRegex(ValueError, 'do not trace'): 904 root.save(saved_model_dir, save_format='tf') 905 906 # When saving the config only, the layer call function should not be not 907 # traced. 908 root.save(saved_model_dir, save_format='tf', save_traces=False) 909 loaded = tf_load.load(saved_model_dir) 910 self.assertTrue(hasattr(loaded, 'attached_layer')) 911 912 # This should raise an error when loaded without the custom object 913 loaded = keras_load.load(saved_model_dir) 914 with self.assertRaisesRegex(ValueError, 'Cannot call custom layer'): 915 loaded.attached_layer(constant_op.constant([1.])) 916 917 # Try loading with the custom objects 918 with generic_utils.CustomObjectScope({'DoNotTrace': DoNotTrace}): 919 loaded = keras_load.load(saved_model_dir) 920 with self.assertRaisesRegex(ValueError, 'I said do not trace'): 921 loaded.attached_layer(constant_op.constant([1.])) 922 923 924class TestLayerCallTracing(test.TestCase, parameterized.TestCase): 925 926 def test_functions_have_same_trace(self): 927 928 class Layer(keras.engine.base_layer.Layer): 929 930 def call(self, inputs): 931 return inputs 932 933 def call2(self, inputs): 934 return inputs * 2 935 936 layer = Layer() 937 call_collection = keras_save.LayerCallCollection(layer) 938 fn = call_collection.add_function(layer.call, 'call') 939 fn2 = call_collection.add_function(layer.call2, 'call2') 940 941 fn(np.ones((2, 3))) 942 fn(np.ones((4, 5))) 943 944 self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2) 945 self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2) 946 947 # Check that the shapes are correct 948 self.assertEqual( 949 {(2, 3), (4, 5)}, 950 set(tuple(c.structured_input_signature[0][0].shape.as_list()) 951 for c in fn2._list_all_concrete_functions_for_serialization())) 952 953 def test_training_arg_replacement(self): 954 955 def assert_num_traces(layer_cls, training_keyword): 956 layer = layer_cls() 957 call_collection = keras_save.LayerCallCollection(layer) 958 fn = call_collection.add_function(layer.call, 'call') 959 960 fn(np.ones((2, 3)), training=True) 961 self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2) 962 963 fn(np.ones((2, 4)), training=False) 964 self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4) 965 966 if training_keyword: 967 fn(np.ones((2, 5)), True) 968 self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6) 969 fn(np.ones((2, 6))) 970 self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8) 971 972 class LayerWithTrainingKeyword(keras.engine.base_layer.Layer): 973 974 def call(self, inputs, training=False): 975 return inputs * training 976 977 assert_num_traces(LayerWithTrainingKeyword, training_keyword=True) 978 979 class LayerWithKwargs(keras.engine.base_layer.Layer): 980 981 def call(self, inputs, **kwargs): 982 return inputs * kwargs['training'] 983 984 assert_num_traces(LayerWithKwargs, training_keyword=False) 985 986 class LayerWithChildLayer(keras.engine.base_layer.Layer): 987 988 def __init__(self): 989 self.child = LayerWithKwargs() 990 super(LayerWithChildLayer, self).__init__() 991 992 def call(self, inputs): 993 return self.child(inputs) 994 995 assert_num_traces(LayerWithChildLayer, training_keyword=False) 996 997 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 998 def test_maintains_losses(self): 999 layer = LayerWithLoss() 1000 layer(np.ones((2, 3))) 1001 previous_losses = layer.losses[:] 1002 1003 call_collection = keras_save.LayerCallCollection(layer) 1004 fn = call_collection.add_function(layer.call, 'call') 1005 fn(np.ones((2, 3))) 1006 1007 self.assertAllEqual(previous_losses, layer.losses) 1008 1009 1010@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 1011class MetricTest(test.TestCase, parameterized.TestCase): 1012 1013 def _save_model_dir(self, dirname='saved_model'): 1014 temp_dir = self.get_temp_dir() 1015 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 1016 return os.path.join(temp_dir, dirname) 1017 1018 def generate_inputs(self, num_tensor_args, shape=(1, 5)): 1019 return [ 1020 np.random.uniform(0, 1, shape).astype('float32') 1021 for _ in range(num_tensor_args) 1022 ] 1023 1024 def _test_metric_save_and_load(self, 1025 metric, 1026 save_dir, 1027 num_tensor_args, 1028 shape=(1, 5), 1029 test_sample_weight=True): 1030 with self.cached_session(): 1031 tf_save.save(metric, save_dir) 1032 loaded = keras_load.load(save_dir) 1033 self.evaluate([v.initializer for v in loaded.variables]) 1034 self.assertEqual(metric.name, loaded.name) 1035 self.assertEqual(metric.dtype, loaded.dtype) 1036 1037 inputs = self.generate_inputs(num_tensor_args, shape) 1038 actual = self.evaluate(metric(*inputs)) 1039 self.assertAllClose(actual, loaded(*inputs)) 1040 self.assertAllClose(metric.variables, loaded.variables) 1041 1042 # Test with separate calls to update state and result. 1043 inputs = self.generate_inputs(num_tensor_args, shape) 1044 self.evaluate(metric.update_state(*inputs)) 1045 self.evaluate(loaded.update_state(*inputs)) 1046 actual = self.evaluate(metric.result()) 1047 self.assertAllClose(actual, loaded.result()) 1048 1049 if test_sample_weight: 1050 # Test with sample weights input. 1051 inputs = self.generate_inputs(num_tensor_args, shape) 1052 sample_weight = self.generate_inputs(1, [])[0] 1053 inputs.append(sample_weight) 1054 1055 actual = self.evaluate(metric(*inputs)) 1056 self.assertAllClose(actual, loaded(*inputs)) 1057 return loaded 1058 1059 @parameterized.named_parameters([ 1060 ('mean', keras.metrics.Mean, 1, (1, 5)), 1061 ('false_positives', keras.metrics.FalsePositives, 2, (1, 5)), 1062 ('precision_at_top_k', keras.metrics.Precision, 2, (2, 3, 4), { 1063 'top_k': 2, 1064 'class_id': 1 1065 }), 1066 ('precision_at_recall', keras.metrics.PrecisionAtRecall, 2, (1, 5), { 1067 'recall': .8 1068 }), ('auc', keras.metrics.AUC, 2, (1, 5), { 1069 'multi_label': True 1070 }), ('cosine_similarity', keras.metrics.CosineSimilarity, 2, (2, 3, 1)) 1071 ]) 1072 def test_metric(self, metric_cls, num_tensor_args, shape, init_kwargs=None): 1073 init_kwargs = init_kwargs or {} 1074 metric = metric_cls(**init_kwargs) 1075 metric(*self.generate_inputs(num_tensor_args, shape)) 1076 self.evaluate([v.initializer for v in metric.variables]) 1077 loaded = self._test_metric_save_and_load(metric, self._save_model_dir(), 1078 num_tensor_args, shape) 1079 self.assertEqual(type(loaded), type(metric)) 1080 1081 @parameterized.named_parameters([ 1082 ('mean', keras.metrics.Mean, 1, False), 1083 ('auc', keras.metrics.AUC, 2, False), 1084 ('mean_tensor', keras.metrics.MeanTensor, 1, True)]) 1085 def test_custom_metric(self, base_cls, num_tensor_args, requires_build): 1086 1087 class CustomMetric(base_cls): 1088 1089 def update_state(self, *args): # pylint: disable=useless-super-delegation 1090 # Sometimes built-in metrics return an op in update_state. Custom 1091 # metrics don't support returning ops, so wrap the update_state method 1092 # while returning nothing. 1093 super(CustomMetric, self).update_state(*args) 1094 1095 with self.cached_session(): 1096 metric = CustomMetric() 1097 save_dir = self._save_model_dir('first_save') 1098 1099 if requires_build: 1100 metric(*self.generate_inputs(num_tensor_args)) # pylint: disable=not-callable 1101 1102 self.evaluate([v.initializer for v in metric.variables]) 1103 1104 with self.assertRaisesRegex(ValueError, 1105 'Unable to restore custom object'): 1106 self._test_metric_save_and_load(metric, save_dir, num_tensor_args) 1107 with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}): 1108 loaded = self._test_metric_save_and_load( 1109 metric, 1110 save_dir, 1111 num_tensor_args, 1112 test_sample_weight=False) 1113 1114 self._test_metric_save_and_load( 1115 loaded, 1116 self._save_model_dir('second_save'), 1117 num_tensor_args, 1118 test_sample_weight=False) 1119 1120 def test_registered_custom_metric(self): 1121 1122 @generic_utils.register_keras_serializable('Testing') 1123 class CustomMeanMetric(keras.metrics.Mean): 1124 1125 def update_state(self, *args): # pylint: disable=useless-super-delegation 1126 # Sometimes built-in metrics return an op in update_state. Custom 1127 # metrics don't support returning ops, so wrap the update_state method 1128 # while returning nothing. 1129 super(CustomMeanMetric, self).update_state(*args) 1130 1131 with self.cached_session(): 1132 metric = CustomMeanMetric() 1133 save_dir = self._save_model_dir('first_save') 1134 self.evaluate([v.initializer for v in metric.variables]) 1135 loaded = self._test_metric_save_and_load( 1136 metric, 1137 save_dir, 1138 num_tensor_args=1, 1139 test_sample_weight=False) 1140 1141 self._test_metric_save_and_load( 1142 loaded, 1143 self._save_model_dir('second_save'), 1144 num_tensor_args=1, 1145 test_sample_weight=False) 1146 1147 def test_custom_metric_wrapped_call(self): 1148 1149 class NegativeMean(keras.metrics.Mean): 1150 1151 @def_function.function( 1152 input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) 1153 def update_state(self, value): 1154 super(NegativeMean, self).update_state(-value) 1155 1156 metric = NegativeMean() 1157 self.evaluate([v.initializer for v in metric.variables]) 1158 with generic_utils.CustomObjectScope({'NegativeMean': NegativeMean}): 1159 self._test_metric_save_and_load( 1160 metric, self._save_model_dir(), 1, test_sample_weight=False) 1161 1162 @keras_parameterized.run_with_all_model_types 1163 def test_custom_metric_model(self): 1164 1165 class CustomMetric(keras.metrics.MeanSquaredError): 1166 pass 1167 1168 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 1169 model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()]) 1170 1171 saved_model_dir = self._save_model_dir() 1172 tf_save.save(model, saved_model_dir) 1173 with self.assertRaisesRegex(ValueError, 'custom_objects'): 1174 keras_load.load(saved_model_dir) 1175 1176 keras_load.load(saved_model_dir, compile=False) 1177 1178 1179if __name__ == '__main__': 1180 test.main() 1181