1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#,============================================================================ 15"""Tests for model saving in the HDF5 format.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23import uuid 24 25from absl.testing import parameterized 26import numpy as np 27 28from tensorflow.python import keras 29from tensorflow.python.eager import context 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.keras import combinations 34from tensorflow.python.keras import keras_parameterized 35from tensorflow.python.keras import optimizer_v1 36from tensorflow.python.keras import testing_utils 37from tensorflow.python.keras.engine import training 38from tensorflow.python.keras.saving import hdf5_format 39from tensorflow.python.lib.io import file_io 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import random_ops 42from tensorflow.python.platform import test 43from tensorflow.python.training import checkpoint_management 44from tensorflow.python.training import training as training_module 45from tensorflow.python.training.tracking import util as trackable 46 47try: 48 import h5py # pylint:disable=g-import-not-at-top 49except ImportError: 50 h5py = None 51 52 53@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 54class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): 55 56 def _save_model_dir(self, dirname='saved_model'): 57 temp_dir = self.get_temp_dir() 58 self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True) 59 return os.path.join(temp_dir, dirname) 60 61 @keras_parameterized.run_with_all_weight_formats 62 def test_weight_loading(self): 63 saved_model_dir = self._save_model_dir() 64 save_format = testing_utils.get_save_format() 65 with self.cached_session(): 66 a = keras.layers.Input(shape=(2,)) 67 x = keras.layers.Dense(3)(a) 68 b = keras.layers.Dense(1)(x) 69 model = keras.models.Model(a, b) 70 71 x = np.random.random((3, 2)) 72 ref_y = model.predict(x) 73 weights = model.get_weights() 74 model.set_weights(weights) 75 y = model.predict(x) 76 self.assertAllClose(ref_y, y) 77 78 with self.assertRaises(ValueError): 79 model.set_weights(weights[1:]) 80 with self.assertRaises(ValueError): 81 model.set_weights(weights[::-1]) 82 83 model.save_weights(saved_model_dir, save_format=save_format) 84 model.load_weights(saved_model_dir) 85 y = model.predict(x) 86 self.assertAllClose(ref_y, y) 87 88 def test_weight_preprocessing(self): 89 input_dim = 3 90 output_dim = 3 91 size = 2 92 cases = [ 93 [ 94 (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), 95 [np.random.random((2, 1)), np.random.random((2, 1))], 96 (None, 3, 2), 97 ], 98 [ 99 (keras.layers.TimeDistributed(keras.layers.Dense(1))), 100 [np.random.random((2, 1)), np.random.random((1,))], 101 (None, 3, 2), 102 ], 103 [ 104 (keras.layers.Conv1D(output_dim, size, use_bias=False)), 105 [np.random.random((output_dim, input_dim, size, 1))], 106 (None, 4, input_dim), 107 ], 108 [ 109 (keras.layers.Conv2D(output_dim, size, 110 use_bias=False, data_format='channels_first')), 111 [np.random.random((output_dim, input_dim, size, size))], 112 (None, input_dim, 4, 4), 113 ], 114 [ 115 (keras.layers.Conv2DTranspose(output_dim, size, 116 use_bias=False, 117 data_format='channels_first')), 118 [np.random.random((output_dim, input_dim, size, size))], 119 (None, input_dim, 4, 4), 120 ], 121 [ 122 (keras.layers.Conv2DTranspose(output_dim, size, 123 use_bias=False, 124 data_format='channels_last')), 125 [np.random.random((size, size, input_dim, output_dim))], 126 (None, 4, 4, input_dim), 127 ], 128 [ 129 (keras.layers.Conv3D(output_dim, size, 130 use_bias=False, data_format='channels_first')), 131 [np.random.random((output_dim, input_dim, size, size, size))], 132 (None, input_dim, 4, 4, 4), 133 ], 134 [ 135 (keras.layers.GRUV1(output_dim)), 136 [np.random.random((input_dim, output_dim)), 137 np.random.random((output_dim, output_dim)), 138 np.random.random((output_dim,)), 139 np.random.random((input_dim, output_dim)), 140 np.random.random((output_dim, output_dim)), 141 np.random.random((output_dim,)), 142 np.random.random((input_dim, output_dim)), 143 np.random.random((output_dim, output_dim)), 144 np.random.random((output_dim,))], 145 (None, 4, input_dim), 146 ], 147 [ 148 (keras.layers.LSTMV1(output_dim)), 149 [np.random.random((input_dim, output_dim)), 150 np.random.random((output_dim, output_dim)), 151 np.random.random((output_dim,)), 152 np.random.random((input_dim, output_dim)), 153 np.random.random((output_dim, output_dim)), 154 np.random.random((output_dim,)), 155 np.random.random((input_dim, output_dim)), 156 np.random.random((output_dim, output_dim)), 157 np.random.random((output_dim,)), 158 np.random.random((input_dim, output_dim)), 159 np.random.random((output_dim, output_dim)), 160 np.random.random((output_dim,))], 161 (None, 4, input_dim), 162 ], 163 ] 164 for layer, weights, input_shape in cases: 165 layer.build(input_shape) 166 _ = hdf5_format.preprocess_weights_for_loading( 167 layer, weights, original_keras_version='1') 168 169 model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) 170 _ = hdf5_format.preprocess_weights_for_loading( 171 model, model.weights, original_keras_version='1') 172 173 x = keras.Input((2,)) 174 y = keras.layers.Dense(2)(x) 175 model = keras.models.Model(x, y) 176 _ = hdf5_format.preprocess_weights_for_loading( 177 model, model.weights, original_keras_version='1') 178 179 @parameterized.named_parameters( 180 ('gru', keras.layers.GRU, { 181 'units': 2, 182 'input_shape': (3, 5) 183 }), 184 ('gru_with_reset_after', keras.layers.GRU, { 185 'units': 2, 186 'input_shape': (3, 5), 187 'reset_after': True 188 }), 189 ('lstm', keras.layers.LSTM, { 190 'units': 2, 191 'input_shape': (3, 5) 192 }), 193 ('cudnngru', keras.layers.CuDNNGRU, { 194 'units': 2, 195 'input_shape': (3, 5) 196 }), 197 ('cudnnlstm', keras.layers.CuDNNLSTM, { 198 'units': 2, 199 'input_shape': (3, 5) 200 })) 201 def test_preprocess_weights_for_loading_rnn_should_be_idempotent( 202 self, layer_class, layer_args): 203 with self.cached_session(): 204 layer = layer_class(**layer_args) 205 layer.build(input_shape=layer_args.get('input_shape')) 206 weights1 = layer.get_weights() 207 weights2 = hdf5_format.preprocess_weights_for_loading( 208 layer, weights1) 209 _ = [ 210 self.assertAllClose(x, y, rtol=1e-05) 211 for (x, y) in zip(weights1, weights2) 212 ] 213 214 def test_sequential_weight_loading(self): 215 if h5py is None: 216 return 217 218 h5_path = self._save_model_dir('test.h5') 219 220 num_hidden = 5 221 input_dim = 3 222 batch_size = 5 223 num_classes = 2 224 225 with self.cached_session(): 226 model = keras.models.Sequential() 227 model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) 228 model.add(keras.layers.Dense(num_classes)) 229 230 x = np.random.random((batch_size, input_dim)) 231 ref_y = model.predict(x) 232 233 model.save_weights(h5_path) 234 235 model = keras.models.Sequential() 236 model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) 237 model.add(keras.layers.Dense(num_classes)) 238 model.load_weights(h5_path) 239 y = model.predict(x) 240 241 self.assertAllClose(y, ref_y) 242 243 @keras_parameterized.run_with_all_saved_model_formats( 244 exclude_formats=['tf_no_traces']) 245 def test_nested_model_weight_loading(self): 246 save_format = testing_utils.get_save_format() 247 saved_model_dir = self._save_model_dir() 248 249 batch_size = 5 250 shape = (None, None, 3) 251 252 with self.cached_session(): 253 def gen_model(): 254 255 def seq_model(): 256 model = keras.models.Sequential([ 257 keras.layers.Conv2D(3, 1, input_shape=shape), 258 keras.layers.BatchNormalization()]) 259 return model 260 261 x = inner_inputs = keras.layers.Input((None, None, 3)) 262 x = seq_model()(x) 263 x = seq_model()(x) 264 inner_model = keras.models.Model(inner_inputs, x) 265 266 inputs = keras.layers.Input(shape) 267 return keras.models.Model(inputs, inner_model(inputs)) 268 269 model = gen_model() 270 x = np.random.random((batch_size, 1, 1, 3)) 271 ref_y = model.predict(x) 272 273 model.save_weights(saved_model_dir, save_format=save_format) 274 275 model = gen_model() 276 model.load_weights(saved_model_dir) 277 y = model.predict(x) 278 279 self.assertAllClose(y, ref_y) 280 281 def test_sequential_weight_loading_group_name_with_incorrect_length(self): 282 if h5py is None: 283 return 284 285 h5_path = self._save_model_dir('test.h5') 286 287 num_hidden = 5 288 input_dim = 3 289 num_classes = 2 290 with self.cached_session(): 291 ref_model = keras.models.Sequential() 292 ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, 293 name='d1')) 294 ref_model.add(keras.layers.Dense(num_classes, name='d2')) 295 ref_model.compile(loss=keras.losses.MSE, 296 optimizer='rmsprop', 297 metrics=[keras.metrics.categorical_accuracy]) 298 299 f_ref_model = h5py.File(h5_path, 'w') 300 hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) 301 302 f_model = h5py.File(h5_path, 'r') 303 model = keras.models.Sequential() 304 model.add(keras.layers.Dense(num_hidden, use_bias=False, 305 input_dim=input_dim, name='d1')) 306 model.add(keras.layers.Dense(num_classes, name='d2')) 307 model.compile(loss=keras.losses.MSE, 308 optimizer='rmsprop', 309 metrics=[keras.metrics.categorical_accuracy]) 310 with self.assertRaisesRegex( 311 ValueError, r'Layer #0 \(named \"d1\"\) expects 1 ' 312 r'weight\(s\), but the saved weights have 2 ' 313 r'element\(s\)\.'): 314 hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) 315 316 hdf5_format.load_weights_from_hdf5_group_by_name( 317 f_model, model.layers, skip_mismatch=True) 318 self.assertAllClose(keras.backend.get_value(ref_model.layers[1].kernel), 319 keras.backend.get_value(model.layers[1].kernel)) 320 321 def test_sequential_weight_loading_group_name_with_incorrect_shape(self): 322 if h5py is None: 323 return 324 325 h5_path = self._save_model_dir('test.h5') 326 327 num_hidden = 5 328 input_dim = 3 329 num_classes = 2 330 with ops.Graph().as_default(), self.cached_session(): 331 ref_model = keras.models.Sequential() 332 ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim, 333 name='d1')) 334 ref_model.add(keras.layers.Dense(num_classes, name='d2')) 335 ref_model.compile(loss=keras.losses.MSE, 336 optimizer=optimizer_v1.RMSprop(lr=0.0001), 337 metrics=[keras.metrics.categorical_accuracy]) 338 339 f_ref_model = h5py.File(h5_path, 'w') 340 keras.backend.set_value(ref_model.layers[1].bias, [3.5] * num_classes) 341 hdf5_format.save_weights_to_hdf5_group(f_ref_model, ref_model.layers) 342 343 f_model = h5py.File(h5_path, 'r') 344 model = keras.models.Sequential() 345 model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim, 346 name='d1')) 347 model.add(keras.layers.Dense(num_classes, name='d2')) 348 model.compile(loss=keras.losses.MSE, 349 optimizer=optimizer_v1.RMSprop(lr=0.0001), 350 metrics=[keras.metrics.categorical_accuracy]) 351 with self.assertRaisesRegex( 352 ValueError, r'Layer #0 \(named "d1"\), weight ' 353 r'<tf\.Variable \'d1_1\/kernel:0\' ' 354 r'shape=\(3, 10\) dtype=float32> has ' 355 r'shape \(3, 10\), but the saved weight has ' 356 r'shape \(3, 5\)\.'): 357 hdf5_format.load_weights_from_hdf5_group_by_name(f_model, model.layers) 358 359 hdf5_format.load_weights_from_hdf5_group_by_name( 360 f_model, model.layers, skip_mismatch=True) 361 self.assertAllClose([3.5] * num_classes, 362 keras.backend.get_value(model.layers[1].bias)) 363 364 @keras_parameterized.run_with_all_saved_model_formats( 365 exclude_formats=['tf_no_traces']) 366 @keras_parameterized.run_with_all_model_types 367 def test_load_weights_from_saved_model(self): 368 save_path = self._save_model_dir() 369 save_format = testing_utils.get_save_format() 370 371 if save_format == 'h5' and testing_utils.get_model_type() == 'subclass': 372 # TODO(b/173646281): HDF5 format currently does not allow saving 373 # subclassed models. 374 return 375 376 with self.cached_session(): 377 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 378 data = np.random.random((1, 3)) 379 labels = np.random.random((1, 4)) 380 model.compile(loss='mse', optimizer='rmsprop') 381 model.fit(data, labels) 382 model.save(save_path, save_format=save_format) 383 new_model = testing_utils.get_small_mlp(1, 4, input_dim=3) 384 if testing_utils.get_model_type() == 'subclass': 385 # Call on test data to build the model. 386 new_model.predict(data) 387 new_model.load_weights(save_path) 388 self.assertAllClose(model.weights, new_model.weights) 389 390 391class SubclassedModel(training.Model): 392 393 def __init__(self): 394 super(SubclassedModel, self).__init__() 395 self.x_layer = keras.layers.Dense(3) 396 self.b_layer = keras.layers.Dense(1) 397 398 def call(self, a): 399 return self.b_layer(self.x_layer(a)) 400 401 402class TestWeightSavingAndLoadingTFFormat(test.TestCase, parameterized.TestCase): 403 404 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 405 def test_tensorflow_format_overwrite(self): 406 with self.cached_session() as session: 407 model = SubclassedModel() 408 temp_dir = self.get_temp_dir() 409 prefix = os.path.join(temp_dir, 'ckpt') 410 411 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 412 executing_eagerly = context.executing_eagerly() 413 model(x) # pylint: disable=not-callable 414 if not executing_eagerly: 415 session.run([v.initializer for v in model.variables]) 416 model.save_weights(prefix, save_format='tensorflow') 417 model.save_weights(prefix, save_format='tensorflow', overwrite=True) 418 with self.assertRaises(EOFError): 419 # Indirectly tests that the user is prompted 420 model.save_weights(prefix, save_format='tensorflow', overwrite=False) 421 422 def test_no_default_session(self): 423 with ops.Graph().as_default(): 424 self.assertFalse(ops.get_default_session()) 425 data = np.random.random((1000, 32)).astype(np.float32) 426 labels = np.random.random((1000, 10)).astype(np.float32) 427 428 model = keras.models.Sequential([ 429 keras.layers.Dense(10, activation='softmax'), 430 keras.layers.Dense(10, activation='softmax')]) 431 432 model.compile(optimizer=training_module.RMSPropOptimizer(0.001), 433 loss='categorical_crossentropy', 434 metrics=['accuracy']) 435 436 model.fit(data, labels) 437 fname = os.path.join(self.get_temp_dir(), 'weights', 'ckpt') 438 model.save_weights(fname) 439 model.load_weights(fname) 440 441 def test_no_graph_pollution(self): 442 with ops.get_default_graph().as_default(): 443 graph = ops.Graph() 444 with graph.as_default(), self.session(graph) as session: 445 model = SubclassedModel() 446 temp_dir = self.get_temp_dir() 447 prefix = os.path.join(temp_dir, 'ckpt') 448 449 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 450 model(x) # pylint: disable=not-callable 451 session.run([v.initializer for v in model.variables]) 452 model.save_weights(prefix, save_format='tensorflow') 453 op_count = len(graph.get_operations()) 454 model.save_weights(prefix, save_format='tensorflow') 455 self.assertLen(graph.get_operations(), op_count) 456 457 model.load_weights(prefix) 458 op_count = len(graph.get_operations()) 459 model.load_weights(prefix) 460 self.assertLen(graph.get_operations(), op_count) 461 462 def _weight_loading_test_template(self, make_model_fn): 463 with self.cached_session(): 464 model = make_model_fn() 465 model.compile( 466 loss='mse', 467 optimizer=training_module.RMSPropOptimizer(0.1), 468 metrics=['acc', keras.metrics.CategoricalAccuracy()]) 469 temp_dir = self.get_temp_dir() 470 prefix = os.path.join(temp_dir, 'ckpt') 471 train_x = np.random.random((3, 2)) 472 train_y = np.random.random((3,)) 473 x = constant_op.constant(train_x, dtype=dtypes.float32) 474 475 model.train_on_batch(train_x, train_y) 476 model.save_weights(prefix, save_format='tf') 477 ref_y_before_train = model.predict(train_x) 478 model.train_on_batch(train_x, train_y) 479 ref_y_after_train = model.predict(train_x) 480 for v in model.variables: 481 self.evaluate( 482 v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) 483 484 self.addCleanup(shutil.rmtree, temp_dir) 485 486 model.load_weights(prefix) 487 self.assertAllClose(ref_y_before_train, self.evaluate(model(x))) 488 489 # Test restore-on-create if this is a subclassed Model (graph Networks 490 # will have already created their variables). 491 load_model = make_model_fn() 492 load_model.load_weights(prefix) 493 self.assertAllClose( 494 ref_y_before_train, 495 self.evaluate(load_model(x))) 496 load_model = make_model_fn() 497 load_model.load_weights(prefix) 498 # We need to run some of the restore ops for predict(), but not all 499 # variables have been created yet (optimizer slot variables). Tests 500 # incremental restore. 501 load_model.predict(train_x) 502 load_model.compile( 503 loss='mse', 504 optimizer=training_module.RMSPropOptimizer(0.1), 505 metrics=['acc', keras.metrics.CategoricalAccuracy()]) 506 load_model.train_on_batch(train_x, train_y) 507 self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) 508 509 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 510 def test_weight_loading_graph_model(self): 511 def _make_graph_model(): 512 a = keras.layers.Input(shape=(2,)) 513 x = keras.layers.Dense(3)(a) 514 b = keras.layers.Dense(1)(x) 515 return keras.models.Model(a, b) 516 517 self._weight_loading_test_template(_make_graph_model) 518 519 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 520 def test_weight_loading_subclassed_model(self): 521 self._weight_loading_test_template(SubclassedModel) 522 523 def _new_layer_weight_loading_test_template( 524 self, first_model_fn, second_model_fn): 525 with self.cached_session() as session: 526 model = first_model_fn() 527 temp_dir = self.get_temp_dir() 528 prefix = os.path.join(temp_dir, 'ckpt') 529 530 x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) 531 executing_eagerly = context.executing_eagerly() 532 ref_y_tensor = model(x) 533 if not executing_eagerly: 534 session.run([v.initializer for v in model.variables]) 535 ref_y = self.evaluate(ref_y_tensor) 536 model.save_weights(prefix) 537 self.assertEqual( 538 prefix, 539 checkpoint_management.latest_checkpoint(temp_dir)) 540 for v in model.variables: 541 self.evaluate( 542 v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) 543 544 self.addCleanup(shutil.rmtree, temp_dir) 545 546 second_model = second_model_fn() 547 status = second_model.load_weights(prefix) 548 second_model(x) 549 status.run_restore_ops() 550 second_model.save_weights(prefix) 551 # Check that the second model's checkpoint loads into the original model 552 status = model.load_weights(prefix) 553 status.run_restore_ops(session) 554 y = self.evaluate(model(x)) 555 self.assertAllClose(ref_y, y) 556 557 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 558 def test_weight_loading_graph_model_added_layer(self): 559 def _save_graph_model(): 560 a = keras.layers.Input(shape=(2,)) 561 x = keras.layers.Dense(3, name='first')(a) 562 b = keras.layers.Dense(1, name='second')(x) 563 return keras.models.Model(a, b) 564 def _restore_graph_model(): 565 a = keras.layers.Input(shape=(2,)) 566 x = keras.layers.Dense(3, name='first')(a) 567 y = keras.layers.Dense(1, name='second')(x) 568 b = keras.layers.Dense(3, name='secondjr')(y) 569 return keras.models.Model(a, b) 570 571 self._new_layer_weight_loading_test_template( 572 _save_graph_model, _restore_graph_model) 573 574 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 575 def test_weight_loading_graph_model_added_no_weight_layer(self): 576 def _save_graph_model(): 577 a = keras.layers.Input(shape=(2,)) 578 x = keras.layers.Dense(3, name='first')(a) 579 b = keras.layers.Dense(1, name='second')(x) 580 return keras.models.Model(a, b) 581 def _restore_graph_model(): 582 a = keras.layers.Input(shape=(2,)) 583 x = keras.layers.Dense(3, name='first')(a) 584 b = keras.layers.Dense(1, name='second')(x) 585 y = keras.layers.Dropout(rate=0.1)(b) 586 return keras.models.Model(a, y) 587 588 self._new_layer_weight_loading_test_template( 589 _save_graph_model, _restore_graph_model) 590 591 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 592 def test_weight_loading_subclassed_model_added_layer(self): 593 594 class SubclassedModelRestore(training.Model): 595 596 def __init__(self): 597 super(SubclassedModelRestore, self).__init__() 598 self.x_layer = keras.layers.Dense(3) 599 self.y_layer = keras.layers.Dense(3) 600 self.b_layer = keras.layers.Dense(1) 601 602 def call(self, a): 603 return self.b_layer(self.y_layer(self.x_layer(a))) 604 605 self._new_layer_weight_loading_test_template( 606 SubclassedModel, SubclassedModelRestore) 607 608 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 609 def test_incompatible_checkpoint(self): 610 save_path = trackable.Checkpoint().save( 611 os.path.join(self.get_temp_dir(), 'ckpt')) 612 m = DummySubclassModel() 613 with self.assertRaisesRegex(AssertionError, 'Nothing to load'): 614 m.load_weights(save_path) 615 m.dense = keras.layers.Dense(2) 616 m.dense(constant_op.constant([[1.]])) 617 with self.assertRaisesRegex(AssertionError, 618 'Nothing except the root object matched'): 619 m.load_weights(save_path) 620 621 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 622 def test_directory_passed(self): 623 with self.cached_session(): 624 m = DummySubclassModel() 625 v = m.add_weight(name='v', shape=[]) 626 self.evaluate(v.assign(42.)) 627 prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'ckpt/') 628 m.save_weights(prefix) 629 self.evaluate(v.assign(2.)) 630 m.load_weights(prefix) 631 self.assertEqual(42., self.evaluate(v)) 632 633 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 634 def test_relative_path(self): 635 with self.cached_session(): 636 m = DummySubclassModel() 637 v = m.add_weight(name='v', shape=[]) 638 os.chdir(self.get_temp_dir()) 639 640 prefix = 'ackpt' 641 self.evaluate(v.assign(42.)) 642 m.save_weights(prefix) 643 self.assertTrue(file_io.file_exists_v2('ackpt.index')) 644 self.evaluate(v.assign(1.)) 645 m.load_weights(prefix) 646 self.assertEqual(42., self.evaluate(v)) 647 648 prefix = 'subdir/ackpt' 649 self.evaluate(v.assign(43.)) 650 m.save_weights(prefix) 651 self.assertTrue(file_io.file_exists_v2('subdir/ackpt.index')) 652 self.evaluate(v.assign(2.)) 653 m.load_weights(prefix) 654 self.assertEqual(43., self.evaluate(v)) 655 656 prefix = 'ackpt/' 657 self.evaluate(v.assign(44.)) 658 m.save_weights(prefix) 659 self.assertTrue(file_io.file_exists_v2('ackpt/.index')) 660 self.evaluate(v.assign(3.)) 661 m.load_weights(prefix) 662 self.assertEqual(44., self.evaluate(v)) 663 664 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 665 def test_nonexistent_prefix_directory(self): 666 with self.cached_session(): 667 m = DummySubclassModel() 668 v = m.add_weight(name='v', shape=[]) 669 self.evaluate(v.assign(42.)) 670 prefix = os.path.join(self.get_temp_dir(), str(uuid.uuid4()), 'bckpt') 671 m.save_weights(prefix) 672 self.evaluate(v.assign(2.)) 673 m.load_weights(prefix) 674 self.assertEqual(42., self.evaluate(v)) 675 676 677class DummySubclassModel(training.Model): 678 pass 679 680 681if __name__ == '__main__': 682 test.main() 683