1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.training.saver.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import os
23import shutil
24import tempfile
25
26from google.protobuf import text_format
27
28from tensorflow.core.protobuf import saver_pb2
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops as ops_lib
32from tensorflow.python.framework import test_util
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.ops import variables
35from tensorflow.python.platform import gfile
36from tensorflow.python.platform import test
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training import checkpoint_management
39from tensorflow.python.training import saver as saver_module
40from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
41from tensorflow.python.training.tracking import util
42
43
44class LatestCheckpointWithRelativePaths(test.TestCase):
45
46  @staticmethod
47  @contextlib.contextmanager
48  def tempWorkingDir(temppath):
49    cwd = os.getcwd()
50    os.chdir(temppath)
51    try:
52      yield
53    finally:
54      os.chdir(cwd)
55
56  @staticmethod
57  @contextlib.contextmanager
58  def tempDir():
59    tempdir = tempfile.mkdtemp()
60    try:
61      yield tempdir
62    finally:
63      shutil.rmtree(tempdir)
64
65  @test_util.run_deprecated_v1
66  def testNameCollision(self):
67    # Make sure we have a clean directory to work in.
68    with self.tempDir() as tempdir:
69      # Jump to that directory until this test is done.
70      with self.tempWorkingDir(tempdir):
71        # Save training snapshots to a relative path.
72        traindir = "train/"
73        os.mkdir(traindir)
74        # Collides with the default name of the checkpoint state file.
75        filepath = os.path.join(traindir, "checkpoint")
76
77        with self.cached_session() as sess:
78          unused_a = variables.Variable(0.0)  # So that Saver saves something.
79          variables.global_variables_initializer().run()
80
81          # Should fail.
82          saver = saver_module.Saver(sharded=False)
83          with self.assertRaisesRegexp(ValueError, "collides with"):
84            saver.save(sess, filepath)
85
86          # Succeeds: the file will be named "checkpoint-<step>".
87          saver.save(sess, filepath, global_step=1)
88          self.assertIsNotNone(
89              checkpoint_management.latest_checkpoint(traindir))
90
91          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
92          saver = saver_module.Saver(sharded=True)
93          saver.save(sess, filepath)
94          self.assertIsNotNone(
95              checkpoint_management.latest_checkpoint(traindir))
96
97          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
98          saver = saver_module.Saver(sharded=True)
99          saver.save(sess, filepath, global_step=1)
100          self.assertIsNotNone(
101              checkpoint_management.latest_checkpoint(traindir))
102
103  @test_util.run_deprecated_v1
104  def testRelativePath(self):
105    # Make sure we have a clean directory to work in.
106    with self.tempDir() as tempdir:
107
108      # Jump to that directory until this test is done.
109      with self.tempWorkingDir(tempdir):
110
111        # Save training snapshots to a relative path.
112        traindir = "train/"
113        os.mkdir(traindir)
114
115        filename = "snapshot"
116        filepath = os.path.join(traindir, filename)
117
118        with self.cached_session() as sess:
119          # Build a simple graph.
120          v0 = variables.Variable(0.0)
121          inc = v0.assign_add(1.0)
122
123          save = saver_module.Saver({"v0": v0})
124
125          # Record a short training history.
126          variables.global_variables_initializer().run()
127          save.save(sess, filepath, global_step=0)
128          self.evaluate(inc)
129          save.save(sess, filepath, global_step=1)
130          self.evaluate(inc)
131          save.save(sess, filepath, global_step=2)
132
133        with self.cached_session() as sess:
134          # Build a new graph with different initialization.
135          v0 = variables.Variable(-1.0)
136
137          # Create a new saver.
138          save = saver_module.Saver({"v0": v0})
139          variables.global_variables_initializer().run()
140
141          # Get the most recent checkpoint name from the training history file.
142          name = checkpoint_management.latest_checkpoint(traindir)
143          self.assertIsNotNone(name)
144
145          # Restore "v0" from that checkpoint.
146          save.restore(sess, name)
147          self.assertEqual(v0.eval(), 2.0)
148
149
150class CheckpointStateTest(test.TestCase):
151
152  def _get_test_dir(self, dirname):
153    test_dir = os.path.join(self.get_temp_dir(), dirname)
154    gfile.MakeDirs(test_dir)
155    return test_dir
156
157  def testAbsPath(self):
158    save_dir = self._get_test_dir("abs_paths")
159    abs_path = os.path.join(save_dir, "model-0")
160    ckpt = checkpoint_management.generate_checkpoint_state_proto(
161        save_dir, abs_path)
162    self.assertEqual(ckpt.model_checkpoint_path, abs_path)
163    self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
164    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
165    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
166
167  def testRelPath(self):
168    train_dir = "train"
169    model = os.path.join(train_dir, "model-0")
170    # model_checkpoint_path should have no "train" directory part.
171    new_rel_path = "model-0"
172    ckpt = checkpoint_management.generate_checkpoint_state_proto(
173        train_dir, model)
174    self.assertEqual(ckpt.model_checkpoint_path, new_rel_path)
175    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 1)
176    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], new_rel_path)
177
178  def testAllModelCheckpointPaths(self):
179    save_dir = self._get_test_dir("all_models_test")
180    abs_path = os.path.join(save_dir, "model-0")
181    for paths in [None, [], ["model-2"]]:
182      ckpt = checkpoint_management.generate_checkpoint_state_proto(
183          save_dir, abs_path, all_model_checkpoint_paths=paths)
184      self.assertEqual(ckpt.model_checkpoint_path, abs_path)
185      self.assertTrue(os.path.isabs(ckpt.model_checkpoint_path))
186      self.assertEqual(
187          len(ckpt.all_model_checkpoint_paths), len(paths) if paths else 1)
188      self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path)
189
190  def testUpdateCheckpointState(self):
191    save_dir = self._get_test_dir("update_checkpoint_state")
192    os.chdir(save_dir)
193    # Make a temporary train directory.
194    train_dir = "train"
195    os.mkdir(train_dir)
196    abs_path = os.path.join(save_dir, "model-0")
197    rel_path = os.path.join("train", "model-2")
198    checkpoint_management.update_checkpoint_state(
199        train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path])
200    ckpt = checkpoint_management.get_checkpoint_state(train_dir)
201    self.assertEqual(ckpt.model_checkpoint_path, rel_path)
202    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
203    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
204    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
205
206  def testUpdateCheckpointStateSaveRelativePaths(self):
207    save_dir = self._get_test_dir("update_checkpoint_state")
208    os.chdir(save_dir)
209    abs_path2 = os.path.join(save_dir, "model-2")
210    rel_path2 = "model-2"
211    abs_path0 = os.path.join(save_dir, "model-0")
212    rel_path0 = "model-0"
213    checkpoint_management.update_checkpoint_state_internal(
214        save_dir=save_dir,
215        model_checkpoint_path=abs_path2,
216        all_model_checkpoint_paths=[rel_path0, abs_path2],
217        save_relative_paths=True)
218
219    # File should contain relative paths.
220    file_content = file_io.read_file_to_string(
221        os.path.join(save_dir, "checkpoint"))
222    ckpt = CheckpointState()
223    text_format.Merge(file_content, ckpt)
224    self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
225    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
226    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
227    self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)
228
229    # get_checkpoint_state should return absolute paths.
230    ckpt = checkpoint_management.get_checkpoint_state(save_dir)
231    self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
232    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
233    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
234    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
235
236  def testCheckPointStateFailsWhenIncomplete(self):
237    save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
238    os.chdir(save_dir)
239    ckpt_path = os.path.join(save_dir, "checkpoint")
240    ckpt_file = open(ckpt_path, "w")
241    ckpt_file.write("")
242    ckpt_file.close()
243    with self.assertRaises(ValueError):
244      checkpoint_management.get_checkpoint_state(save_dir)
245
246  def testCheckPointCompletesRelativePaths(self):
247    save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
248    os.chdir(save_dir)
249    ckpt_path = os.path.join(save_dir, "checkpoint")
250    ckpt_file = open(ckpt_path, "w")
251    ckpt_file.write("""
252        model_checkpoint_path: "./model.ckpt-687529"
253        all_model_checkpoint_paths: "./model.ckpt-687500"
254        all_model_checkpoint_paths: "./model.ckpt-687529"
255        """)
256    ckpt_file.close()
257    ckpt = checkpoint_management.get_checkpoint_state(save_dir)
258    self.assertEqual(ckpt.model_checkpoint_path,
259                     os.path.join(save_dir, "./model.ckpt-687529"))
260    self.assertEqual(ckpt.all_model_checkpoint_paths[0],
261                     os.path.join(save_dir, "./model.ckpt-687500"))
262    self.assertEqual(ckpt.all_model_checkpoint_paths[1],
263                     os.path.join(save_dir, "./model.ckpt-687529"))
264
265
266class SaverUtilsTest(test.TestCase):
267
268  def setUp(self):
269    self._base_dir = os.path.join(self.get_temp_dir(), "saver_utils_test")
270    gfile.MakeDirs(self._base_dir)
271
272  def tearDown(self):
273    gfile.DeleteRecursively(self._base_dir)
274
275  @test_util.run_deprecated_v1
276  def testCheckpointExists(self):
277    for sharded in (False, True):
278      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
279        with self.session(graph=ops_lib.Graph()) as sess:
280          unused_v = variables.Variable(1.0, name="v")
281          variables.global_variables_initializer().run()
282          saver = saver_module.Saver(sharded=sharded, write_version=version)
283
284          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
285          self.assertFalse(
286              checkpoint_management.checkpoint_exists(path))  # Not saved yet.
287
288          ckpt_prefix = saver.save(sess, path)
289          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
290
291          ckpt_prefix = checkpoint_management.latest_checkpoint(self._base_dir)
292          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
293
294  @test_util.run_deprecated_v1
295  def testGetCheckpointMtimes(self):
296    prefixes = []
297    for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
298      with self.session(graph=ops_lib.Graph()) as sess:
299        unused_v = variables.Variable(1.0, name="v")
300        variables.global_variables_initializer().run()
301        saver = saver_module.Saver(write_version=version)
302        prefixes.append(
303            saver.save(sess, os.path.join(self._base_dir, str(version))))
304
305    mtimes = checkpoint_management.get_checkpoint_mtimes(prefixes)
306    self.assertEqual(2, len(mtimes))
307    self.assertTrue(mtimes[1] >= mtimes[0])
308
309  @test_util.run_deprecated_v1
310  def testRemoveCheckpoint(self):
311    for sharded in (False, True):
312      for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
313        with self.session(graph=ops_lib.Graph()) as sess:
314          unused_v = variables.Variable(1.0, name="v")
315          variables.global_variables_initializer().run()
316          saver = saver_module.Saver(sharded=sharded, write_version=version)
317
318          path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
319          ckpt_prefix = saver.save(sess, path)
320          self.assertTrue(checkpoint_management.checkpoint_exists(ckpt_prefix))
321          checkpoint_management.remove_checkpoint(ckpt_prefix, version)
322          self.assertFalse(checkpoint_management.checkpoint_exists(ckpt_prefix))
323
324
325class CheckpointManagerTest(test.TestCase):
326
327  @test_util.run_in_graph_and_eager_modes
328  def testDeletion(self):
329    checkpoint = util.Checkpoint()
330    manager = checkpoint_management.CheckpointManager(
331        checkpoint, self.get_temp_dir(), max_to_keep=3)
332    first_path = manager.save()
333    second_path = manager.save()
334    third_path = manager.save()
335    fourth_path = manager.save()
336    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
337    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
338    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
339    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
340
341  @test_util.run_in_graph_and_eager_modes
342  def testKeepAll(self):
343    checkpoint = util.Checkpoint()
344    directory = os.path.join(
345        self.get_temp_dir(),
346        # Avoid sharing directories between eager and graph
347        # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
348        str(context.executing_eagerly()))
349    manager = checkpoint_management.CheckpointManager(
350        checkpoint, directory, max_to_keep=None)
351    first_path = manager.save()
352    second_path = manager.save()
353    third_path = manager.save()
354    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
355    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
356    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
357    self.assertEqual(third_path, manager.latest_checkpoint)
358    self.assertEqual([first_path, second_path, third_path],
359                     manager.checkpoints)
360    del manager
361    manager = checkpoint_management.CheckpointManager(
362        checkpoint, directory, max_to_keep=None)
363    fourth_path = manager.save()
364    self.assertEqual([first_path, second_path, third_path, fourth_path],
365                     manager.checkpoints)
366    del manager
367    manager = checkpoint_management.CheckpointManager(
368        checkpoint, directory, max_to_keep=3)
369    self.assertEqual([first_path, second_path, third_path, fourth_path],
370                     manager.checkpoints)
371    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
372    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
373    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
374    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
375    fifth_path = manager.save()
376    self.assertEqual([third_path, fourth_path, fifth_path],
377                     manager.checkpoints)
378    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
379    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
380    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
381    self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
382    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
383
384  @test_util.run_in_graph_and_eager_modes
385  @test.mock.patch.object(checkpoint_management, "time")
386  def testSaveRestoreState(self, mock_time):
387    directory = self.get_temp_dir()
388    mock_time.time.return_value = 3.
389    checkpoint = util.Checkpoint()
390    first_manager = checkpoint_management.CheckpointManager(
391        checkpoint, directory, max_to_keep=2)
392    first_time = 10000.
393    first_name = os.path.join(directory, "ckpt-1")
394    mock_time.time.return_value = first_time
395    first_manager.save()
396    state = checkpoint_management.get_checkpoint_state(directory)
397    second_time = first_time + 3610.
398    second_name = os.path.join(directory, "ckpt-2")
399    mock_time.time.return_value = second_time
400    first_manager.save()
401    state = checkpoint_management.get_checkpoint_state(directory)
402    self.assertEqual([first_time, second_time],
403                     state.all_model_checkpoint_timestamps)
404    self.assertEqual([first_name, second_name], first_manager.checkpoints)
405    self.assertEqual(second_name, first_manager.latest_checkpoint)
406    del first_manager
407
408    second_manager = checkpoint_management.CheckpointManager(
409        checkpoint, directory,
410        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
411    self.assertEqual([first_name, second_name], second_manager.checkpoints)
412    self.assertEqual(second_name, second_manager.latest_checkpoint)
413    third_name = os.path.join(directory, "ckpt-3")
414    third_time = second_time + 3600. * 0.2
415    mock_time.time.return_value = third_time
416    second_manager.save()
417    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
418    self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
419    self.assertEqual([second_name, third_name],
420                     second_manager.checkpoints)
421    state = checkpoint_management.get_checkpoint_state(directory)
422    self.assertEqual(first_time, state.last_preserved_timestamp)
423    fourth_time = third_time + 3600. * 0.5
424    mock_time.time.return_value = fourth_time
425    fourth_name = os.path.join(directory, "ckpt-4")
426    second_manager.save()
427    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
428    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
429    self.assertEqual([third_name, fourth_name],
430                     second_manager.checkpoints)
431    fifth_time = fourth_time + 3600. * 0.5
432    mock_time.time.return_value = fifth_time
433    fifth_name = os.path.join(directory, "ckpt-5")
434    second_manager.save()
435    self.assertEqual([fourth_name, fifth_name],
436                     second_manager.checkpoints)
437    state = checkpoint_management.get_checkpoint_state(directory)
438    self.assertEqual(first_time, state.last_preserved_timestamp)
439    del second_manager
440    third_manager = checkpoint_management.CheckpointManager(
441        checkpoint, directory,
442        max_to_keep=2, keep_checkpoint_every_n_hours=1.5)
443    self.assertEqual(fifth_name, third_manager.latest_checkpoint)
444    mock_time.time.return_value += 10.
445    third_manager.save()
446    sixth_name = os.path.join(directory, "ckpt-6")
447    state = checkpoint_management.get_checkpoint_state(directory)
448    self.assertEqual(fourth_time, state.last_preserved_timestamp)
449    self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
450    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
451    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
452    self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
453    self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
454    self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
455    self.assertEqual([fifth_name, sixth_name],
456                     third_manager.checkpoints)
457
458  @test_util.run_in_graph_and_eager_modes
459  def testContinueFromUnmanaged(self):
460    directory = self.get_temp_dir()
461    prefix = os.path.join(directory, "unusual_prefix")
462    checkpoint = util.Checkpoint()
463    first_path = checkpoint.save(prefix)
464    second_path = checkpoint.save(prefix)
465    del checkpoint
466    checkpoint = util.Checkpoint()
467    manager = checkpoint_management.CheckpointManager(
468        checkpoint, directory, max_to_keep=2)
469    checkpoint.restore(manager.latest_checkpoint).run_restore_ops()
470    self.assertEqual(2, self.evaluate(checkpoint.save_counter))
471    third_path = manager.save()
472    self.assertEqual([third_path], manager.checkpoints)
473    fourth_path = manager.save()
474    self.assertEqual([third_path, fourth_path],
475                     manager.checkpoints)
476    fifth_path = manager.save()
477    self.assertEqual([fourth_path, fifth_path],
478                     manager.checkpoints)
479    self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
480    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
481    self.assertFalse(checkpoint_management.checkpoint_exists(third_path))
482    self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
483    self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
484
485  @test_util.run_in_graph_and_eager_modes
486  @test.mock.patch.object(checkpoint_management, "time")
487  def testClockReset(self, mock_time):
488    directory = self.get_temp_dir()
489    mock_time.time.return_value = 10000.
490    checkpoint = util.Checkpoint()
491    first_manager = checkpoint_management.CheckpointManager(
492        checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.)
493    first_path = first_manager.save()
494    mock_time.time.return_value += 3600.
495    second_path = first_manager.save()
496    mock_time.time.return_value += 3600.
497    third_path = first_manager.save()
498    self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
499    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
500    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
501    self.assertEqual([third_path], first_manager.checkpoints)
502    state = checkpoint_management.get_checkpoint_state(directory)
503    self.assertEqual(13600., state.last_preserved_timestamp)
504    # Set the clock back in time
505    mock_time.time.return_value = 5000.
506    del first_manager
507    with test.mock.patch.object(logging, "warning") as mock_log:
508      second_manager = checkpoint_management.CheckpointManager(
509          checkpoint, directory, max_to_keep=1)
510      self.assertRegexpMatches(
511          str(mock_log.call_args),
512          "behind the last preserved checkpoint timestamp")
513    # We should err on the side of keeping checkpoints around when we're not
514    # sure whether they were preserved or not due to clock funkiness.
515    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
516    # We know about the existing checkpoints, but they'll never be deleted and
517    # so won't go in the CheckpointState proto on save.
518    self.assertEqual(third_path, second_manager.latest_checkpoint)
519    self.assertEqual([], second_manager.checkpoints)
520    mock_time.time.return_value += 10.
521    fourth_path = second_manager.save()
522    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
523    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
524    self.assertEqual(fourth_path, second_manager.latest_checkpoint)
525    self.assertEqual([fourth_path], second_manager.checkpoints)
526    mock_time.time.return_value += 10.
527    fifth_path = second_manager.save()
528    self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
529    self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
530    self.assertEqual([fifth_path], second_manager.checkpoints)
531    state = checkpoint_management.get_checkpoint_state(directory)
532    self.assertEqual(5000., state.last_preserved_timestamp)
533    self.assertEqual([5020.],
534                     state.all_model_checkpoint_timestamps)
535
536  @test_util.run_in_graph_and_eager_modes
537  def testCustomNumbering(self):
538    directory = self.get_temp_dir()
539    step = variables.Variable(0, dtype=dtypes.int64)
540    checkpoint = util.Checkpoint(step=step)
541    manager = checkpoint_management.CheckpointManager(
542        checkpoint, directory, max_to_keep=2)
543    self.evaluate(step.initializer)
544    for i in range(5):
545      path = manager.save(checkpoint_number=step)
546      expected_suffix = "-%d" % (2 * i,)
547      if not path.endswith(expected_suffix):
548        self.fail("%s should have suffix %s" % (path, expected_suffix))
549      self.evaluate(step.assign_add(2))
550    self.assertEqual(5, self.evaluate(checkpoint.save_counter))
551    # Test regular integers
552    last_path = manager.save(checkpoint_number=32)
553    self.assertIn("-32", last_path)
554    self.assertEqual(last_path, manager.latest_checkpoint)
555    self.assertEqual(
556        last_path, checkpoint_management.latest_checkpoint(directory))
557    state = checkpoint_management.get_checkpoint_state(directory)
558    # Only the most recent two checkpoints are saved
559    self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
560
561
562if __name__ == "__main__":
563  test.main()
564