1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Tests for saving and loading Keras models and layers from SavedModel.
17
18These should ensure that all layer properties are correctly assigned after
19loading from the SavedModel.
20
21Tests that focus on the model structure should go in revive_test.py
22"""
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import os
28import shutil
29
30from absl.testing import parameterized
31import numpy as np
32
33from tensorflow.core.example import example_pb2
34from tensorflow.core.example import feature_pb2
35from tensorflow.python import keras
36from tensorflow.python import tf2
37from tensorflow.python.data.ops import dataset_ops
38from tensorflow.python.distribute import mirrored_strategy
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.feature_column import feature_column_v2 as fc
42from tensorflow.python.framework import constant_op
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import tensor_spec
46from tensorflow.python.keras import combinations
47from tensorflow.python.keras import keras_parameterized
48from tensorflow.python.keras import regularizers
49from tensorflow.python.keras import testing_utils
50from tensorflow.python.keras.feature_column.dense_features import DenseFeatures
51from tensorflow.python.keras.saving.saved_model import load as keras_load
52from tensorflow.python.keras.saving.saved_model import save_impl as keras_save
53from tensorflow.python.keras.utils import control_flow_util
54from tensorflow.python.keras.utils import generic_utils
55from tensorflow.python.keras.utils import tf_contextlib
56from tensorflow.python.keras.utils import tf_inspect
57from tensorflow.python.ops import array_ops
58from tensorflow.python.ops import init_ops
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import parsing_ops
61from tensorflow.python.ops import variables
62from tensorflow.python.ops.ragged import ragged_factory_ops
63from tensorflow.python.platform import test
64from tensorflow.python.saved_model import load as tf_load
65from tensorflow.python.saved_model import save as tf_save
66
67
68class LayerWithLearningPhase(keras.engine.base_layer.Layer):
69
70  def build(self, input_shape):
71    self.input_spec = keras.layers.InputSpec(shape=[None] * len(input_shape))
72    self.built = True
73
74  def call(self, x, training=None):
75    if training is None:
76      training = keras.backend.learning_phase()
77    output = control_flow_util.smart_cond(training, lambda: x * 0,
78                                          lambda: array_ops.identity(x))
79    if not context.executing_eagerly():
80      output._uses_learning_phase = True  # pylint: disable=protected-access
81    return output
82
83  def compute_output_shape(self, input_shape):
84    return input_shape
85
86  @property
87  def _use_input_spec_as_call_signature(self):
88    return True
89
90
91class LayerWithLoss(keras.layers.Layer):
92
93  def call(self, inputs):
94    self.add_loss(math_ops.reduce_sum(inputs), inputs=inputs)
95    return inputs * 2
96
97
98class LayerWithUpdate(keras.layers.Layer):
99
100  def build(self, _):
101    self.v = self.add_weight(
102        'v',
103        shape=[],
104        initializer=keras.initializers.zeros,
105        trainable=False,
106        dtype=dtypes.float32)
107
108  def call(self, inputs, training=True):
109    if training:
110      self.add_update(self.v.assign_add(1.))
111    return inputs * 2.
112
113
114@generic_utils.register_keras_serializable('Testing')
115class GlobalLayerThatShouldFailIfNotAdded(keras.layers.Layer):
116  _must_restore_from_config = True
117
118
119@keras_parameterized.run_all_keras_modes
120class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
121
122  def _save_model_dir(self, dirname='saved_model'):
123    temp_dir = self.get_temp_dir()
124    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
125    return os.path.join(temp_dir, dirname)
126
127  def _test_save_and_load(self, use_dataset=False):
128    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
129    model.layers[-1].activity_regularizer = regularizers.get('l2')
130    model.activity_regularizer = regularizers.get('l2')
131    model.compile(
132        loss='mse',
133        optimizer='rmsprop')
134    def callable_loss():
135      return math_ops.reduce_sum(model.weights[0])
136    model.add_loss(callable_loss)
137
138    x = np.random.random((1, 3))
139    y = np.random.random((1, 4))
140
141    if not tf2.enabled():
142      # The layer autocast behavior only runs when autocast is enabled, so
143      # in V1, the numpy inputs still need to be cast to float32.
144      x = x.astype(np.float32)
145      y = y.astype(np.float32)
146
147    if use_dataset:
148      dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(1)
149      model.fit(dataset)
150    else:
151      model.train_on_batch(x, y)
152
153    saved_model_dir = self._save_model_dir()
154    tf_save.save(model, saved_model_dir)
155    loaded = keras_load.load(saved_model_dir)
156    self.evaluate(variables.variables_initializer(loaded.variables))
157    self.assertAllClose(self.evaluate(model.weights),
158                        self.evaluate(loaded.weights))
159
160    input_arr = constant_op.constant(
161        np.random.random((1, 3)).astype(np.float32))
162    self.assertAllClose(self.evaluate(model(input_arr)),
163                        self.evaluate(loaded(input_arr)))
164    # Validate losses. The order of conditional losses may change between the
165    # model and loaded model, so sort the losses first.
166    if context.executing_eagerly():
167      self.assertAllClose(sorted(self.evaluate(model.losses)),
168                          sorted(self.evaluate(loaded.losses)))
169    else:
170      self.assertAllClose(self.evaluate(model.get_losses_for(None)),
171                          self.evaluate(loaded.get_losses_for(None)))
172      self.assertAllClose(
173          sorted(self.evaluate(model.get_losses_for(input_arr))),
174          sorted(self.evaluate(loaded.get_losses_for(input_arr))))
175
176  @keras_parameterized.run_with_all_model_types
177  def test_model_save_and_load(self):
178    self._test_save_and_load(use_dataset=True)
179
180  @keras_parameterized.run_with_all_model_types
181  def test_model_save_and_load_dataset(self):
182    self._test_save_and_load(use_dataset=True)
183
184  def test_trainable_weights(self):
185    layer = keras.layers.Dense(4, name='custom_layer')
186    layer.build([3,])
187    layer.add_weight(
188        'extra_weight', shape=[],
189        initializer=init_ops.constant_initializer(11),
190        trainable=True)
191    layer.add_weight(
192        'extra_weight_2', shape=[],
193        initializer=init_ops.constant_initializer(12),
194        trainable=False)
195
196    saved_model_dir = self._save_model_dir()
197    self.evaluate(variables.variables_initializer(layer.variables))
198    tf_save.save(layer, saved_model_dir)
199    loaded = keras_load.load(saved_model_dir)
200    self.evaluate(variables.variables_initializer(loaded.variables))
201
202    equal_attrs = ['name', '_expects_training_arg', 'trainable']
203    for attr in equal_attrs:
204      self.assertEqual(getattr(layer, attr), getattr(loaded, attr))
205
206    all_close = ['weights', 'trainable_weights', 'non_trainable_weights']
207    for attr in all_close:
208      self.assertAllClose(self.evaluate(getattr(layer, attr)),
209                          self.evaluate(getattr(loaded, attr)))
210
211  def test_maintains_losses(self):
212    """Tests that the layer losses do not change before and after export."""
213    model = keras.models.Sequential([LayerWithLoss()])
214    model.compile(
215        loss='mse',
216        optimizer='rmsprop')
217    input_arr = np.random.random((1, 3))
218    target_arr = np.random.random((1, 3))
219
220    # Test that symbolic losses are maintained (train_on_batch saves symbolic
221    # losses.)
222    model.train_on_batch(input_arr, target_arr)
223    previous_losses = model.losses[:]
224
225    saved_model_dir = self._save_model_dir()
226    tf_save.save(model, saved_model_dir)
227
228    with previous_losses[0].graph.as_default():
229      # If we try to compare symbolic Tensors in eager mode assertAllEqual will
230      # return False even if they are the same Tensor.
231      self.assertAllEqual(previous_losses, model.losses)
232
233    if context.executing_eagerly():
234      # Test that eager losses are maintained.
235      model(input_arr)  # Calls model eagerly, creating eager losses.
236      previous_losses = model.losses[:]
237      tf_save.save(model, saved_model_dir)
238      self.assertAllEqual(previous_losses, model.losses)
239
240  def test_layer_with_learning_phase(self):
241    layer = LayerWithLearningPhase()
242    layer.build([None, None])
243    saved_model_dir = self._save_model_dir()
244    tf_save.save(layer, saved_model_dir)
245    loaded = keras_load.load(saved_model_dir)
246    input_arr = array_ops.ones((4, 3))
247
248    # Run the layer, and use the keras backend learning phase
249    keras.backend.set_learning_phase(0)
250    self.assertAllEqual(input_arr, loaded(input_arr))
251    keras.backend.set_learning_phase(1)
252    self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr))
253
254    # Run the layer while explicitly setting the training argument
255    self.assertAllEqual(
256        input_arr, loaded(input_arr, training=constant_op.constant(False)))
257    self.assertAllEqual(
258        array_ops.zeros((4, 3)),
259        loaded(input_arr, training=constant_op.constant(True)))
260
261  @keras_parameterized.run_with_all_model_types
262  def test_standard_loader(self):
263    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
264    model.activity_regularizer = regularizers.get('l2')
265    def eager_loss():
266      return math_ops.reduce_sum(model.weights[0])
267    model.add_loss(eager_loss)
268
269    # Call predict to ensure that all layers are built and inputs are set.
270    model.predict(np.random.random((1, 3)).astype(np.float32))
271    saved_model_dir = self._save_model_dir()
272
273    tf_save.save(model, saved_model_dir)
274
275    loaded = tf_load.load(saved_model_dir)
276    self.evaluate(variables.variables_initializer(loaded.variables))
277    all_close = ['variables', 'trainable_variables',
278                 'non_trainable_variables']
279    for attr in all_close:
280      self.assertAllClose(self.evaluate(getattr(model, attr)),
281                          self.evaluate(getattr(loaded.keras_api, attr)))
282    self.assertLen(loaded.regularization_losses, 1)
283    expected_layers = len(model.layers)
284    self.assertEqual(expected_layers, len(loaded.keras_api.layers))
285    input_arr = array_ops.ones((4, 3))
286    self.assertAllClose(self.evaluate(model(input_arr)),
287                        self.evaluate(loaded(input_arr, training=False)))
288
289  @keras_parameterized.run_with_all_model_types
290  def test_compiled_model(self):
291    # TODO(b/134519980): Issue with model.fit if the model call function uses
292    # a tf.function (Graph mode only).
293    if not context.executing_eagerly():
294      return
295
296    input_arr = np.random.random((1, 3))
297    target_arr = np.random.random((1, 4))
298
299    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
300    expected_predict = model.predict(input_arr)
301
302    # Compile and save model.
303    model.compile('rmsprop', 'mse')
304    saved_model_dir = self._save_model_dir()
305    tf_save.save(model, saved_model_dir)
306
307    loaded = keras_load.load(saved_model_dir)
308    actual_predict = loaded.predict(input_arr)
309    self.assertAllClose(expected_predict, actual_predict)
310
311    loss_before = loaded.evaluate(input_arr, target_arr)
312    loaded.fit(input_arr, target_arr)
313    loss_after = loaded.evaluate(input_arr, target_arr)
314    self.assertLess(loss_after, loss_before)
315    predict = loaded.predict(input_arr)
316
317    ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
318    loaded.save_weights(ckpt_path)
319
320    # Ensure that the checkpoint is compatible with the original model.
321    model.load_weights(ckpt_path)
322    self.assertAllClose(predict, model.predict(input_arr))
323
324  def test_metadata_input_spec(self):
325    class LayerWithNestedSpec(keras.layers.Layer):
326
327      def __init__(self):
328        super(LayerWithNestedSpec, self).__init__()
329        self.input_spec = {
330            'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
331            'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
332
333      @property
334      def _use_input_spec_as_call_signature(self):
335        return True
336
337    layer = LayerWithNestedSpec()
338    saved_model_dir = self._save_model_dir()
339    tf_save.save(layer, saved_model_dir)
340    loaded = keras_load.load(saved_model_dir)
341    self.assertEqual(3, loaded.input_spec['a'].max_ndim)
342    self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
343    self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
344    self.assertEqual('float16', loaded.input_spec['b'].dtype)
345
346  def test_must_restore_from_config_fails_if_layer_is_not_in_scope(self):
347
348    class LayerThatShouldFailIfNotAdded(keras.layers.Layer):
349      _must_restore_from_config = True
350
351    layer = LayerThatShouldFailIfNotAdded()
352    saved_model_dir = self._save_model_dir()
353    tf_save.save(layer, saved_model_dir)
354    with self.assertRaisesRegex(RuntimeError, 'Unable to restore a layer of'):
355      _ = keras_load.load(saved_model_dir)
356
357  def test_must_restore_from_config_custom_object_scope(self):
358
359    class LayerThatShouldFailIfNotAdded(keras.layers.Layer):
360      _must_restore_from_config = True
361
362    layer = LayerThatShouldFailIfNotAdded()
363    saved_model_dir = self._save_model_dir()
364    tf_save.save(layer, saved_model_dir)
365    with generic_utils.CustomObjectScope(
366        {'LayerThatShouldFailIfNotAdded': LayerThatShouldFailIfNotAdded}):
367      _ = keras_load.load(saved_model_dir)
368
369  def test_must_restore_from_config_registration(self):
370    layer = GlobalLayerThatShouldFailIfNotAdded()
371    saved_model_dir = self._save_model_dir()
372    tf_save.save(layer, saved_model_dir)
373    _ = keras_load.load(saved_model_dir)
374
375  def test_multi_input_model(self):
376    input_1 = keras.layers.Input(shape=(3,))
377    input_2 = keras.layers.Input(shape=(5,))
378    model = keras.Model([input_1, input_2], [input_1, input_2])
379    saved_model_dir = self._save_model_dir()
380
381    model.save(saved_model_dir, save_format='tf')
382    loaded = keras_load.load(saved_model_dir)
383    input_arr_1 = np.random.random((1, 3)).astype('float32')
384    input_arr_2 = np.random.random((1, 5)).astype('float32')
385
386    outputs = loaded([input_arr_1, input_arr_2])
387    self.assertAllEqual(input_arr_1, outputs[0])
388    self.assertAllEqual(input_arr_2, outputs[1])
389
390  def test_revived_sequential(self):
391    model = keras.models.Sequential()
392    model.add(keras.layers.Dense(5, input_shape=(3,),
393                                 kernel_regularizer=regularizers.get('l2')))
394    model.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
395
396    self.evaluate(variables.variables_initializer(model.variables))
397
398    saved_model_dir = self._save_model_dir()
399    model.save(saved_model_dir, save_format='tf')
400    loaded = keras_load.load(saved_model_dir)
401
402    self.assertLen(loaded.layers, 2)
403    self.assertLen(loaded.losses, 2)
404
405    loaded.pop()
406
407    self.assertLen(loaded.layers, 1)
408    self.assertLen(loaded.losses, 1)
409
410    loaded.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
411
412    self.assertLen(loaded.layers, 2)
413    self.assertLen(loaded.losses, 2)
414
415  def testBatchNormUpdates(self):
416    model = keras.models.Sequential(
417        keras.layers.BatchNormalization(input_shape=(1,)))
418    self.evaluate(variables.variables_initializer(model.variables))
419    saved_model_dir = self._save_model_dir()
420
421    # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
422    # metadata warning.
423    # with self.captureWritesToStream(sys.stderr) as captured_logs:
424    model.save(saved_model_dir, save_format='tf')
425    loaded = keras_load.load(saved_model_dir)
426
427    # Assert that saving does not log deprecation warnings
428    # (even if it needs to set learning phase for compat reasons)
429    # if context.executing_eagerly():
430    #   self.assertNotIn('deprecated', captured_logs.contents())
431
432    input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
433    input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
434    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0])
435
436    self.evaluate(loaded(input_arr, training=True))
437    if not context.executing_eagerly():
438      self.evaluate(loaded.get_updates_for(input_arr))
439    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
440
441    self.evaluate(loaded(input_arr2, training=False))
442    if not context.executing_eagerly():
443      self.evaluate(loaded.get_updates_for(input_arr2))
444    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
445
446  def testDisablingBatchNormTrainableBeforeSaving(self):
447    # We disable trainable on the batchnorm layers before saving
448    model = keras.models.Sequential(
449        keras.layers.BatchNormalization(input_shape=(1,)))
450    model.trainable = False
451    self.evaluate(variables.variables_initializer(model.variables))
452    saved_model_dir = self._save_model_dir()
453    model.save(saved_model_dir, save_format='tf')
454    loaded = keras_load.load(saved_model_dir)
455    self.evaluate(variables.variables_initializer(loaded.variables))
456    input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
457    input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
458    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0])
459
460    # Trainable should still be disabled after loading
461    self.evaluate(loaded(input_arr, training=True))
462    if not context.executing_eagerly():
463      self.evaluate(loaded.get_updates_for(input_arr))
464    self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.0])
465
466    # Re-enabling trainable on the loaded model should cause the batchnorm
467    # layer to start training again.
468    # Note: this only works in v2.
469    if context.executing_eagerly():
470      loaded.trainable = True
471      self.evaluate(loaded(input_arr, training=True))
472      self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
473
474      self.evaluate(loaded(input_arr2, training=False))
475      self.assertAllClose(self.evaluate(loaded.layers[-1].moving_mean), [0.12])
476
477  def testSaveWithSignatures(self):
478    model = keras.models.Sequential()
479    model.add(keras.layers.Dense(5, input_shape=(3,),
480                                 kernel_regularizer=regularizers.get('l2')))
481    model.add(keras.layers.Dropout(0.5))
482    model.add(keras.layers.Dense(4, kernel_regularizer=regularizers.get('l2')))
483
484    input_arr = np.random.random((2, 3))
485    target_arr = np.random.random((2, 4))
486
487    model.compile(
488        loss='mse',
489        optimizer='rmsprop')
490    model.train_on_batch(input_arr, target_arr)
491
492    @def_function.function(input_signature=[tensor_spec.TensorSpec((None, 3))])
493    def predict(inputs):
494      return {'predictions': model(inputs)}
495
496    feature_configs = {
497        'inputs': parsing_ops.FixedLenFeature(
498            shape=[2, 3], dtype=dtypes.float32)}
499
500    @def_function.function(
501        input_signature=[tensor_spec.TensorSpec([None], dtypes.string)])
502    def parse_and_predict(examples):
503      features = parsing_ops.parse_single_example(examples[0], feature_configs)
504      return {'predictions': model(features['inputs']),
505              'layer_1_outputs': model.layers[0](features['inputs'])}
506
507    saved_model_dir = self._save_model_dir()
508    model.save(saved_model_dir, save_format='tf', signatures={
509        'predict': predict,
510        'parse_and_predict': parse_and_predict})
511    model.save('/tmp/saved', save_format='tf', signatures={
512        'predict': predict,
513        'parse_and_predict': parse_and_predict})
514
515    loaded = keras_load.load(saved_model_dir)
516
517    self.assertAllClose(
518        model.predict(input_arr),
519        loaded.signatures['predict'](ops.convert_to_tensor_v2_with_dispatch(
520            input_arr.astype('float32')))['predictions'])
521
522    feature = {
523        'inputs': feature_pb2.Feature(
524            float_list=feature_pb2.FloatList(
525                value=input_arr.astype('float32').flatten()))}
526    example = example_pb2.Example(
527        features=feature_pb2.Features(feature=feature))
528    outputs = loaded.signatures['parse_and_predict'](
529        ops.convert_to_tensor_v2_with_dispatch([example.SerializeToString()]))
530    self.assertAllClose(model.predict(input_arr), outputs['predictions'])
531    self.assertAllClose(model.layers[0](input_arr), outputs['layer_1_outputs'])
532
533  def testTrainingDefaults(self):
534    def assert_training_default(fn, default_value):
535      arg_spec = tf_inspect.getfullargspec(fn)
536      index = len(arg_spec.args) - arg_spec.args.index('training')
537      self.assertEqual(arg_spec.defaults[-index], default_value)
538
539    class LayerWithTrainingRequiredArg(keras.engine.base_layer.Layer):
540
541      def call(self, inputs, training):
542        return control_flow_util.smart_cond(training, lambda: inputs * 0,
543                                            lambda: array_ops.identity(inputs))
544
545    class LayerWithTrainingDefaultTrue(keras.engine.base_layer.Layer):
546
547      def call(self, inputs, training=True):
548        return control_flow_util.smart_cond(training, lambda: inputs * 0,
549                                            lambda: array_ops.identity(inputs))
550
551    class Model(keras.models.Model):
552
553      def __init__(self):
554        super(Model, self).__init__()
555        self.layer_with_training_default_none = LayerWithLearningPhase()
556        self.layer_with_training_default_true = LayerWithTrainingDefaultTrue()
557        self.layer_with_required_training_arg = LayerWithTrainingRequiredArg()
558
559      def call(self, inputs):
560        x = self.layer_with_training_default_none(inputs)
561        x += self.layer_with_training_default_true(inputs)
562        x += self.layer_with_required_training_arg(inputs, False)
563        return x
564
565    model = Model()
566    # Build and set model inputs
567    model.predict(np.ones([1, 3]).astype('float32'))
568    saved_model_dir = self._save_model_dir()
569    model.save(saved_model_dir, save_format='tf')
570    load = tf_load.load(saved_model_dir)
571
572    # Ensure that the Keras loader is able to load and build the model.
573    _ = keras_load.load(saved_model_dir)
574
575    assert_training_default(load.__call__, False)
576    assert_training_default(
577        load.layer_with_training_default_none.__call__, False)
578    assert_training_default(
579        load.layer_with_training_default_true.__call__, True)
580
581    # Assert that there are no defaults for layer with required training arg
582    arg_spec = tf_inspect.getfullargspec(
583        load.layer_with_required_training_arg.__call__)
584    self.assertFalse(arg_spec.defaults)  # defaults is None or empty
585
586  def testTraceModelWithKwarg(self):
587    class Model(keras.models.Model):
588
589      def call(self, inputs, keyword=None):
590        return array_ops.identity(inputs)
591
592    model = Model()
593    prediction = model.predict(np.ones([1, 3]).astype('float32'))
594    saved_model_dir = self._save_model_dir()
595    model.save(saved_model_dir, save_format='tf')
596
597    loaded = keras_load.load(saved_model_dir)
598    self.assertAllClose(prediction,
599                        loaded.predict(np.ones([1, 3]).astype('float32')))
600
601  def testFeatureColumns(self):
602    # TODO(b/120099662): Error with table initialization with Keras models in
603    # graph mode.
604    if context.executing_eagerly():
605      numeric = fc.numeric_column('a')
606      bucketized = fc.bucketized_column(numeric, boundaries=[5, 10, 15])
607      cat_vocab = fc.categorical_column_with_vocabulary_list(
608          'b', ['1', '2', '3'])
609      one_hot = fc.indicator_column(cat_vocab)
610      embedding = fc.embedding_column(cat_vocab, dimension=8)
611      feature_layer = DenseFeatures([bucketized, one_hot, embedding])
612      model = keras.models.Sequential(feature_layer)
613
614      features = {'a': np.array([13, 15]), 'b': np.array(['1', '2'])}
615      predictions = model.predict(features)
616
617      saved_model_dir = self._save_model_dir()
618      model.save(saved_model_dir, save_format='tf')
619      loaded = keras_load.load(saved_model_dir)
620      loaded_predictions = loaded.predict(features)
621      self.assertAllClose(predictions, loaded_predictions)
622
623  def testSaveTensorKwarg(self):
624
625    class LayerWithTensorKwarg(keras.layers.Layer):
626
627      def call(self, inputs, tensor=None):
628        if tensor is not None:
629          return inputs * math_ops.cast(tensor, dtypes.float32)
630        else:
631          return inputs
632
633    t = self.evaluate(array_ops.sequence_mask(1))
634    inputs = keras.layers.Input(shape=(3))
635    model = keras.models.Model(inputs, LayerWithTensorKwarg()(inputs, t))
636
637    input_arr = np.random.random((1, 3))
638    predictions = model.predict(input_arr)
639
640    saved_model_dir = self._save_model_dir()
641    model.save(saved_model_dir, save_format='tf')
642    loaded = keras_load.load(saved_model_dir)
643    loaded_predictions = loaded.predict(input_arr)
644    self.assertAllClose(predictions, loaded_predictions)
645
646  def testModelWithTfFunctionCall(self):
647    class Subclass(keras.models.Model):
648
649      @def_function.function
650      def call(self, inputs, training=False):
651        return inputs * math_ops.cast(training, dtypes.float32)
652
653    model = Subclass()
654    model.predict(array_ops.ones((1, 2)), steps=1)
655    saved_model_dir = self._save_model_dir()
656    model.save(saved_model_dir, save_format='tf')
657    loaded = keras_load.load(saved_model_dir)
658    self.assertAllEqual(
659        [[1, 5]],
660        self.evaluate(loaded(array_ops.constant([[1, 5.]]), training=True)))
661    self.assertAllEqual(
662        [[0, 0]],
663        self.evaluate(loaded(array_ops.constant([[1, 5.]]), training=False)))
664
665  def testReviveFunctionalModel(self):
666
667    class CustomAdd(keras.layers.Add):
668
669      def build(self, input_shape):
670        self.w = self.add_weight('w', shape=[])
671        super(CustomAdd, self).build(input_shape)
672
673      def call(self, inputs):
674        outputs = super(CustomAdd, self).call(inputs)
675        return outputs * self.w
676
677    input1 = keras.layers.Input(shape=(None, 3), name='input_1')
678    input2 = keras.layers.Input(shape=(None, 3), name='input_2')
679
680    d = keras.layers.Dense(4, name='dense_with_two_inbound_nodes')
681    output1 = d(input1)
682    output2 = d(input2)
683
684    # Use a custom layer in this model to ensure that layers aren't being
685    # recreated directly from the config.
686    outputs = CustomAdd(name='custom')([output1, output2])
687    model = keras.models.Model([input1, input2], outputs, name='save_model')
688
689    self.evaluate(variables.variables_initializer(model.variables))
690    saved_model_dir = self._save_model_dir()
691    model.save(saved_model_dir, save_format='tf')
692
693    loaded = keras_load.load(saved_model_dir)
694    self.assertEqual('save_model', loaded.name)
695    self.assertLen(
696        loaded.get_layer('dense_with_two_inbound_nodes')._inbound_nodes, 2)
697    self.assertEqual('CustomAdd', type(loaded.get_layer('custom')).__name__)
698    self.assertLen(loaded.get_layer('custom').weights, 1)
699
700  def _testAddUpdate(self, scope):
701    with scope:
702      layer_with_update = LayerWithUpdate()
703      model = testing_utils.get_model_from_layers([layer_with_update],
704                                                  input_shape=(3,))
705
706      x = np.ones((10, 3))
707      if testing_utils.get_model_type() == 'subclass':
708        model.predict(x, batch_size=10)
709      self.evaluate(variables.variables_initializer(model.variables))
710      saved_model_dir = self._save_model_dir()
711      model.save(saved_model_dir, save_format='tf')
712
713    loaded = keras_load.load(saved_model_dir)
714    loaded_layer = loaded.layers[-1]
715    self.evaluate(variables.variables_initializer(loaded.variables))
716    self.assertEqual(self.evaluate(loaded_layer.v), 0.)
717
718    loaded.compile('sgd', 'mse')
719    loaded.fit(x, x, batch_size=10)
720    self.assertEqual(self.evaluate(loaded_layer.v), 1.)
721
722  @keras_parameterized.run_with_all_model_types
723  def testSaveLayerWithUpdates(self):
724    @tf_contextlib.contextmanager
725    def nullcontextmanager():
726      yield
727    self._testAddUpdate(nullcontextmanager())
728
729  @keras_parameterized.run_with_all_model_types
730  def testSaveInStrategyScope(self):
731    self._testAddUpdate(mirrored_strategy.MirroredStrategy().scope())
732
733  def testSaveTimeDistributedLayer(self):
734    model = keras.Sequential([
735        keras.layers.TimeDistributed(
736            keras.layers.Dense(1, kernel_regularizer=regularizers.get('l2')),
737            input_shape=(None, 1))])
738    predictions = model.predict_on_batch(array_ops.ones((3, 2, 1)))
739
740    saved_model_dir = self._save_model_dir()
741    model.save(saved_model_dir, save_format='tf')
742
743    loaded = keras_load.load(saved_model_dir)
744    self.assertAllClose(loaded.predict_on_batch(array_ops.ones((3, 2, 1))),
745                        predictions)
746
747  @parameterized.named_parameters([
748      ('with_unrolling', True),
749      ('no_unrolling', False)
750  ])
751  def testSaveStatefulRNN(self, unroll):
752    batch = 12
753    timesteps = 10
754    input_dim = 8
755    input_arr = np.ones((batch, timesteps, input_dim)).astype('float32')
756
757    cells = [keras.layers.LSTMCell(32), keras.layers.LSTMCell(64)]
758    if unroll:
759      x = keras.Input(batch_shape=(batch, timesteps, input_dim))
760    else:
761      x = keras.Input(batch_shape=(batch, None, input_dim))
762    layer = keras.layers.RNN(cells, stateful=True, unroll=unroll)
763    y = layer(x)
764
765    model = keras.Model(x, y)
766    model.compile('rmsprop', 'mse',
767                  run_eagerly=testing_utils.should_run_eagerly())
768    model.train_on_batch(
769        np.zeros((batch, timesteps, input_dim)).astype('float32'),
770        np.zeros((batch, 64)).astype('float32'))
771
772    saved_model_dir = self._save_model_dir()
773    tf_save.save(model, saved_model_dir)
774
775    loaded = keras_load.load(saved_model_dir)
776    loaded_layer = loaded.layers[1]
777
778    if not context.executing_eagerly():
779      keras.backend.get_session()  # force variable initialization
780
781    self.assertAllClose(layer.states, loaded_layer.states)
782    self.assertAllClose(model(input_arr), loaded(input_arr))
783
784  def testSaveStatelessConvLSTM2D(self):
785    data_format = 'channels_first'
786    batch, timesteps, channels, rows, cols = 12, 10, 8, 4, 4
787    input_arr = np.ones(
788        (batch, timesteps, channels, rows, cols)).astype('float32')
789    layer = keras.layers.ConvLSTM2D(
790        filters=16, kernel_size=(1, 1), data_format=data_format)
791    x = keras.Input(batch_shape=(batch, timesteps, channels, rows, cols))
792    y = layer(x)
793    model = keras.Model(x, y)
794
795    predict_1 = model(input_arr)
796    saved_model_dir = self._save_model_dir()
797    tf_save.save(model, saved_model_dir)
798    del model
799
800    loaded = keras_load.load(saved_model_dir)
801    predict_2 = loaded(input_arr)
802    self.assertAllClose(predict_1, predict_2)
803
804  def testSaveWithRaggedInputs(self):
805
806    class EmbeddingMerger(keras.layers.Layer):
807
808      def __init__(self, list_features, **kwargs):
809        super().__init__(**kwargs)
810        self._supports_ragged_inputs = True
811        self.embeddings = {
812            feature: keras.layers.Embedding(10, 3) for feature in list_features}
813        self.mean = keras.layers.Lambda(
814            math_ops.reduce_mean, arguments=dict(axis=1))
815
816      def call(self, inputs):
817        tensors = [self.embeddings[col](inputs[col]) for col in inputs]
818        tensors = [self.mean(inp) for inp in tensors]
819        return keras.layers.Add()(tensors)
820
821    list_features = ['feature_1', 'feature_2']
822    feature_1 = ragged_factory_ops.constant([[0.], [1, 3]])
823    feature_2 = ragged_factory_ops.constant([[1., 2], [4]])
824    f = {'feature_1': feature_1,
825         'feature_2': feature_2}
826    f_inputs = {
827        'feature_1': keras.Input(shape=(None,), name='feature_1', ragged=True),
828        'feature_2': keras.Input(shape=(None,), name='feature_2', ragged=True)}
829
830    out = EmbeddingMerger(list_features)(f_inputs)
831    model = keras.Model(f_inputs, out)
832    self.evaluate(variables.variables_initializer(model.variables))
833    saved_model_dir = self._save_model_dir()
834    tf_save.save(model, saved_model_dir)
835
836    loaded = keras_load.load(saved_model_dir)
837    self.evaluate(variables.variables_initializer(loaded.variables))
838    self.assertAllClose(model.predict(f), loaded.predict(f))
839
840
841class TestSavedModelFormat(test.TestCase):
842
843  def _save_model_dir(self, dirname='saved_model'):
844    temp_dir = self.get_temp_dir()
845    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
846    return os.path.join(temp_dir, dirname)
847
848  def test_load_with_partially_failed_serialization(self):
849
850    class BadCustomLayer(keras.layers.Layer):
851
852      def __call__(self, inputs):
853        return inputs
854
855    class Model(keras.models.Model):
856
857      def __init__(self):
858        super(Model, self).__init__()
859        self.layer = BadCustomLayer()
860
861      @def_function.function(
862          input_signature=[tensor_spec.TensorSpec([None, 1])])
863      def call(self, inputs):
864        return self.layer(inputs)
865
866    model = Model()
867    inp = constant_op.constant([[1.0]])
868    model(inp)
869    saved_model_dir = self._save_model_dir()
870    tf_save.save(model, saved_model_dir)
871
872    loaded = keras_load.load(saved_model_dir)
873    self.assertAllEqual([[1.0]], self.evaluate(loaded(inp)))
874    with self.assertRaisesRegex(ValueError, 'call function was not serialized'):
875      loaded.layer(inp)
876
877  def test_save_without_tracing(self):
878
879    class DoNotTrace(keras.layers.Layer):
880
881      def __init__(self):
882        super(DoNotTrace, self).__init__()
883        self.input_spec = keras.layers.InputSpec(shape=[None])
884        self.built = True
885
886      def call(self, inputs):
887        raise ValueError('I said do not trace')
888
889      def get_config(self):
890        return {}
891
892      @property
893      def _use_input_spec_as_call_signature(self):
894        return True
895
896    root = keras.models.Sequential()
897    root.add(keras.layers.Input(shape=(3,)))
898    root.attached_layer = DoNotTrace()
899
900    saved_model_dir = self._save_model_dir()
901
902    # With the default settings, the call function is traced.
903    with self.assertRaisesRegex(ValueError, 'do not trace'):
904      root.save(saved_model_dir, save_format='tf')
905
906    # When saving the config only, the layer call function should not be not
907    # traced.
908    root.save(saved_model_dir, save_format='tf', save_traces=False)
909    loaded = tf_load.load(saved_model_dir)
910    self.assertTrue(hasattr(loaded, 'attached_layer'))
911
912    # This should raise an error when loaded without the custom object
913    loaded = keras_load.load(saved_model_dir)
914    with self.assertRaisesRegex(ValueError, 'Cannot call custom layer'):
915      loaded.attached_layer(constant_op.constant([1.]))
916
917    # Try loading with the custom objects
918    with generic_utils.CustomObjectScope({'DoNotTrace': DoNotTrace}):
919      loaded = keras_load.load(saved_model_dir)
920    with self.assertRaisesRegex(ValueError, 'I said do not trace'):
921      loaded.attached_layer(constant_op.constant([1.]))
922
923
924class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
925
926  def test_functions_have_same_trace(self):
927
928    class Layer(keras.engine.base_layer.Layer):
929
930      def call(self, inputs):
931        return inputs
932
933      def call2(self, inputs):
934        return inputs * 2
935
936    layer = Layer()
937    call_collection = keras_save.LayerCallCollection(layer)
938    fn = call_collection.add_function(layer.call, 'call')
939    fn2 = call_collection.add_function(layer.call2, 'call2')
940
941    fn(np.ones((2, 3)))
942    fn(np.ones((4, 5)))
943
944    self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
945    self.assertLen(fn2._list_all_concrete_functions_for_serialization(), 2)
946
947    # Check that the shapes are correct
948    self.assertEqual(
949        {(2, 3), (4, 5)},
950        set(tuple(c.structured_input_signature[0][0].shape.as_list())
951            for c in fn2._list_all_concrete_functions_for_serialization()))
952
953  def test_training_arg_replacement(self):
954
955    def assert_num_traces(layer_cls, training_keyword):
956      layer = layer_cls()
957      call_collection = keras_save.LayerCallCollection(layer)
958      fn = call_collection.add_function(layer.call, 'call')
959
960      fn(np.ones((2, 3)), training=True)
961      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 2)
962
963      fn(np.ones((2, 4)), training=False)
964      self.assertLen(fn._list_all_concrete_functions_for_serialization(), 4)
965
966      if training_keyword:
967        fn(np.ones((2, 5)), True)
968        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 6)
969        fn(np.ones((2, 6)))
970        self.assertLen(fn._list_all_concrete_functions_for_serialization(), 8)
971
972    class LayerWithTrainingKeyword(keras.engine.base_layer.Layer):
973
974      def call(self, inputs, training=False):
975        return inputs * training
976
977    assert_num_traces(LayerWithTrainingKeyword, training_keyword=True)
978
979    class LayerWithKwargs(keras.engine.base_layer.Layer):
980
981      def call(self, inputs, **kwargs):
982        return inputs * kwargs['training']
983
984    assert_num_traces(LayerWithKwargs, training_keyword=False)
985
986    class LayerWithChildLayer(keras.engine.base_layer.Layer):
987
988      def __init__(self):
989        self.child = LayerWithKwargs()
990        super(LayerWithChildLayer, self).__init__()
991
992      def call(self, inputs):
993        return self.child(inputs)
994
995    assert_num_traces(LayerWithChildLayer, training_keyword=False)
996
997  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
998  def test_maintains_losses(self):
999    layer = LayerWithLoss()
1000    layer(np.ones((2, 3)))
1001    previous_losses = layer.losses[:]
1002
1003    call_collection = keras_save.LayerCallCollection(layer)
1004    fn = call_collection.add_function(layer.call, 'call')
1005    fn(np.ones((2, 3)))
1006
1007    self.assertAllEqual(previous_losses, layer.losses)
1008
1009
1010@combinations.generate(combinations.combine(mode=['graph', 'eager']))
1011class MetricTest(test.TestCase, parameterized.TestCase):
1012
1013  def _save_model_dir(self, dirname='saved_model'):
1014    temp_dir = self.get_temp_dir()
1015    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
1016    return os.path.join(temp_dir, dirname)
1017
1018  def generate_inputs(self, num_tensor_args, shape=(1, 5)):
1019    return [
1020        np.random.uniform(0, 1, shape).astype('float32')
1021        for _ in range(num_tensor_args)
1022    ]
1023
1024  def _test_metric_save_and_load(self,
1025                                 metric,
1026                                 save_dir,
1027                                 num_tensor_args,
1028                                 shape=(1, 5),
1029                                 test_sample_weight=True):
1030    with self.cached_session():
1031      tf_save.save(metric, save_dir)
1032      loaded = keras_load.load(save_dir)
1033      self.evaluate([v.initializer for v in loaded.variables])
1034      self.assertEqual(metric.name, loaded.name)
1035      self.assertEqual(metric.dtype, loaded.dtype)
1036
1037      inputs = self.generate_inputs(num_tensor_args, shape)
1038      actual = self.evaluate(metric(*inputs))
1039      self.assertAllClose(actual, loaded(*inputs))
1040      self.assertAllClose(metric.variables, loaded.variables)
1041
1042      # Test with separate calls to update state and result.
1043      inputs = self.generate_inputs(num_tensor_args, shape)
1044      self.evaluate(metric.update_state(*inputs))
1045      self.evaluate(loaded.update_state(*inputs))
1046      actual = self.evaluate(metric.result())
1047      self.assertAllClose(actual, loaded.result())
1048
1049      if test_sample_weight:
1050        # Test with sample weights input.
1051        inputs = self.generate_inputs(num_tensor_args, shape)
1052        sample_weight = self.generate_inputs(1, [])[0]
1053        inputs.append(sample_weight)
1054
1055        actual = self.evaluate(metric(*inputs))
1056        self.assertAllClose(actual, loaded(*inputs))
1057      return loaded
1058
1059  @parameterized.named_parameters([
1060      ('mean', keras.metrics.Mean, 1, (1, 5)),
1061      ('false_positives', keras.metrics.FalsePositives, 2, (1, 5)),
1062      ('precision_at_top_k', keras.metrics.Precision, 2, (2, 3, 4), {
1063          'top_k': 2,
1064          'class_id': 1
1065      }),
1066      ('precision_at_recall', keras.metrics.PrecisionAtRecall, 2, (1, 5), {
1067          'recall': .8
1068      }), ('auc', keras.metrics.AUC, 2, (1, 5), {
1069          'multi_label': True
1070      }), ('cosine_similarity', keras.metrics.CosineSimilarity, 2, (2, 3, 1))
1071  ])
1072  def test_metric(self, metric_cls, num_tensor_args, shape, init_kwargs=None):
1073    init_kwargs = init_kwargs or {}
1074    metric = metric_cls(**init_kwargs)
1075    metric(*self.generate_inputs(num_tensor_args, shape))
1076    self.evaluate([v.initializer for v in metric.variables])
1077    loaded = self._test_metric_save_and_load(metric, self._save_model_dir(),
1078                                             num_tensor_args, shape)
1079    self.assertEqual(type(loaded), type(metric))
1080
1081  @parameterized.named_parameters([
1082      ('mean', keras.metrics.Mean, 1, False),
1083      ('auc', keras.metrics.AUC, 2, False),
1084      ('mean_tensor', keras.metrics.MeanTensor, 1, True)])
1085  def test_custom_metric(self, base_cls, num_tensor_args, requires_build):
1086
1087    class CustomMetric(base_cls):
1088
1089      def update_state(self, *args):  # pylint: disable=useless-super-delegation
1090        # Sometimes built-in metrics return an op in update_state. Custom
1091        # metrics don't support returning ops, so wrap the update_state method
1092        # while returning nothing.
1093        super(CustomMetric, self).update_state(*args)
1094
1095    with self.cached_session():
1096      metric = CustomMetric()
1097      save_dir = self._save_model_dir('first_save')
1098
1099      if requires_build:
1100        metric(*self.generate_inputs(num_tensor_args))  # pylint: disable=not-callable
1101
1102      self.evaluate([v.initializer for v in metric.variables])
1103
1104      with self.assertRaisesRegex(ValueError,
1105                                  'Unable to restore custom object'):
1106        self._test_metric_save_and_load(metric, save_dir, num_tensor_args)
1107      with generic_utils.CustomObjectScope({'CustomMetric': CustomMetric}):
1108        loaded = self._test_metric_save_and_load(
1109            metric,
1110            save_dir,
1111            num_tensor_args,
1112            test_sample_weight=False)
1113
1114        self._test_metric_save_and_load(
1115            loaded,
1116            self._save_model_dir('second_save'),
1117            num_tensor_args,
1118            test_sample_weight=False)
1119
1120  def test_registered_custom_metric(self):
1121
1122    @generic_utils.register_keras_serializable('Testing')
1123    class CustomMeanMetric(keras.metrics.Mean):
1124
1125      def update_state(self, *args):  # pylint: disable=useless-super-delegation
1126        # Sometimes built-in metrics return an op in update_state. Custom
1127        # metrics don't support returning ops, so wrap the update_state method
1128        # while returning nothing.
1129        super(CustomMeanMetric, self).update_state(*args)
1130
1131    with self.cached_session():
1132      metric = CustomMeanMetric()
1133      save_dir = self._save_model_dir('first_save')
1134      self.evaluate([v.initializer for v in metric.variables])
1135      loaded = self._test_metric_save_and_load(
1136          metric,
1137          save_dir,
1138          num_tensor_args=1,
1139          test_sample_weight=False)
1140
1141      self._test_metric_save_and_load(
1142          loaded,
1143          self._save_model_dir('second_save'),
1144          num_tensor_args=1,
1145          test_sample_weight=False)
1146
1147  def test_custom_metric_wrapped_call(self):
1148
1149    class NegativeMean(keras.metrics.Mean):
1150
1151      @def_function.function(
1152          input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
1153      def update_state(self, value):
1154        super(NegativeMean, self).update_state(-value)
1155
1156    metric = NegativeMean()
1157    self.evaluate([v.initializer for v in metric.variables])
1158    with generic_utils.CustomObjectScope({'NegativeMean': NegativeMean}):
1159      self._test_metric_save_and_load(
1160          metric, self._save_model_dir(), 1, test_sample_weight=False)
1161
1162  @keras_parameterized.run_with_all_model_types
1163  def test_custom_metric_model(self):
1164
1165    class CustomMetric(keras.metrics.MeanSquaredError):
1166      pass
1167
1168    model = testing_utils.get_small_mlp(1, 4, input_dim=3)
1169    model.compile(loss='mse', optimizer='rmsprop', metrics=[CustomMetric()])
1170
1171    saved_model_dir = self._save_model_dir()
1172    tf_save.save(model, saved_model_dir)
1173    with self.assertRaisesRegex(ValueError, 'custom_objects'):
1174      keras_load.load(saved_model_dir)
1175
1176    keras_load.load(saved_model_dir, compile=False)
1177
1178
1179if __name__ == '__main__':
1180  test.main()
1181