1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tensorflow.python.training.saver.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import os 23import shutil 24import tempfile 25 26from google.protobuf import text_format 27 28from tensorflow.core.protobuf import saver_pb2 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops as ops_lib 32from tensorflow.python.framework import test_util 33from tensorflow.python.lib.io import file_io 34from tensorflow.python.ops import variables 35from tensorflow.python.platform import gfile 36from tensorflow.python.platform import test 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.training import checkpoint_management 39from tensorflow.python.training import saver as saver_module 40from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState 41from tensorflow.python.training.tracking import util 42 43 44class LatestCheckpointWithRelativePaths(test.TestCase): 45 46 @staticmethod 47 @contextlib.contextmanager 48 def tempWorkingDir(temppath): 49 cwd = os.getcwd() 50 os.chdir(temppath) 51 try: 52 yield 53 finally: 54 os.chdir(cwd) 55 56 @staticmethod 57 @contextlib.contextmanager 58 def tempDir(): 59 tempdir = tempfile.mkdtemp() 60 try: 61 yield tempdir 62 finally: 63 shutil.rmtree(tempdir) 64 65 @test_util.run_deprecated_v1 66 def testNameCollision(self): 67 # Make sure we have a clean directory to work in. 68 with self.tempDir() as tempdir: 69 # Jump to that directory until this test is done. 70 with self.tempWorkingDir(tempdir): 71 # Save training snapshots to a relative path. 72 traindir = "train/" 73 os.mkdir(traindir) 74 # Collides with the default name of the checkpoint state file. 75 filepath = os.path.join(traindir, "checkpoint") 76 77 with self.cached_session() as sess: 78 unused_a = variables.Variable(0.0) # So that Saver saves something. 79 variables.global_variables_initializer().run() 80 81 # Should fail. 82 saver = saver_module.Saver(sharded=False) 83 with self.assertRaisesRegexp(ValueError, "collides with"): 84 saver.save(sess, filepath) 85 86 # Succeeds: the file will be named "checkpoint-<step>". 87 saver.save(sess, filepath, global_step=1) 88 self.assertIsNotNone( 89 checkpoint_management.latest_checkpoint(traindir)) 90 91 # Succeeds: the file will be named "checkpoint-<i>-of-<n>". 92 saver = saver_module.Saver(sharded=True) 93 saver.save(sess, filepath) 94 self.assertIsNotNone( 95 checkpoint_management.latest_checkpoint(traindir)) 96 97 # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>". 98 saver = saver_module.Saver(sharded=True) 99 saver.save(sess, filepath, global_step=1) 100 self.assertIsNotNone( 101 checkpoint_management.latest_checkpoint(traindir)) 102 103 @test_util.run_deprecated_v1 104 def testRelativePath(self): 105 # Make sure we have a clean directory to work in. 106 with self.tempDir() as tempdir: 107 108 # Jump to that directory until this test is done. 109 with self.tempWorkingDir(tempdir): 110 111 # Save training snapshots to a relative path. 112 traindir = "train/" 113 os.mkdir(traindir) 114 115 filename = "snapshot" 116 filepath = os.path.join(traindir, filename) 117 118 with self.cached_session() as sess: 119 # Build a simple graph. 120 v0 = variables.Variable(0.0) 121 inc = v0.assign_add(1.0) 122 123 save = saver_module.Saver({"v0": v0}) 124 125 # Record a short training history. 126 variables.global_variables_initializer().run() 127 save.save(sess, filepath, global_step=0) 128 self.evaluate(inc) 129 save.save(sess, filepath, global_step=1) 130 self.evaluate(inc) 131 save.save(sess, filepath, global_step=2) 132 133 with self.cached_session() as sess: 134 # Build a new graph with different initialization. 135 v0 = variables.Variable(-1.0) 136 137 # Create a new saver. 138 save = saver_module.Saver({"v0": v0}) 139 variables.global_variables_initializer().run() 140 141 # Get the most recent checkpoint name from the training history file. 142 name = checkpoint_management.latest_checkpoint(traindir) 143 self.assertIsNotNone(name) 144 145 # Restore "v0" from that checkpoint. 146 save.restore(sess, name) 147 self.assertEqual(v0.eval(), 2.0) 148 149 150class CheckpointStateTest(test.TestCase): 151 152 def _get_test_dir(self, dirname): 153 test_dir = os.path.join(self.get_temp_dir(), dirname) 154 gfile.MakeDirs(test_dir) 155 return test_dir 156 157 def testAbsPath(self): 158 save_dir = self._get_test_dir("abs_paths") 159 abs_path = os.path.join(save_dir, "model-0") 160 ckpt = checkpoint_management.generate_checkpoint_state_proto( 161 save_dir, abs_path) 162 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 163 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 164 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 165 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 166 167 def testRelPath(self): 168 train_dir = "train" 169 model = os.path.join(train_dir, "model-0") 170 # model_checkpoint_path should have no "train" directory part. 171 new_rel_path = "model-0" 172 ckpt = checkpoint_management.generate_checkpoint_state_proto( 173 train_dir, model) 174 self.assertEqual(ckpt.model_checkpoint_path, new_rel_path) 175 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1) 176 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path) 177 178 def testAllModelCheckpointPaths(self): 179 save_dir = self._get_test_dir("all_models_test") 180 abs_path = os.path.join(save_dir, "model-0") 181 for paths in [None, [], ["model-2"]]: 182 ckpt = checkpoint_management.generate_checkpoint_state_proto( 183 save_dir, abs_path, all_model_checkpoint_paths=paths) 184 self.assertEqual(ckpt.model_checkpoint_path, abs_path) 185 self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path)) 186 self.assertEqual( 187 len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1) 188 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path) 189 190 def testUpdateCheckpointState(self): 191 save_dir = self._get_test_dir("update_checkpoint_state") 192 os.chdir(save_dir) 193 # Make a temporary train directory. 194 train_dir = "train" 195 os.mkdir(train_dir) 196 abs_path = os.path.join(save_dir, "model-0") 197 rel_path = os.path.join("train", "model-2") 198 checkpoint_management.update_checkpoint_state( 199 train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) 200 ckpt = checkpoint_management.get_checkpoint_state(train_dir) 201 self.assertEqual(ckpt.model_checkpoint_path, rel_path) 202 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 203 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) 204 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path) 205 206 def testUpdateCheckpointStateSaveRelativePaths(self): 207 save_dir = self._get_test_dir("update_checkpoint_state") 208 os.chdir(save_dir) 209 abs_path2 = os.path.join(save_dir, "model-2") 210 rel_path2 = "model-2" 211 abs_path0 = os.path.join(save_dir, "model-0") 212 rel_path0 = "model-0" 213 checkpoint_management.update_checkpoint_state_internal( 214 save_dir=save_dir, 215 model_checkpoint_path=abs_path2, 216 all_model_checkpoint_paths=[rel_path0, abs_path2], 217 save_relative_paths=True) 218 219 # File should contain relative paths. 220 file_content = file_io.read_file_to_string( 221 os.path.join(save_dir, "checkpoint")) 222 ckpt = CheckpointState() 223 text_format.Merge(file_content, ckpt) 224 self.assertEqual(ckpt.model_checkpoint_path, rel_path2) 225 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 226 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) 227 self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) 228 229 # get_checkpoint_state should return absolute paths. 230 ckpt = checkpoint_management.get_checkpoint_state(save_dir) 231 self.assertEqual(ckpt.model_checkpoint_path, abs_path2) 232 self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) 233 self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) 234 self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0) 235 236 def testCheckPointStateFailsWhenIncomplete(self): 237 save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") 238 os.chdir(save_dir) 239 ckpt_path = os.path.join(save_dir, "checkpoint") 240 ckpt_file = open(ckpt_path, "w") 241 ckpt_file.write("") 242 ckpt_file.close() 243 with self.assertRaises(ValueError): 244 checkpoint_management.get_checkpoint_state(save_dir) 245 246 def testCheckPointCompletesRelativePaths(self): 247 save_dir = self._get_test_dir("checkpoint_completes_relative_paths") 248 os.chdir(save_dir) 249 ckpt_path = os.path.join(save_dir, "checkpoint") 250 ckpt_file = open(ckpt_path, "w") 251 ckpt_file.write(""" 252 model_checkpoint_path: "./model.ckpt-687529" 253 all_model_checkpoint_paths: "./model.ckpt-687500" 254 all_model_checkpoint_paths: "./model.ckpt-687529" 255 """) 256 ckpt_file.close() 257 ckpt = checkpoint_management.get_checkpoint_state(save_dir) 258 self.assertEqual(ckpt.model_checkpoint_path, 259 os.path.join(save_dir, "./model.ckpt-687529")) 260 self.assertEqual(ckpt.all_model_checkpoint_paths[0], 261 os.path.join(save_dir, "./model.ckpt-687500")) 262 self.assertEqual(ckpt.all_model_checkpoint_paths[1], 263 os.path.join(save_dir, "./model.ckpt-687529")) 264 265 266class SaverUtilsTest(test.TestCase): 267 268 def setUp(self): 269 self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test") 270 gfile.MakeDirs(self._base_dir) 271 272 def tearDown(self): 273 gfile.DeleteRecursively(self._base_dir) 274 275 @test_util.run_deprecated_v1 276 def testCheckpointExists(self): 277 for sharded in (False, True): 278 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 279 with self.session(graph=ops_lib.Graph()) as sess: 280 unused_v = variables.Variable(1.0, name="v") 281 variables.global_variables_initializer().run() 282 saver = saver_module.Saver(sharded=sharded, write_version=version) 283 284 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) 285 self.assertFalse( 286 checkpoint_management.checkpoint_exists(path)) # Not saved yet. 287 288 ckpt_prefix = saver.save(sess, path) 289 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 290 291 ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir) 292 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 293 294 @test_util.run_deprecated_v1 295 def testGetCheckpointMtimes(self): 296 prefixes = [] 297 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 298 with self.session(graph=ops_lib.Graph()) as sess: 299 unused_v = variables.Variable(1.0, name="v") 300 variables.global_variables_initializer().run() 301 saver = saver_module.Saver(write_version=version) 302 prefixes.append( 303 saver.save(sess, os.path.join(self._base_dir, str(version)))) 304 305 mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes) 306 self.assertEqual(2, len(mtimes)) 307 self.assertTrue(mtimes[1] >= mtimes[0]) 308 309 @test_util.run_deprecated_v1 310 def testRemoveCheckpoint(self): 311 for sharded in (False, True): 312 for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1): 313 with self.session(graph=ops_lib.Graph()) as sess: 314 unused_v = variables.Variable(1.0, name="v") 315 variables.global_variables_initializer().run() 316 saver = saver_module.Saver(sharded=sharded, write_version=version) 317 318 path = os.path.join(self._base_dir, "%s-%s" % (sharded, version)) 319 ckpt_prefix = saver.save(sess, path) 320 self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix)) 321 checkpoint_management.remove_checkpoint(ckpt_prefix, version) 322 self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix)) 323 324 325class CheckpointManagerTest(test.TestCase): 326 327 @test_util.run_in_graph_and_eager_modes 328 def testDeletion(self): 329 checkpoint = util.Checkpoint() 330 manager = checkpoint_management.CheckpointManager( 331 checkpoint, self.get_temp_dir(), max_to_keep=3) 332 first_path = manager.save() 333 second_path = manager.save() 334 third_path = manager.save() 335 fourth_path = manager.save() 336 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 337 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 338 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 339 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 340 341 @test_util.run_in_graph_and_eager_modes 342 def testKeepAll(self): 343 checkpoint = util.Checkpoint() 344 directory = os.path.join( 345 self.get_temp_dir(), 346 # Avoid sharing directories between eager and graph 347 # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories 348 str(context.executing_eagerly())) 349 manager = checkpoint_management.CheckpointManager( 350 checkpoint, directory, max_to_keep=None) 351 first_path = manager.save() 352 second_path = manager.save() 353 third_path = manager.save() 354 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 355 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 356 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 357 self.assertEqual(third_path, manager.latest_checkpoint) 358 self.assertEqual([first_path, second_path, third_path], 359 manager.checkpoints) 360 del manager 361 manager = checkpoint_management.CheckpointManager( 362 checkpoint, directory, max_to_keep=None) 363 fourth_path = manager.save() 364 self.assertEqual([first_path, second_path, third_path, fourth_path], 365 manager.checkpoints) 366 del manager 367 manager = checkpoint_management.CheckpointManager( 368 checkpoint, directory, max_to_keep=3) 369 self.assertEqual([first_path, second_path, third_path, fourth_path], 370 manager.checkpoints) 371 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 372 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 373 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 374 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 375 fifth_path = manager.save() 376 self.assertEqual([third_path, fourth_path, fifth_path], 377 manager.checkpoints) 378 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) 379 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 380 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 381 self.assertFalse(checkpoint_management.checkpoint_exists(second_path)) 382 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 383 384 @test_util.run_in_graph_and_eager_modes 385 @test.mock.patch.object(checkpoint_management, "time") 386 def testSaveRestoreState(self, mock_time): 387 directory = self.get_temp_dir() 388 mock_time.time.return_value = 3. 389 checkpoint = util.Checkpoint() 390 first_manager = checkpoint_management.CheckpointManager( 391 checkpoint, directory, max_to_keep=2) 392 first_time = 10000. 393 first_name = os.path.join(directory, "ckpt-1") 394 mock_time.time.return_value = first_time 395 first_manager.save() 396 state = checkpoint_management.get_checkpoint_state(directory) 397 second_time = first_time + 3610. 398 second_name = os.path.join(directory, "ckpt-2") 399 mock_time.time.return_value = second_time 400 first_manager.save() 401 state = checkpoint_management.get_checkpoint_state(directory) 402 self.assertEqual([first_time, second_time], 403 state.all_model_checkpoint_timestamps) 404 self.assertEqual([first_name, second_name], first_manager.checkpoints) 405 self.assertEqual(second_name, first_manager.latest_checkpoint) 406 del first_manager 407 408 second_manager = checkpoint_management.CheckpointManager( 409 checkpoint, directory, 410 max_to_keep=2, keep_checkpoint_every_n_hours=1.5) 411 self.assertEqual([first_name, second_name], second_manager.checkpoints) 412 self.assertEqual(second_name, second_manager.latest_checkpoint) 413 third_name = os.path.join(directory, "ckpt-3") 414 third_time = second_time + 3600. * 0.2 415 mock_time.time.return_value = third_time 416 second_manager.save() 417 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 418 self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) 419 self.assertEqual([second_name, third_name], 420 second_manager.checkpoints) 421 state = checkpoint_management.get_checkpoint_state(directory) 422 self.assertEqual(first_time, state.last_preserved_timestamp) 423 fourth_time = third_time + 3600. * 0.5 424 mock_time.time.return_value = fourth_time 425 fourth_name = os.path.join(directory, "ckpt-4") 426 second_manager.save() 427 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 428 self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) 429 self.assertEqual([third_name, fourth_name], 430 second_manager.checkpoints) 431 fifth_time = fourth_time + 3600. * 0.5 432 mock_time.time.return_value = fifth_time 433 fifth_name = os.path.join(directory, "ckpt-5") 434 second_manager.save() 435 self.assertEqual([fourth_name, fifth_name], 436 second_manager.checkpoints) 437 state = checkpoint_management.get_checkpoint_state(directory) 438 self.assertEqual(first_time, state.last_preserved_timestamp) 439 del second_manager 440 third_manager = checkpoint_management.CheckpointManager( 441 checkpoint, directory, 442 max_to_keep=2, keep_checkpoint_every_n_hours=1.5) 443 self.assertEqual(fifth_name, third_manager.latest_checkpoint) 444 mock_time.time.return_value += 10. 445 third_manager.save() 446 sixth_name = os.path.join(directory, "ckpt-6") 447 state = checkpoint_management.get_checkpoint_state(directory) 448 self.assertEqual(fourth_time, state.last_preserved_timestamp) 449 self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) 450 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) 451 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) 452 self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) 453 self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) 454 self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) 455 self.assertEqual([fifth_name, sixth_name], 456 third_manager.checkpoints) 457 458 @test_util.run_in_graph_and_eager_modes 459 def testContinueFromUnmanaged(self): 460 directory = self.get_temp_dir() 461 prefix = os.path.join(directory, "unusual_prefix") 462 checkpoint = util.Checkpoint() 463 first_path = checkpoint.save(prefix) 464 second_path = checkpoint.save(prefix) 465 del checkpoint 466 checkpoint = util.Checkpoint() 467 manager = checkpoint_management.CheckpointManager( 468 checkpoint, directory, max_to_keep=2) 469 checkpoint.restore(manager.latest_checkpoint).run_restore_ops() 470 self.assertEqual(2, self.evaluate(checkpoint.save_counter)) 471 third_path = manager.save() 472 self.assertEqual([third_path], manager.checkpoints) 473 fourth_path = manager.save() 474 self.assertEqual([third_path, fourth_path], 475 manager.checkpoints) 476 fifth_path = manager.save() 477 self.assertEqual([fourth_path, fifth_path], 478 manager.checkpoints) 479 self.assertTrue(checkpoint_management.checkpoint_exists(first_path)) 480 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 481 self.assertFalse(checkpoint_management.checkpoint_exists(third_path)) 482 self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path)) 483 self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path)) 484 485 @test_util.run_in_graph_and_eager_modes 486 @test.mock.patch.object(checkpoint_management, "time") 487 def testClockReset(self, mock_time): 488 directory = self.get_temp_dir() 489 mock_time.time.return_value = 10000. 490 checkpoint = util.Checkpoint() 491 first_manager = checkpoint_management.CheckpointManager( 492 checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) 493 first_path = first_manager.save() 494 mock_time.time.return_value += 3600. 495 second_path = first_manager.save() 496 mock_time.time.return_value += 3600. 497 third_path = first_manager.save() 498 self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) 499 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 500 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 501 self.assertEqual([third_path], first_manager.checkpoints) 502 state = checkpoint_management.get_checkpoint_state(directory) 503 self.assertEqual(13600., state.last_preserved_timestamp) 504 # Set the clock back in time 505 mock_time.time.return_value = 5000. 506 del first_manager 507 with test.mock.patch.object(logging, "warning") as mock_log: 508 second_manager = checkpoint_management.CheckpointManager( 509 checkpoint, directory, max_to_keep=1) 510 self.assertRegexpMatches( 511 str(mock_log.call_args), 512 "behind the last preserved checkpoint timestamp") 513 # We should err on the side of keeping checkpoints around when we're not 514 # sure whether they were preserved or not due to clock funkiness. 515 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 516 # We know about the existing checkpoints, but they'll never be deleted and 517 # so won't go in the CheckpointState proto on save. 518 self.assertEqual(third_path, second_manager.latest_checkpoint) 519 self.assertEqual([], second_manager.checkpoints) 520 mock_time.time.return_value += 10. 521 fourth_path = second_manager.save() 522 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 523 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 524 self.assertEqual(fourth_path, second_manager.latest_checkpoint) 525 self.assertEqual([fourth_path], second_manager.checkpoints) 526 mock_time.time.return_value += 10. 527 fifth_path = second_manager.save() 528 self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) 529 self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) 530 self.assertEqual([fifth_path], second_manager.checkpoints) 531 state = checkpoint_management.get_checkpoint_state(directory) 532 self.assertEqual(5000., state.last_preserved_timestamp) 533 self.assertEqual([5020.], 534 state.all_model_checkpoint_timestamps) 535 536 @test_util.run_in_graph_and_eager_modes 537 def testCustomNumbering(self): 538 directory = self.get_temp_dir() 539 step = variables.Variable(0, dtype=dtypes.int64) 540 checkpoint = util.Checkpoint(step=step) 541 manager = checkpoint_management.CheckpointManager( 542 checkpoint, directory, max_to_keep=2) 543 self.evaluate(step.initializer) 544 for i in range(5): 545 path = manager.save(checkpoint_number=step) 546 expected_suffix = "-%d" % (2 * i,) 547 if not path.endswith(expected_suffix): 548 self.fail("%s should have suffix %s" % (path, expected_suffix)) 549 self.evaluate(step.assign_add(2)) 550 self.assertEqual(5, self.evaluate(checkpoint.save_counter)) 551 # Test regular integers 552 last_path = manager.save(checkpoint_number=32) 553 self.assertIn("-32", last_path) 554 self.assertEqual(last_path, manager.latest_checkpoint) 555 self.assertEqual( 556 last_path, checkpoint_management.latest_checkpoint(directory)) 557 state = checkpoint_management.get_checkpoint_state(directory) 558 # Only the most recent two checkpoints are saved 559 self.assertEqual([path, last_path], state.all_model_checkpoint_paths) 560 561 562if __name__ == "__main__": 563 test.main() 564