1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#,============================================================================
15"""Tests for model saving in the HDF5 format."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23import uuid
24
25from absl.testing import parameterized
26import numpy as np
27
28from tensorflow.python import keras
29from tensorflow.python.eager import context
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.keras import combinations
34from tensorflow.python.keras import keras_parameterized
35from tensorflow.python.keras import optimizer_v1
36from tensorflow.python.keras import testing_utils
37from tensorflow.python.keras.engine import training
38from tensorflow.python.keras.saving import hdf5_format
39from tensorflow.python.lib.io import file_io
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import random_ops
42from tensorflow.python.platform import test
43from tensorflow.python.training import checkpoint_management
44from tensorflow.python.training import training as training_module
45from tensorflow.python.training.tracking import util as trackable
46
47try:
48  import h5py  # pylint:disable=g-import-not-at-top
49except ImportError:
50  h5py = None
51
52
53@combinations.generate(combinations.combine(mode=['graph', 'eager']))
54class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
55
56  def _save_model_dir(self, dirname='saved_model'):
57    temp_dir = self.get_temp_dir()
58    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
59    return os.path.join(temp_dir, dirname)
60
61  @keras_parameterized.run_with_all_weight_formats
62  def test_weight_loading(self):
63    saved_model_dir = self._save_model_dir()
64    save_format = testing_utils.get_save_format()
65    with self.cached_session():
66      a = keras.layers.Input(shape=(2,))
67      x = keras.layers.Dense(3)(a)
68      b = keras.layers.Dense(1)(x)
69      model = keras.models.Model(a, b)
70
71      x = np.random.random((3, 2))
72      ref_y = model.predict(x)
73      weights = model.get_weights()
74      model.set_weights(weights)
75      y = model.predict(x)
76      self.assertAllClose(ref_y, y)
77
78      with self.assertRaises(ValueError):
79        model.set_weights(weights[1:])
80      with self.assertRaises(ValueError):
81        model.set_weights(weights[::-1])
82
83      model.save_weights(saved_model_dir, save_format=save_format)
84      model.load_weights(saved_model_dir)
85      y = model.predict(x)
86      self.assertAllClose(ref_y, y)
87
88  def test_weight_preprocessing(self):
89    input_dim = 3
90    output_dim = 3
91    size = 2
92    cases = [
93        [
94            (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
95            [np.random.random((2, 1)), np.random.random((2, 1))],
96            (None, 3, 2),
97        ],
98        [
99            (keras.layers.TimeDistributed(keras.layers.Dense(1))),
100            [np.random.random((2, 1)), np.random.random((1,))],
101            (None, 3, 2),
102        ],
103        [
104            (keras.layers.Conv1D(output_dim, size, use_bias=False)),
105            [np.random.random((output_dim, input_dim, size, 1))],
106            (None, 4, input_dim),
107        ],
108        [
109            (keras.layers.Conv2D(output_dim, size,
110                                 use_bias=False, data_format='channels_first')),
111            [np.random.random((output_dim, input_dim, size, size))],
112            (None, input_dim, 4, 4),
113        ],
114        [
115            (keras.layers.Conv2DTranspose(output_dim, size,
116                                          use_bias=False,
117                                          data_format='channels_first')),
118            [np.random.random((output_dim, input_dim, size, size))],
119            (None, input_dim, 4, 4),
120        ],
121        [
122            (keras.layers.Conv2DTranspose(output_dim, size,
123                                          use_bias=False,
124                                          data_format='channels_last')),
125            [np.random.random((size, size, input_dim, output_dim))],
126            (None, 4, 4, input_dim),
127        ],
128        [
129            (keras.layers.Conv3D(output_dim, size,
130                                 use_bias=False, data_format='channels_first')),
131            [np.random.random((output_dim, input_dim, size, size, size))],
132            (None, input_dim, 4, 4, 4),
133        ],
134        [
135            (keras.layers.GRUV1(output_dim)),
136            [np.random.random((input_dim, output_dim)),
137             np.random.random((output_dim, output_dim)),
138             np.random.random((output_dim,)),
139             np.random.random((input_dim, output_dim)),
140             np.random.random((output_dim, output_dim)),
141             np.random.random((output_dim,)),
142             np.random.random((input_dim, output_dim)),
143             np.random.random((output_dim, output_dim)),
144             np.random.random((output_dim,))],
145            (None, 4, input_dim),
146        ],
147        [
148            (keras.layers.LSTMV1(output_dim)),
149            [np.random.random((input_dim, output_dim)),
150             np.random.random((output_dim, output_dim)),
151             np.random.random((output_dim,)),
152             np.random.random((input_dim, output_dim)),
153             np.random.random((output_dim, output_dim)),
154             np.random.random((output_dim,)),
155             np.random.random((input_dim, output_dim)),
156             np.random.random((output_dim, output_dim)),
157             np.random.random((output_dim,)),
158             np.random.random((input_dim, output_dim)),
159             np.random.random((output_dim, output_dim)),
160             np.random.random((output_dim,))],
161            (None, 4, input_dim),
162        ],
163    ]
164    for layer, weights, input_shape in cases:
165      layer.build(input_shape)
166      _ = hdf5_format.preprocess_weights_for_loading(
167          layer, weights, original_keras_version='1')
168
169    model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
170    _ = hdf5_format.preprocess_weights_for_loading(
171        model, model.weights, original_keras_version='1')
172
173    x = keras.Input((2,))
174    y = keras.layers.Dense(2)(x)
175    model = keras.models.Model(x, y)
176    _ = hdf5_format.preprocess_weights_for_loading(
177        model, model.weights, original_keras_version='1')
178
179  @parameterized.named_parameters(
180      ('gru', keras.layers.GRU, {
181          'units': 2,
182          'input_shape': (3, 5)
183      }),
184      ('gru_with_reset_after', keras.layers.GRU, {
185          'units': 2,
186          'input_shape': (3, 5),
187          'reset_after': True
188      }),
189      ('lstm', keras.layers.LSTM, {
190          'units': 2,
191          'input_shape': (3, 5)
192      }),
193      ('cudnngru', keras.layers.CuDNNGRU, {
194          'units': 2,
195          'input_shape': (3, 5)
196      }),
197      ('cudnnlstm', keras.layers.CuDNNLSTM, {
198          'units': 2,
199          'input_shape': (3, 5)
200      }))
201  def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
202      self, layer_class, layer_args):
203    with self.cached_session():
204      layer = layer_class(**layer_args)
205      layer.build(input_shape=layer_args.get('input_shape'))
206      weights1 = layer.get_weights()
207      weights2 = hdf5_format.preprocess_weights_for_loading(
208          layer, weights1)
209      _ = [
210          self.assertAllClose(x, y, rtol=1e-05)
211          for (x, y) in zip(weights1, weights2)
212      ]
213
214  def test_sequential_weight_loading(self):
215    if h5py is None:
216      return
217
218    h5_path = self._save_model_dir('test.h5')
219
220    num_hidden = 5
221    input_dim = 3
222    batch_size = 5
223    num_classes = 2
224
225    with self.cached_session():
226      model = keras.models.Sequential()
227      model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
228      model.add(keras.layers.Dense(num_classes))
229
230      x = np.random.random((batch_size, input_dim))
231      ref_y = model.predict(x)
232
233      model.save_weights(h5_path)
234
235      model = keras.models.Sequential()
236      model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
237      model.add(keras.layers.Dense(num_classes))
238      model.load_weights(h5_path)
239      y = model.predict(x)
240
241      self.assertAllClose(y, ref_y)
242
243  @keras_parameterized.run_with_all_saved_model_formats(
244      exclude_formats=['tf_no_traces'])
245  def test_nested_model_weight_loading(self):
246    save_format = testing_utils.get_save_format()
247    saved_model_dir = self._save_model_dir()
248
249    batch_size = 5
250    shape = (None, None, 3)
251
252    with self.cached_session():
253      def gen_model():
254
255        def seq_model():
256          model = keras.models.Sequential([
257              keras.layers.Conv2D(3, 1, input_shape=shape),
258              keras.layers.BatchNormalization()])
259          return model
260
261        x = inner_inputs = keras.layers.Input((None, None, 3))
262        x = seq_model()(x)
263        x = seq_model()(x)
264        inner_model = keras.models.Model(inner_inputs, x)
265
266        inputs = keras.layers.Input(shape)
267        return keras.models.Model(inputs, inner_model(inputs))
268
269      model = gen_model()
270      x = np.random.random((batch_size, 1, 1, 3))
271      ref_y = model.predict(x)
272
273      model.save_weights(saved_model_dir, save_format=save_format)
274
275      model = gen_model()
276      model.load_weights(saved_model_dir)
277      y = model.predict(x)
278
279      self.assertAllClose(y, ref_y)
280
281  def test_sequential_weight_loading_group_name_with_incorrect_length(self):
282    if h5py is None:
283      return
284
285    h5_path = self._save_model_dir('test.h5')
286
287    num_hidden = 5
288    input_dim = 3
289    num_classes = 2
290    with self.cached_session():
291      ref_model = keras.models.Sequential()
292      ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
293                                       name='d1'))
294      ref_model.add(keras.layers.Dense(num_classes, name='d2'))
295      ref_model.compile(loss=keras.losses.MSE,
296                        optimizer='rmsprop',
297                        metrics=[keras.metrics.categorical_accuracy])
298
299      f_ref_model = h5py.File(h5_path, 'w')
300      hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
301
302      f_model = h5py.File(h5_path, 'r')
303      model = keras.models.Sequential()
304      model.add(keras.layers.Dense(num_hidden, use_bias=False,
305                                   input_dim=input_dim, name='d1'))
306      model.add(keras.layers.Dense(num_classes, name='d2'))
307      model.compile(loss=keras.losses.MSE,
308                    optimizer='rmsprop',
309                    metrics=[keras.metrics.categorical_accuracy])
310      with self.assertRaisesRegex(
311          ValueError, r'Layer #0 \(named \"d1\"\) expects 1 '
312          r'weight\(s\), but the saved weights have 2 '
313          r'element\(s\)\.'):
314        hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
315
316      hdf5_format.load_weights_from_hdf5_group_by_name(
317          f_model, model.layers, skip_mismatch=True)
318      self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel),
319                          keras.backend.get_value(model.layers[1].kernel))
320
321  def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
322    if h5py is None:
323      return
324
325    h5_path = self._save_model_dir('test.h5')
326
327    num_hidden = 5
328    input_dim = 3
329    num_classes = 2
330    with ops.Graph().as_default(), self.cached_session():
331      ref_model = keras.models.Sequential()
332      ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
333                                       name='d1'))
334      ref_model.add(keras.layers.Dense(num_classes, name='d2'))
335      ref_model.compile(loss=keras.losses.MSE,
336                        optimizer=optimizer_v1.RMSprop(lr=0.0001),
337                        metrics=[keras.metrics.categorical_accuracy])
338
339      f_ref_model = h5py.File(h5_path, 'w')
340      keras.backend.set_value(ref_model.layers[1].bias, [3.5] * num_classes)
341      hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
342
343      f_model = h5py.File(h5_path, 'r')
344      model = keras.models.Sequential()
345      model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim,
346                                   name='d1'))
347      model.add(keras.layers.Dense(num_classes, name='d2'))
348      model.compile(loss=keras.losses.MSE,
349                    optimizer=optimizer_v1.RMSprop(lr=0.0001),
350                    metrics=[keras.metrics.categorical_accuracy])
351      with self.assertRaisesRegex(
352          ValueError, r'Layer #0 \(named "d1"\), weight '
353          r'<tf\.Variable \'d1_1\/kernel:0\' '
354          r'shape=\(3, 10\) dtype=float32> has '
355          r'shape \(3, 10\), but the saved weight has '
356          r'shape \(3, 5\)\.'):
357        hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers)
358
359      hdf5_format.load_weights_from_hdf5_group_by_name(
360          f_model, model.layers, skip_mismatch=True)
361      self.assertAllClose([3.5] * num_classes,
362                          keras.backend.get_value(model.layers[1].bias))
363
364  @keras_parameterized.run_with_all_saved_model_formats(
365      exclude_formats=['tf_no_traces'])
366  @keras_parameterized.run_with_all_model_types
367  def test_load_weights_from_saved_model(self):
368    save_path = self._save_model_dir()
369    save_format = testing_utils.get_save_format()
370
371    if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
372      # TODO(b/173646281): HDF5 format currently does not allow saving
373      # subclassed models.
374      return
375
376    with self.cached_session():
377      model = testing_utils.get_small_mlp(1, 4, input_dim=3)
378      data = np.random.random((1, 3))
379      labels = np.random.random((1, 4))
380      model.compile(loss='mse', optimizer='rmsprop')
381      model.fit(data, labels)
382      model.save(save_path, save_format=save_format)
383      new_model = testing_utils.get_small_mlp(1, 4, input_dim=3)
384      if testing_utils.get_model_type() == 'subclass':
385        # Call on test data to build the model.
386        new_model.predict(data)
387      new_model.load_weights(save_path)
388      self.assertAllClose(model.weights, new_model.weights)
389
390
391class SubclassedModel(training.Model):
392
393  def __init__(self):
394    super(SubclassedModel, self).__init__()
395    self.x_layer = keras.layers.Dense(3)
396    self.b_layer = keras.layers.Dense(1)
397
398  def call(self, a):
399    return self.b_layer(self.x_layer(a))
400
401
402class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase):
403
404  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
405  def test_tensorflow_format_overwrite(self):
406    with self.cached_session() as session:
407      model = SubclassedModel()
408      temp_dir = self.get_temp_dir()
409      prefix = os.path.join(temp_dir, 'ckpt')
410
411      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
412      executing_eagerly = context.executing_eagerly()
413      model(x)  # pylint: disable=not-callable
414      if not executing_eagerly:
415        session.run([v.initializer for v in model.variables])
416      model.save_weights(prefix, save_format='tensorflow')
417      model.save_weights(prefix, save_format='tensorflow', overwrite=True)
418      with self.assertRaises(EOFError):
419        # Indirectly tests that the user is prompted
420        model.save_weights(prefix, save_format='tensorflow', overwrite=False)
421
422  def test_no_default_session(self):
423    with ops.Graph().as_default():
424      self.assertFalse(ops.get_default_session())
425      data = np.random.random((1000, 32)).astype(np.float32)
426      labels = np.random.random((1000, 10)).astype(np.float32)
427
428      model = keras.models.Sequential([
429          keras.layers.Dense(10, activation='softmax'),
430          keras.layers.Dense(10, activation='softmax')])
431
432      model.compile(optimizer=training_module.RMSPropOptimizer(0.001),
433                    loss='categorical_crossentropy',
434                    metrics=['accuracy'])
435
436      model.fit(data, labels)
437      fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt')
438      model.save_weights(fname)
439      model.load_weights(fname)
440
441  def test_no_graph_pollution(self):
442    with ops.get_default_graph().as_default():
443      graph = ops.Graph()
444      with graph.as_default(), self.session(graph) as session:
445        model = SubclassedModel()
446        temp_dir = self.get_temp_dir()
447        prefix = os.path.join(temp_dir, 'ckpt')
448
449        x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
450        model(x)  # pylint: disable=not-callable
451        session.run([v.initializer for v in model.variables])
452        model.save_weights(prefix, save_format='tensorflow')
453        op_count = len(graph.get_operations())
454        model.save_weights(prefix, save_format='tensorflow')
455        self.assertLen(graph.get_operations(), op_count)
456
457        model.load_weights(prefix)
458        op_count = len(graph.get_operations())
459        model.load_weights(prefix)
460        self.assertLen(graph.get_operations(), op_count)
461
462  def _weight_loading_test_template(self, make_model_fn):
463    with self.cached_session():
464      model = make_model_fn()
465      model.compile(
466          loss='mse',
467          optimizer=training_module.RMSPropOptimizer(0.1),
468          metrics=['acc', keras.metrics.CategoricalAccuracy()])
469      temp_dir = self.get_temp_dir()
470      prefix = os.path.join(temp_dir, 'ckpt')
471      train_x = np.random.random((3, 2))
472      train_y = np.random.random((3,))
473      x = constant_op.constant(train_x, dtype=dtypes.float32)
474
475      model.train_on_batch(train_x, train_y)
476      model.save_weights(prefix, save_format='tf')
477      ref_y_before_train = model.predict(train_x)
478      model.train_on_batch(train_x, train_y)
479      ref_y_after_train = model.predict(train_x)
480      for v in model.variables:
481        self.evaluate(
482            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
483
484      self.addCleanup(shutil.rmtree, temp_dir)
485
486      model.load_weights(prefix)
487      self.assertAllClose(ref_y_before_train, self.evaluate(model(x)))
488
489      # Test restore-on-create if this is a subclassed Model (graph Networks
490      # will have already created their variables).
491      load_model = make_model_fn()
492      load_model.load_weights(prefix)
493      self.assertAllClose(
494          ref_y_before_train,
495          self.evaluate(load_model(x)))
496      load_model = make_model_fn()
497      load_model.load_weights(prefix)
498      # We need to run some of the restore ops for predict(), but not all
499      # variables have been created yet (optimizer slot variables). Tests
500      # incremental restore.
501      load_model.predict(train_x)
502      load_model.compile(
503          loss='mse',
504          optimizer=training_module.RMSPropOptimizer(0.1),
505          metrics=['acc', keras.metrics.CategoricalAccuracy()])
506      load_model.train_on_batch(train_x, train_y)
507      self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
508
509  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
510  def test_weight_loading_graph_model(self):
511    def _make_graph_model():
512      a = keras.layers.Input(shape=(2,))
513      x = keras.layers.Dense(3)(a)
514      b = keras.layers.Dense(1)(x)
515      return keras.models.Model(a, b)
516
517    self._weight_loading_test_template(_make_graph_model)
518
519  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
520  def test_weight_loading_subclassed_model(self):
521    self._weight_loading_test_template(SubclassedModel)
522
523  def _new_layer_weight_loading_test_template(
524      self, first_model_fn, second_model_fn):
525    with self.cached_session() as session:
526      model = first_model_fn()
527      temp_dir = self.get_temp_dir()
528      prefix = os.path.join(temp_dir, 'ckpt')
529
530      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
531      executing_eagerly = context.executing_eagerly()
532      ref_y_tensor = model(x)
533      if not executing_eagerly:
534        session.run([v.initializer for v in model.variables])
535      ref_y = self.evaluate(ref_y_tensor)
536      model.save_weights(prefix)
537      self.assertEqual(
538          prefix,
539          checkpoint_management.latest_checkpoint(temp_dir))
540      for v in model.variables:
541        self.evaluate(
542            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
543
544      self.addCleanup(shutil.rmtree, temp_dir)
545
546      second_model = second_model_fn()
547      status = second_model.load_weights(prefix)
548      second_model(x)
549      status.run_restore_ops()
550      second_model.save_weights(prefix)
551      # Check that the second model's checkpoint loads into the original model
552      status = model.load_weights(prefix)
553      status.run_restore_ops(session)
554      y = self.evaluate(model(x))
555      self.assertAllClose(ref_y, y)
556
557  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
558  def test_weight_loading_graph_model_added_layer(self):
559    def _save_graph_model():
560      a = keras.layers.Input(shape=(2,))
561      x = keras.layers.Dense(3, name='first')(a)
562      b = keras.layers.Dense(1, name='second')(x)
563      return keras.models.Model(a, b)
564    def _restore_graph_model():
565      a = keras.layers.Input(shape=(2,))
566      x = keras.layers.Dense(3, name='first')(a)
567      y = keras.layers.Dense(1, name='second')(x)
568      b = keras.layers.Dense(3, name='secondjr')(y)
569      return keras.models.Model(a, b)
570
571    self._new_layer_weight_loading_test_template(
572        _save_graph_model, _restore_graph_model)
573
574  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
575  def test_weight_loading_graph_model_added_no_weight_layer(self):
576    def _save_graph_model():
577      a = keras.layers.Input(shape=(2,))
578      x = keras.layers.Dense(3, name='first')(a)
579      b = keras.layers.Dense(1, name='second')(x)
580      return keras.models.Model(a, b)
581    def _restore_graph_model():
582      a = keras.layers.Input(shape=(2,))
583      x = keras.layers.Dense(3, name='first')(a)
584      b = keras.layers.Dense(1, name='second')(x)
585      y = keras.layers.Dropout(rate=0.1)(b)
586      return keras.models.Model(a, y)
587
588    self._new_layer_weight_loading_test_template(
589        _save_graph_model, _restore_graph_model)
590
591  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
592  def test_weight_loading_subclassed_model_added_layer(self):
593
594    class SubclassedModelRestore(training.Model):
595
596      def __init__(self):
597        super(SubclassedModelRestore, self).__init__()
598        self.x_layer = keras.layers.Dense(3)
599        self.y_layer = keras.layers.Dense(3)
600        self.b_layer = keras.layers.Dense(1)
601
602      def call(self, a):
603        return self.b_layer(self.y_layer(self.x_layer(a)))
604
605    self._new_layer_weight_loading_test_template(
606        SubclassedModel, SubclassedModelRestore)
607
608  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
609  def test_incompatible_checkpoint(self):
610    save_path = trackable.Checkpoint().save(
611        os.path.join(self.get_temp_dir(), 'ckpt'))
612    m = DummySubclassModel()
613    with self.assertRaisesRegex(AssertionError, 'Nothing to load'):
614      m.load_weights(save_path)
615    m.dense = keras.layers.Dense(2)
616    m.dense(constant_op.constant([[1.]]))
617    with self.assertRaisesRegex(AssertionError,
618                                'Nothing except the root object matched'):
619      m.load_weights(save_path)
620
621  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
622  def test_directory_passed(self):
623    with self.cached_session():
624      m = DummySubclassModel()
625      v = m.add_weight(name='v', shape=[])
626      self.evaluate(v.assign(42.))
627      prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'ckpt/')
628      m.save_weights(prefix)
629      self.evaluate(v.assign(2.))
630      m.load_weights(prefix)
631      self.assertEqual(42., self.evaluate(v))
632
633  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
634  def test_relative_path(self):
635    with self.cached_session():
636      m = DummySubclassModel()
637      v = m.add_weight(name='v', shape=[])
638      os.chdir(self.get_temp_dir())
639
640      prefix = 'ackpt'
641      self.evaluate(v.assign(42.))
642      m.save_weights(prefix)
643      self.assertTrue(file_io.file_exists_v2('ackpt.index'))
644      self.evaluate(v.assign(1.))
645      m.load_weights(prefix)
646      self.assertEqual(42., self.evaluate(v))
647
648      prefix = 'subdir/ackpt'
649      self.evaluate(v.assign(43.))
650      m.save_weights(prefix)
651      self.assertTrue(file_io.file_exists_v2('subdir/ackpt.index'))
652      self.evaluate(v.assign(2.))
653      m.load_weights(prefix)
654      self.assertEqual(43., self.evaluate(v))
655
656      prefix = 'ackpt/'
657      self.evaluate(v.assign(44.))
658      m.save_weights(prefix)
659      self.assertTrue(file_io.file_exists_v2('ackpt/.index'))
660      self.evaluate(v.assign(3.))
661      m.load_weights(prefix)
662      self.assertEqual(44., self.evaluate(v))
663
664  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
665  def test_nonexistent_prefix_directory(self):
666    with self.cached_session():
667      m = DummySubclassModel()
668      v = m.add_weight(name='v', shape=[])
669      self.evaluate(v.assign(42.))
670      prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'bckpt')
671      m.save_weights(prefix)
672      self.evaluate(v.assign(2.))
673      m.load_weights(prefix)
674      self.assertEqual(42., self.evaluate(v))
675
676
677class DummySubclassModel(training.Model):
678  pass
679
680
681if __name__ == '__main__':
682  test.main()
683