1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#,============================================================================
15"""Tests for model saving in the HDF5 format."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23import tempfile
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python import keras
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import test_util
33from tensorflow.python.keras import optimizers
34from tensorflow.python.keras.engine import training
35from tensorflow.python.keras.saving import hdf5_format
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import random_ops
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.training import checkpoint_management
42from tensorflow.python.training import training as training_module
43from tensorflow.python.training.tracking import util as trackable
44
45try:
46  import h5py  # pylint:disable=g-import-not-at-top
47except ImportError:
48  h5py = None
49
50
51class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
52
53  @test_util.run_in_graph_and_eager_modes
54  def test_weight_loading(self):
55    with self.cached_session():
56      a = keras.layers.Input(shape=(2,))
57      x = keras.layers.Dense(3)(a)
58      b = keras.layers.Dense(1)(x)
59      model = keras.models.Model(a, b)
60
61      x = np.random.random((3, 2))
62      ref_y = model.predict(x)
63      weights = model.get_weights()
64      model.set_weights(weights)
65      y = model.predict(x)
66      self.assertAllClose(ref_y, y)
67
68      with self.assertRaises(ValueError):
69        model.set_weights(weights[1:])
70      with self.assertRaises(ValueError):
71        model.set_weights(weights[::-1])
72
73      temp_dir = self.get_temp_dir()
74      self.addCleanup(shutil.rmtree, temp_dir)
75
76      no_extension_path = os.path.join(temp_dir, 'test')
77      model.save_weights(no_extension_path, save_format='tf')
78      model.load_weights(no_extension_path)
79      y = model.predict(x)
80      self.assertAllClose(ref_y, y)
81
82      if h5py is None:
83        return  # Skip rest of test if H5py isn't available.
84
85      h5_path = os.path.join(temp_dir, 'test.h5')
86      model.save_weights(h5_path)
87      model.load_weights(h5_path)
88      y = model.predict(x)
89      self.assertAllClose(ref_y, y)
90
91      model.load_weights(h5_path, by_name=True)
92      y = model.predict(x)
93      self.assertAllClose(ref_y, y)
94
95      model.save_weights(no_extension_path, save_format='hdf5')
96      model.load_weights(no_extension_path)
97      y = model.predict(x)
98      self.assertAllClose(ref_y, y)
99
100  @test_util.run_in_graph_and_eager_modes
101  def test_weight_preprocessing(self):
102    input_dim = 3
103    output_dim = 3
104    size = 2
105    cases = [
106        [
107            (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
108            [np.random.random((2, 1)), np.random.random((2, 1))],
109            (None, 3, 2),
110        ],
111        [
112            (keras.layers.TimeDistributed(keras.layers.Dense(1))),
113            [np.random.random((2, 1)), np.random.random((1,))],
114            (None, 3, 2),
115        ],
116        [
117            (keras.layers.Conv1D(output_dim, size, use_bias=False)),
118            [np.random.random((output_dim, input_dim, size, 1))],
119            (None, 4, input_dim),
120        ],
121        [
122            (keras.layers.Conv2D(output_dim, size,
123                                 use_bias=False, data_format='channels_first')),
124            [np.random.random((output_dim, input_dim, size, size))],
125            (None, input_dim, 4, 4),
126        ],
127        [
128            (keras.layers.Conv2DTranspose(output_dim, size,
129                                          use_bias=False,
130                                          data_format='channels_first')),
131            [np.random.random((output_dim, input_dim, size, size))],
132            (None, input_dim, 4, 4),
133        ],
134        [
135            (keras.layers.Conv2DTranspose(output_dim, size,
136                                          use_bias=False,
137                                          data_format='channels_last')),
138            [np.random.random((size, size, input_dim, output_dim))],
139            (None, 4, 4, input_dim),
140        ],
141        [
142            (keras.layers.Conv3D(output_dim, size,
143                                 use_bias=False, data_format='channels_first')),
144            [np.random.random((output_dim, input_dim, size, size, size))],
145            (None, input_dim, 4, 4, 4),
146        ],
147        [
148            (keras.layers.GRU(output_dim)),
149            [np.random.random((input_dim, output_dim)),
150             np.random.random((output_dim, output_dim)),
151             np.random.random((output_dim,)),
152             np.random.random((input_dim, output_dim)),
153             np.random.random((output_dim, output_dim)),
154             np.random.random((output_dim,)),
155             np.random.random((input_dim, output_dim)),
156             np.random.random((output_dim, output_dim)),
157             np.random.random((output_dim,))],
158            (None, 4, input_dim),
159        ],
160        [
161            (keras.layers.LSTM(output_dim)),
162            [np.random.random((input_dim, output_dim)),
163             np.random.random((output_dim, output_dim)),
164             np.random.random((output_dim,)),
165             np.random.random((input_dim, output_dim)),
166             np.random.random((output_dim, output_dim)),
167             np.random.random((output_dim,)),
168             np.random.random((input_dim, output_dim)),
169             np.random.random((output_dim, output_dim)),
170             np.random.random((output_dim,)),
171             np.random.random((input_dim, output_dim)),
172             np.random.random((output_dim, output_dim)),
173             np.random.random((output_dim,))],
174            (None, 4, input_dim),
175        ],
176    ]
177    for layer, weights, input_shape in cases:
178      layer.build(input_shape)
179      _ = hdf5_format.preprocess_weights_for_loading(
180          layer, weights, original_keras_version='1')
181
182    model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
183    _ = hdf5_format.preprocess_weights_for_loading(
184        model, model.weights, original_keras_version='1')
185
186    x = keras.Input((2,))
187    y = keras.layers.Dense(2)(x)
188    model = keras.models.Model(x, y)
189    _ = hdf5_format.preprocess_weights_for_loading(
190        model, model.weights, original_keras_version='1')
191
192  @parameterized.named_parameters(
193      ('gru', keras.layers.GRU, {
194          'units': 2,
195          'input_shape': (3, 5)
196      }),
197      ('gru_with_reset_after', keras.layers.GRU, {
198          'units': 2,
199          'input_shape': (3, 5),
200          'reset_after': True
201      }),
202      ('lstm', keras.layers.LSTM, {
203          'units': 2,
204          'input_shape': (3, 5)
205      }),
206      ('cudnngru', keras.layers.CuDNNGRU, {
207          'units': 2,
208          'input_shape': (3, 5)
209      }),
210      ('cudnnlstm', keras.layers.CuDNNLSTM, {
211          'units': 2,
212          'input_shape': (3, 5)
213      }))
214  def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
215      self, layer_class, layer_args):
216    with self.cached_session():
217      layer = layer_class(**layer_args)
218      layer.build(input_shape=layer_args.get('input_shape'))
219      weights1 = layer.get_weights()
220      weights2 = hdf5_format.preprocess_weights_for_loading(
221          layer, weights1)
222      _ = [
223          self.assertAllClose(x, y, rtol=1e-05)
224          for (x, y) in zip(weights1, weights2)
225      ]
226
227  @test_util.run_in_graph_and_eager_modes
228  def test_sequential_weight_loading(self):
229    if h5py is None:
230      return
231
232    temp_dir = self.get_temp_dir()
233    self.addCleanup(shutil.rmtree, temp_dir)
234    h5_path = os.path.join(temp_dir, 'test.h5')
235
236    num_hidden = 5
237    input_dim = 3
238    batch_size = 5
239    num_classes = 2
240
241    with self.cached_session():
242      model = keras.models.Sequential()
243      model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
244      model.add(keras.layers.Dense(num_classes))
245
246      x = np.random.random((batch_size, input_dim))
247      ref_y = model.predict(x)
248
249      model.save_weights(h5_path)
250
251      model = keras.models.Sequential()
252      model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
253      model.add(keras.layers.Dense(num_classes))
254      model.load_weights(h5_path)
255      y = model.predict(x)
256
257      self.assertAllClose(y, ref_y)
258
259  @test_util.run_in_graph_and_eager_modes
260  def test_sequential_weight_loading_group_name_with_incorrect_length(self):
261    if h5py is None:
262      return
263
264    temp_dir = self.get_temp_dir()
265    self.addCleanup(shutil.rmtree, temp_dir)
266    h5_path = os.path.join(temp_dir, 'test.h5')
267
268    num_hidden = 5
269    input_dim = 3
270    num_classes = 2
271    with self.cached_session():
272      ref_model = keras.models.Sequential()
273      ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
274                                       name='d1'))
275      ref_model.add(keras.layers.Dense(num_classes, name='d2'))
276      ref_model.compile(loss=keras.losses.MSE,
277                        optimizer=keras.optimizers.RMSprop(lr=0.0001),
278                        metrics=[keras.metrics.categorical_accuracy])
279
280      f_ref_model = h5py.File(h5_path, 'w')
281      hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
282
283      f_model = h5py.File(h5_path, 'r')
284      model = keras.models.Sequential()
285      model.add(keras.layers.Dense(num_hidden, use_bias=False,
286                                   input_dim=input_dim, name='d1'))
287      model.add(keras.layers.Dense(num_classes, name='d2'))
288      model.compile(loss=keras.losses.MSE,
289                    optimizer=keras.optimizers.RMSprop(lr=0.0001),
290                    metrics=[keras.metrics.categorical_accuracy])
291    with self.assertRaisesRegexp(ValueError,
292                                 r'Layer #0 \(named \"d1\"\) expects 1 '
293                                 r'weight\(s\), but the saved weights have 2 '
294                                 r'element\(s\)\.'):
295      hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
296
297  @test_util.run_deprecated_v1
298  def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
299    if h5py is None:
300      return
301
302    temp_dir = self.get_temp_dir()
303    self.addCleanup(shutil.rmtree, temp_dir)
304    h5_path = os.path.join(temp_dir, 'test.h5')
305
306    num_hidden = 5
307    input_dim = 3
308    num_classes = 2
309    with self.cached_session():
310      ref_model = keras.models.Sequential()
311      ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
312                                       name='d1'))
313      ref_model.add(keras.layers.Dense(num_classes, name='d2'))
314      ref_model.compile(loss=keras.losses.MSE,
315                        optimizer=keras.optimizers.RMSprop(lr=0.0001),
316                        metrics=[keras.metrics.categorical_accuracy])
317
318      f_ref_model = h5py.File(h5_path, 'w')
319      hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
320
321      f_model = h5py.File(h5_path, 'r')
322      model = keras.models.Sequential()
323      model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim,
324                                   name='d1'))
325      model.add(keras.layers.Dense(num_classes, name='d2'))
326      model.compile(loss=keras.losses.MSE,
327                    optimizer=keras.optimizers.RMSprop(lr=0.0001),
328                    metrics=[keras.metrics.categorical_accuracy])
329      with self.assertRaisesRegexp(ValueError,
330                                   r'Layer #0 \(named "d1"\), weight '
331                                   r'<tf\.Variable \'d1_1\/kernel:0\' '
332                                   r'shape=\(3, 10\) dtype=float32> has '
333                                   r'shape \(3, 10\), but the saved weight has '
334                                   r'shape \(3, 5\)\.'):
335        hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
336
337
338class TestWholeModelSaving(test.TestCase):
339
340  @test_util.run_v1_only('b/120994067')
341  def test_sequential_model_saving(self):
342    if h5py is None:
343      self.skipTest('h5py required to run this test')
344
345    with self.cached_session():
346      model = keras.models.Sequential()
347      model.add(keras.layers.Dense(2, input_shape=(3,)))
348      model.add(keras.layers.RepeatVector(3))
349      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
350      model.compile(
351          loss=keras.losses.MSE,
352          optimizer=keras.optimizers.RMSprop(lr=0.0001),
353          metrics=[
354              keras.metrics.categorical_accuracy,
355              keras.metrics.CategoricalCrossentropy(
356                  name='cce', label_smoothing=constant_op.constant(0.2)),
357          ],
358          weighted_metrics=[
359              keras.metrics.categorical_crossentropy,
360              keras.metrics.CategoricalCrossentropy(
361                  name='cce', label_smoothing=constant_op.constant(0.2)),
362          ],
363          sample_weight_mode='temporal')
364
365      x = np.random.random((1, 3))
366      y = np.random.random((1, 3, 3))
367      model.train_on_batch(x, y)
368
369      out = model.predict(x)
370      fd, fname = tempfile.mkstemp('.h5')
371      keras.models.save_model(model, fname)
372
373      new_model = keras.models.load_model(fname)
374      os.close(fd)
375      os.remove(fname)
376
377      out2 = new_model.predict(x)
378      self.assertAllClose(out, out2, atol=1e-05)
379
380      # test that new updates are the same with both models
381      x = np.random.random((1, 3))
382      y = np.random.random((1, 3, 3))
383      model.train_on_batch(x, y)
384      new_model.train_on_batch(x, y)
385
386      x = np.random.random((1, 3))
387      y = np.random.random((1, 3, 3))
388      eval_out = model.evaluate(x, y)
389      eval_out2 = new_model.evaluate(x, y)
390      self.assertArrayNear(eval_out, eval_out2, 0.001)
391
392      out = model.predict(x)
393      out2 = new_model.predict(x)
394
395      # TODO(b/120930751) This tolerance should be 1e-05,
396      # very concerning that its not.
397      self.assertAllClose(out, out2, atol=1e-03)
398
399  @test_util.run_deprecated_v1
400  def test_sequential_model_saving_without_input_shape(self):
401    if h5py is None:
402      self.skipTest('h5py required to run this test')
403
404    with self.cached_session():
405      model = keras.models.Sequential()
406      model.add(keras.layers.Dense(2))
407      model.add(keras.layers.RepeatVector(3))
408      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
409      model.compile(
410          loss=keras.losses.MSE,
411          optimizer=keras.optimizers.RMSprop(lr=0.0001),
412          metrics=[
413              keras.metrics.categorical_accuracy,
414              keras.metrics.CategoricalAccuracy()
415          ],
416          weighted_metrics=[
417              keras.metrics.categorical_accuracy,
418              keras.metrics.CategoricalAccuracy()
419          ],
420          sample_weight_mode='temporal')
421      x = np.random.random((1, 3))
422      y = np.random.random((1, 3, 3))
423      model.train_on_batch(x, y)
424
425      out = model.predict(x)
426      fd, fname = tempfile.mkstemp('.h5', dir=self.get_temp_dir())
427      model.save(fname)
428
429      new_model = keras.models.load_model(fname)
430      os.close(fd)
431      os.remove(fname)
432
433      out2 = new_model.predict(x)
434      self.assertAllClose(out, out2, atol=1e-05)
435
436  def test_sequential_model_saving_without_compile(self):
437    if h5py is None:
438      self.skipTest('h5py required to run this test')
439
440    with self.cached_session():
441      model = keras.models.Sequential()
442      model.add(keras.layers.Dense(2, input_shape=(3,)))
443      model.add(keras.layers.RepeatVector(3))
444      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
445
446      x = np.random.random((1, 3))
447      out = model.predict(x)
448      fd, fname = tempfile.mkstemp('.h5')
449
450      # Save the model without any compilation or training.
451      keras.models.save_model(model, fname)
452
453      new_model = keras.models.load_model(fname)
454      os.close(fd)
455      os.remove(fname)
456
457      out2 = new_model.predict(x)
458      self.assertAllClose(out, out2, atol=1e-05)
459
460  @test_util.run_deprecated_v1
461  def test_sequential_model_saving_2(self):
462    if h5py is None:
463      self.skipTest('h5py required to run this test')
464
465    with self.cached_session():
466      # test with custom optimizer, loss
467
468      class CustomOp(keras.optimizers.RMSprop):
469        pass
470
471      def custom_loss(y_true, y_pred):
472        return keras.losses.mse(y_true, y_pred)
473
474      model = keras.models.Sequential()
475      model.add(keras.layers.Dense(2, input_shape=(3,)))
476      model.add(keras.layers.Dense(3))
477      model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc'])
478
479      x = np.random.random((1, 3))
480      y = np.random.random((1, 3))
481      model.train_on_batch(x, y)
482
483      out = model.predict(x)
484      fd, fname = tempfile.mkstemp('.h5')
485      keras.models.save_model(model, fname)
486
487      model = keras.models.load_model(
488          fname,
489          custom_objects={'CustomOp': CustomOp,
490                          'custom_loss': custom_loss})
491      os.close(fd)
492      os.remove(fname)
493
494      out2 = model.predict(x)
495      self.assertAllClose(out, out2, atol=1e-05)
496
497  @test_util.run_deprecated_v1
498  def test_functional_model_saving(self):
499    if h5py is None:
500      self.skipTest('h5py required to run this test')
501
502    with self.cached_session():
503      inputs = keras.layers.Input(shape=(3,))
504      x = keras.layers.Dense(2)(inputs)
505      output = keras.layers.Dense(3)(x)
506
507      model = keras.models.Model(inputs, output)
508      model.compile(
509          loss=keras.losses.MSE,
510          optimizer=keras.optimizers.RMSprop(lr=0.0001),
511          metrics=[
512              keras.metrics.categorical_accuracy,
513              keras.metrics.CategoricalAccuracy()
514          ],
515          weighted_metrics=[
516              keras.metrics.categorical_accuracy,
517              keras.metrics.CategoricalAccuracy()
518          ])
519      x = np.random.random((1, 3))
520      y = np.random.random((1, 3))
521      model.train_on_batch(x, y)
522
523      out = model.predict(x)
524      fd, fname = tempfile.mkstemp('.h5')
525      keras.models.save_model(model, fname)
526
527      model = keras.models.load_model(fname)
528      os.close(fd)
529      os.remove(fname)
530
531      out2 = model.predict(x)
532      self.assertAllClose(out, out2, atol=1e-05)
533
534  def test_saving_without_compilation(self):
535    if h5py is None:
536      self.skipTest('h5py required to run this test')
537
538    with self.cached_session():
539      model = keras.models.Sequential()
540      model.add(keras.layers.Dense(2, input_shape=(3,)))
541      model.add(keras.layers.Dense(3))
542      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
543
544      fd, fname = tempfile.mkstemp('.h5')
545      keras.models.save_model(model, fname)
546      model = keras.models.load_model(fname)
547      os.close(fd)
548      os.remove(fname)
549
550  def test_saving_with_tf_optimizer(self):
551    if h5py is None:
552      self.skipTest('h5py required to run this test')
553
554    with self.cached_session():
555      model = keras.models.Sequential()
556      model.add(keras.layers.Dense(2, input_shape=(3,)))
557      model.add(keras.layers.Dense(3))
558      model.compile(loss='mse',
559                    optimizer=training_module.AdadeltaOptimizer(0.1),
560                    metrics=['acc'])
561
562      fd, fname = tempfile.mkstemp('.h5')
563      keras.models.save_model(model, fname)
564      model = keras.models.load_model(fname)
565      os.close(fd)
566      os.remove(fname)
567
568  def test_saving_right_after_compilation(self):
569    if h5py is None:
570      self.skipTest('h5py required to run this test')
571
572    with self.cached_session():
573      model = keras.models.Sequential()
574      model.add(keras.layers.Dense(2, input_shape=(3,)))
575      model.add(keras.layers.Dense(3))
576      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
577      model._make_train_function()
578
579      fd, fname = tempfile.mkstemp('.h5')
580      keras.models.save_model(model, fname)
581      model = keras.models.load_model(fname)
582      os.close(fd)
583      os.remove(fname)
584
585  def test_saving_lambda_numpy_array_arguments(self):
586    with self.cached_session():
587      if h5py is None:
588        self.skipTest('h5py required to run this test')
589
590      mean = np.random.random((4, 2, 3))
591      std = np.abs(np.random.random((4, 2, 3))) + 1e-5
592      inputs = keras.layers.Input(shape=(4, 2, 3))
593      output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std,
594                                   arguments={'mu': mean, 'std': std})(inputs)
595      model = keras.models.Model(inputs, output)
596      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
597
598      fd, fname = tempfile.mkstemp('.h5')
599      keras.models.save_model(model, fname)
600
601      model = keras.models.load_model(fname)
602      os.close(fd)
603      os.remove(fname)
604
605      self.assertAllClose(mean, model.layers[1].arguments['mu'])
606      self.assertAllClose(std, model.layers[1].arguments['std'])
607
608  def test_saving_model_with_long_layer_names(self):
609    if h5py is None:
610      self.skipTest('h5py required to run this test')
611
612    with self.cached_session():
613      # This layer name will make the `layers_name` HDF5 attribute blow
614      # out of proportion. Note that it fits into the internal HDF5
615      # attribute memory limit on its own but because h5py converts
616      # the list of layer names into numpy array, which uses the same
617      # amout of memory for every item, it increases the memory
618      # requirements substantially.
619      x = keras.Input(shape=(2,), name='input_' + ('x' * (2**15)))
620      f = x
621      for i in range(4):
622        f = keras.layers.Dense(2, name='dense_%d' % (i,))(f)
623      model = keras.Model(inputs=[x], outputs=[f])
624      model.compile(loss='mse', optimizer='adam', metrics=['acc'])
625
626      x = np.random.random((1, 2))
627      y = np.random.random((1, 2))
628      model.train_on_batch(x, y)
629      out = model.predict(x)
630
631      fd, fname = tempfile.mkstemp('.h5')
632      keras.models.save_model(model, fname)
633      model = keras.models.load_model(fname)
634
635      # Check that the HDF5 files contains chunked array
636      # of layer names.
637      with h5py.File(fname, 'r') as h5file:
638        num_names_arrays = len([attr for attr in h5file['model_weights'].attrs
639                                if attr.startswith('layer_names')])
640      # The chunking of layer names array should have happened.
641      self.assertGreater(num_names_arrays, 0)
642      out2 = model.predict(x)
643      self.assertAllClose(out, out2, atol=1e-05)
644
645      # Cleanup
646      os.close(fd)
647      os.remove(fname)
648
649  def test_saving_model_with_long_weights_names(self):
650    if h5py is None:
651      self.skipTest('h5py required to run this test')
652
653    with self.cached_session():
654      x = keras.Input(shape=(2,), name='nested_model_input')
655      f = x
656      for i in range(4):
657        f = keras.layers.Dense(2, name='nested_model_dense_%d' % (i,))(f)
658      # This layer name will make the `weights_name`
659      # HDF5 attribute blow out of proportion.
660      f = keras.layers.Dense(2, name='nested_model_output' + ('x' * (2**14)))(f)
661      nested_model = keras.Model(inputs=[x], outputs=[f], name='nested_model')
662
663      x = keras.Input(shape=(2,), name='outer_model_input')
664      f = nested_model(x)
665      f = keras.layers.Dense(2, name='outer_model_output')(f)
666
667      model = keras.Model(inputs=[x], outputs=[f])
668      model.compile(loss='mse', optimizer='adam', metrics=['acc'])
669
670      x = np.random.random((1, 2))
671      y = np.random.random((1, 2))
672      model.train_on_batch(x, y)
673      out = model.predict(x)
674
675      fd, fname = tempfile.mkstemp('.h5')
676      keras.models.save_model(model, fname)
677      model = keras.models.load_model(fname)
678
679      # Check that the HDF5 files contains chunked array
680      # of weight names.
681      with h5py.File(fname, 'r') as h5file:
682        num_weight_arrays = len(
683            [attr for attr in h5file['model_weights']['nested_model'].attrs
684             if attr.startswith('weight_names')])
685      # The chunking of layer names array should have happened.
686      self.assertGreater(num_weight_arrays, 0)
687      out2 = model.predict(x)
688      self.assertAllClose(out, out2, atol=1e-05)
689
690      # Cleanup
691      os.close(fd)
692      os.remove(fname)
693
694  @test_util.run_deprecated_v1
695  def test_model_saving_to_pre_created_h5py_file(self):
696    if h5py is None:
697      self.skipTest('h5py required to run this test')
698
699    with self.cached_session():
700      inputs = keras.Input(shape=(3,))
701      x = keras.layers.Dense(2)(inputs)
702      outputs = keras.layers.Dense(3)(x)
703
704      model = keras.Model(inputs, outputs)
705      model.compile(
706          loss=keras.losses.MSE,
707          optimizer=keras.optimizers.Adam(),
708          metrics=[
709              keras.metrics.categorical_accuracy,
710              keras.metrics.CategoricalAccuracy()
711          ])
712      x = np.random.random((1, 3))
713      y = np.random.random((1, 3))
714      model.train_on_batch(x, y)
715
716      out = model.predict(x)
717      fd, fname = tempfile.mkstemp('.h5')
718      with h5py.File(fname, mode='r+') as h5file:
719        keras.models.save_model(model, h5file)
720        loaded_model = keras.models.load_model(h5file)
721        out2 = loaded_model.predict(x)
722      self.assertAllClose(out, out2, atol=1e-05)
723
724      # Test non-default options in h5
725      with h5py.File('_', driver='core',
726                     backing_store=False) as h5file:
727        keras.models.save_model(model, h5file)
728        loaded_model = keras.models.load_model(h5file)
729        out2 = loaded_model.predict(x)
730      self.assertAllClose(out, out2, atol=1e-05)
731
732      # Cleanup
733      os.close(fd)
734      os.remove(fname)
735
736  def test_saving_constant_initializer_with_numpy(self):
737    if h5py is None:
738      self.skipTest('h5py required to run this test')
739
740    with self.cached_session():
741      model = keras.models.Sequential()
742      model.add(
743          keras.layers.Dense(
744              2,
745              input_shape=(3,),
746              kernel_initializer=keras.initializers.Constant(np.ones((3, 2)))))
747      model.add(keras.layers.Dense(3))
748      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
749      fd, fname = tempfile.mkstemp('.h5')
750      keras.models.save_model(model, fname)
751      model = keras.models.load_model(fname)
752      os.close(fd)
753      os.remove(fname)
754
755
756class SubclassedModel(training.Model):
757
758  def __init__(self):
759    super(SubclassedModel, self).__init__()
760    self.x_layer = keras.layers.Dense(3)
761    self.b_layer = keras.layers.Dense(1)
762
763  def call(self, a):
764    return self.b_layer(self.x_layer(a))
765
766
767class TestWeightSavingAndLoadingTFFormat(test.TestCase):
768
769  def test_keras_optimizer_warning(self):
770    graph = ops.Graph()
771    with graph.as_default(), self.session(graph):
772      model = keras.models.Sequential()
773      model.add(keras.layers.Dense(2, input_shape=(3,)))
774      model.add(keras.layers.Dense(3))
775      model.compile(loss='mse', optimizer=optimizers.Adam(), metrics=['acc'])
776      model._make_train_function()
777      temp_dir = self.get_temp_dir()
778      prefix = os.path.join(temp_dir, 'ckpt')
779      with test.mock.patch.object(logging, 'warning') as mock_log:
780        model.save_weights(prefix)
781        self.assertRegexpMatches(
782            str(mock_log.call_args),
783            'Keras optimizer')
784
785  @test_util.run_in_graph_and_eager_modes
786  def test_tensorflow_format_overwrite(self):
787    with self.cached_session() as session:
788      model = SubclassedModel()
789      temp_dir = self.get_temp_dir()
790      prefix = os.path.join(temp_dir, 'ckpt')
791
792      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
793      executing_eagerly = context.executing_eagerly()
794      model(x)  # pylint: disable=not-callable
795      if not executing_eagerly:
796        session.run([v.initializer for v in model.variables])
797      model.save_weights(prefix, save_format='tensorflow')
798      model.save_weights(prefix, save_format='tensorflow', overwrite=True)
799      with self.assertRaises(EOFError):
800        # Indirectly tests that the user is prompted
801        model.save_weights(prefix, save_format='tensorflow', overwrite=False)
802
803  def test_no_default_session(self):
804    with ops.Graph().as_default():
805      self.assertFalse(ops.get_default_session())
806      data = np.random.random((1000, 32)).astype(np.float32)
807      labels = np.random.random((1000, 10)).astype(np.float32)
808
809      model = keras.models.Sequential([
810          keras.layers.Dense(10, activation='softmax'),
811          keras.layers.Dense(10, activation='softmax')])
812
813      model.compile(optimizer=training_module.RMSPropOptimizer(0.001),
814                    loss='categorical_crossentropy',
815                    metrics=['accuracy'])
816
817      model.fit(data, labels)
818      fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt')
819      model.save_weights(fname)
820      model.load_weights(fname)
821
822  def test_no_graph_pollution(self):
823    with context.graph_mode():
824      graph = ops.Graph()
825      with graph.as_default(), self.session(graph) as session:
826        model = SubclassedModel()
827        temp_dir = self.get_temp_dir()
828        prefix = os.path.join(temp_dir, 'ckpt')
829
830        x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
831        model(x)  # pylint: disable=not-callable
832        session.run([v.initializer for v in model.variables])
833        model.save_weights(prefix, save_format='tensorflow')
834        op_count = len(graph.get_operations())
835        model.save_weights(prefix, save_format='tensorflow')
836        self.assertEqual(len(graph.get_operations()), op_count)
837
838        model.load_weights(prefix)
839        op_count = len(graph.get_operations())
840        model.load_weights(prefix)
841        self.assertEqual(len(graph.get_operations()), op_count)
842
843  def _weight_loading_test_template(self, make_model_fn):
844    with self.cached_session():
845      model = make_model_fn()
846      model.compile(
847          loss='mse',
848          optimizer=training_module.RMSPropOptimizer(0.1),
849          metrics=['acc', keras.metrics.CategoricalAccuracy()])
850      temp_dir = self.get_temp_dir()
851      prefix = os.path.join(temp_dir, 'ckpt')
852      train_x = np.random.random((3, 2))
853      train_y = np.random.random((3,))
854      x = constant_op.constant(train_x, dtype=dtypes.float32)
855
856      model.train_on_batch(train_x, train_y)
857      model.save_weights(prefix, save_format='tf')
858      ref_y_before_train = model.predict(train_x)
859      model.train_on_batch(train_x, train_y)
860      ref_y_after_train = model.predict(train_x)
861      for v in model.variables:
862        self.evaluate(
863            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
864
865      self.addCleanup(shutil.rmtree, temp_dir)
866
867      model.load_weights(prefix)
868      self.assertAllClose(ref_y_before_train, self.evaluate(model(x)))
869
870      # Test restore-on-create if this is a subclassed Model (graph Networks
871      # will have already created their variables).
872      load_model = make_model_fn()
873      load_model.load_weights(prefix)
874      self.assertAllClose(
875          ref_y_before_train,
876          self.evaluate(load_model(x)))
877      load_model = make_model_fn()
878      load_model.load_weights(prefix)
879      # We need to run some of the restore ops for predict(), but not all
880      # variables have been created yet (optimizer slot variables). Tests
881      # incremental restore.
882      load_model.predict(train_x)
883      load_model.compile(
884          loss='mse',
885          optimizer=training_module.RMSPropOptimizer(0.1),
886          metrics=['acc', keras.metrics.CategoricalAccuracy()])
887      load_model.train_on_batch(train_x, train_y)
888      self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
889
890  @test_util.run_in_graph_and_eager_modes
891  def test_weight_loading_graph_model(self):
892    def _make_graph_model():
893      a = keras.layers.Input(shape=(2,))
894      x = keras.layers.Dense(3)(a)
895      b = keras.layers.Dense(1)(x)
896      return keras.models.Model(a, b)
897
898    self._weight_loading_test_template(_make_graph_model)
899
900  @test_util.run_in_graph_and_eager_modes
901  def test_weight_loading_subclassed_model(self):
902    self._weight_loading_test_template(SubclassedModel)
903
904  def _new_layer_weight_loading_test_template(
905      self, first_model_fn, second_model_fn, restore_init_fn):
906    with self.cached_session() as session:
907      model = first_model_fn()
908      temp_dir = self.get_temp_dir()
909      prefix = os.path.join(temp_dir, 'ckpt')
910
911      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
912      executing_eagerly = context.executing_eagerly()
913      ref_y_tensor = model(x)
914      if not executing_eagerly:
915        session.run([v.initializer for v in model.variables])
916      ref_y = self.evaluate(ref_y_tensor)
917      model.save_weights(prefix)
918      self.assertEqual(
919          prefix,
920          checkpoint_management.latest_checkpoint(temp_dir))
921      for v in model.variables:
922        self.evaluate(
923            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
924
925      self.addCleanup(shutil.rmtree, temp_dir)
926
927      second_model = second_model_fn()
928      second_model.load_weights(prefix)
929      second_model(x)
930      self.evaluate(restore_init_fn(second_model))
931      second_model.save_weights(prefix)
932      # Check that the second model's checkpoint loads into the original model
933      model.load_weights(prefix)
934      y = self.evaluate(model(x))
935      self.assertAllClose(ref_y, y)
936
937  @test_util.run_in_graph_and_eager_modes
938  def test_weight_loading_graph_model_added_layer(self):
939    def _save_graph_model():
940      a = keras.layers.Input(shape=(2,))
941      x = keras.layers.Dense(3, name='first')(a)
942      b = keras.layers.Dense(1, name='second')(x)
943      return keras.models.Model(a, b)
944    def _restore_graph_model():
945      a = keras.layers.Input(shape=(2,))
946      x = keras.layers.Dense(3, name='first')(a)
947      y = keras.layers.Dense(1, name='second')(x)
948      b = keras.layers.Dense(3, name='secondjr')(y)
949      return keras.models.Model(a, b)
950    def _restore_init_fn(restore_model):
951      return [v.initializer for v in restore_model.layers[-1].variables]
952
953    self._new_layer_weight_loading_test_template(
954        _save_graph_model, _restore_graph_model,
955        _restore_init_fn)
956
957  @test_util.run_in_graph_and_eager_modes
958  def test_weight_loading_graph_model_added_no_weight_layer(self):
959    def _save_graph_model():
960      a = keras.layers.Input(shape=(2,))
961      x = keras.layers.Dense(3, name='first')(a)
962      b = keras.layers.Dense(1, name='second')(x)
963      return keras.models.Model(a, b)
964    def _restore_graph_model():
965      a = keras.layers.Input(shape=(2,))
966      x = keras.layers.Dense(3, name='first')(a)
967      y = keras.layers.Dropout(rate=0.1)(x)
968      b = keras.layers.Dense(1, name='second')(y)
969      return keras.models.Model(a, b)
970    def _restore_init_fn(restore_model):
971      del restore_model  # unused
972      return []
973
974    self._new_layer_weight_loading_test_template(
975        _save_graph_model, _restore_graph_model,
976        _restore_init_fn)
977
978  @test_util.run_in_graph_and_eager_modes
979  def test_weight_loading_subclassed_model_added_layer(self):
980
981    class SubclassedModelRestore(training.Model):
982
983      def __init__(self):
984        super(SubclassedModelRestore, self).__init__()
985        self.x_layer = keras.layers.Dense(3)
986        self.y_layer = keras.layers.Dense(3)
987        self.b_layer = keras.layers.Dense(1)
988
989      def call(self, a):
990        return self.b_layer(self.y_layer(self.x_layer(a)))
991
992    def _restore_init_fn(restore_model):
993      return [v.initializer for v in restore_model.y_layer.variables]
994
995    self._new_layer_weight_loading_test_template(
996        SubclassedModel, SubclassedModelRestore,
997        _restore_init_fn)
998
999  @test_util.run_in_graph_and_eager_modes
1000  def test_incompatible_checkpoint(self):
1001    save_path = trackable.Checkpoint().save(
1002        os.path.join(self.get_temp_dir(), 'ckpt'))
1003    m = keras.Model()
1004    with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
1005      m.load_weights(save_path)
1006    m.dense = keras.layers.Dense(2)
1007    m.dense(constant_op.constant([[1.]]))
1008    with self.assertRaisesRegexp(
1009        AssertionError, 'Nothing except the root object matched'):
1010      m.load_weights(save_path)
1011
1012  @test_util.run_in_graph_and_eager_modes
1013  def test_directory_passed(self):
1014    m = keras.Model()
1015    v = m.add_weight(name='v', shape=[])
1016    self.evaluate(v.assign(42.))
1017    prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/')
1018    m.save_weights(prefix)
1019    self.evaluate(v.assign(2.))
1020    m.load_weights(prefix)
1021    self.assertEqual(42., self.evaluate(v))
1022
1023  @test_util.run_in_graph_and_eager_modes
1024  def test_relative_path(self):
1025    m = keras.Model()
1026    v = m.add_weight(name='v', shape=[])
1027    os.chdir(self.get_temp_dir())
1028
1029    prefix = 'ackpt'
1030    self.evaluate(v.assign(42.))
1031    m.save_weights(prefix)
1032    self.assertTrue(file_io.file_exists('ackpt.index'))
1033    self.evaluate(v.assign(1.))
1034    m.load_weights(prefix)
1035    self.assertEqual(42., self.evaluate(v))
1036
1037    prefix = 'subdir/ackpt'
1038    self.evaluate(v.assign(43.))
1039    m.save_weights(prefix)
1040    self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
1041    self.evaluate(v.assign(2.))
1042    m.load_weights(prefix)
1043    self.assertEqual(43., self.evaluate(v))
1044
1045    prefix = 'ackpt/'
1046    self.evaluate(v.assign(44.))
1047    m.save_weights(prefix)
1048    self.assertTrue(file_io.file_exists('ackpt/.index'))
1049    self.evaluate(v.assign(3.))
1050    m.load_weights(prefix)
1051    self.assertEqual(44., self.evaluate(v))
1052
1053  @test_util.run_in_graph_and_eager_modes
1054  def test_nonexistant_prefix_directory(self):
1055    m = keras.Model()
1056    v = m.add_weight(name='v', shape=[])
1057    self.evaluate(v.assign(42.))
1058    prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt')
1059    m.save_weights(prefix)
1060    self.evaluate(v.assign(2.))
1061    m.load_weights(prefix)
1062    self.assertEqual(42., self.evaluate(v))
1063
1064if __name__ == '__main__':
1065  test.main()
1066