1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tensorflow.python.training.saver.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import math
23import os
24import random
25import time
26
27import numpy as np
28import six
29
30from google.protobuf.any_pb2 import Any
31
32from tensorflow.core.protobuf import config_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import queue_runner_pb2
35from tensorflow.core.protobuf import rewriter_config_pb2
36from tensorflow.core.protobuf import saver_pb2
37from tensorflow.python import pywrap_tensorflow
38from tensorflow.python.client import session
39from tensorflow.python.data.ops import dataset_ops
40from tensorflow.python.eager import context
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import errors
44from tensorflow.python.framework import errors_impl
45from tensorflow.python.framework import function
46from tensorflow.python.framework import graph_io
47from tensorflow.python.framework import meta_graph
48from tensorflow.python.framework import ops as ops_lib
49from tensorflow.python.framework import test_util
50from tensorflow.python.keras.engine import training
51from tensorflow.python.keras.layers import core
52from tensorflow.python.lib.io import file_io
53from tensorflow.python.ops import array_ops
54from tensorflow.python.ops import control_flow_ops
55from tensorflow.python.ops import data_flow_ops
56from tensorflow.python.ops import gradients_impl
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import nn_ops
59from tensorflow.python.ops import partitioned_variables
60from tensorflow.python.ops import random_ops
61from tensorflow.python.ops import resource_variable_ops
62from tensorflow.python.ops import sparse_ops
63from tensorflow.python.ops import variable_scope
64from tensorflow.python.ops import variables
65import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
66from tensorflow.python.platform import gfile
67from tensorflow.python.platform import test
68from tensorflow.python.summary import summary
69from tensorflow.python.training import adam
70from tensorflow.python.training import checkpoint_management
71from tensorflow.python.training import gradient_descent
72from tensorflow.python.training import queue_runner_impl
73from tensorflow.python.training import saver as saver_module
74from tensorflow.python.training import saver_test_utils
75from tensorflow.python.training import training_util
76from tensorflow.python.training.tracking import base as trackable_base
77from tensorflow.python.training.tracking import tracking as trackable_tracking
78from tensorflow.python.training.tracking import util as trackable_utils
79from tensorflow.python.util import compat
80
81
82class SaverTest(test.TestCase):
83
84  def basicSaveRestore(self, variable_op):
85    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
86
87    with self.session(graph=ops_lib.Graph()) as sess:
88      # Build a graph with 2 parameter nodes, and Save and
89      # Restore nodes for them.
90      v0 = variable_op(10.0, name="v0")
91      v1 = variable_op(20.0, name="v1")
92      v2 = saver_test_utils.CheckpointedOp(name="v2")
93      v2_init = v2.insert("k1", 30.0)
94
95      # Initialize all variables
96      if not context.executing_eagerly():
97        self.evaluate([variables.global_variables_initializer(), v2_init])
98
99        # Check that the parameter nodes have been initialized.
100      self.assertEqual(10.0, self.evaluate(v0))
101      self.assertEqual(20.0, self.evaluate(v1))
102      self.assertEqual(b"k1", self.evaluate(v2.keys()))
103      self.assertEqual(30.0, self.evaluate(v2.values()))
104
105      # Save the initialized values in the file at "save_path"
106      save = saver_module.Saver(
107          {
108              "v0": v0,
109              "v1": v1,
110              "v2": v2.saveable
111          }, restore_sequentially=True)
112      val = save.save(sess, save_path)
113      self.assertTrue(isinstance(val, six.string_types))
114      self.assertEqual(save_path, val)
115
116    # Start a second session.  In that session the parameter nodes
117    # have not been initialized either.
118    with self.session(graph=ops_lib.Graph()) as sess:
119      v0 = variable_op(-1.0, name="v0")
120      v1 = variable_op(-1.0, name="v1")
121      v2 = saver_test_utils.CheckpointedOp(name="v2")
122
123      # Assert that the variables are not initialized.
124      if not context.executing_eagerly():
125        self.assertEqual(
126            len(variables.report_uninitialized_variables().eval()), 2)
127        self.assertEqual(0, len(self.evaluate(v2.keys())))
128        self.assertEqual(0, len(self.evaluate(v2.values())))
129      # Restore the saved values in the parameter nodes.
130      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
131      save.restore(sess, save_path)
132      # Check that the parameter nodes have been restored.
133      self.assertEqual(10.0, self.evaluate(v0))
134      self.assertEqual(20.0, self.evaluate(v1))
135      self.assertEqual(b"k1", self.evaluate(v2.keys()))
136      self.assertEqual(30.0, self.evaluate(v2.values()))
137
138    # Build another graph with 2 nodes, initialized
139    # differently, and a Restore node for them.
140    with self.session(graph=ops_lib.Graph()) as sess:
141      v0_2 = variable_op(1000.0, name="v0")
142      v1_2 = variable_op(2000.0, name="v1")
143      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
144      v2_init = v2_2.insert("k1000", 3000.0)
145
146      # Check that the parameter nodes have been initialized.
147      if not context.executing_eagerly():
148        init_all_op = [variables.global_variables_initializer(), v2_init]
149        self.evaluate(init_all_op)
150        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
151        # table as it claims in eager mode?
152        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
153        self.assertEqual(3000.0, self.evaluate(v2_2.values()))
154      self.assertEqual(1000.0, self.evaluate(v0_2))
155      self.assertEqual(2000.0, self.evaluate(v1_2))
156
157      # Restore the values saved earlier in the parameter nodes.
158      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
159      save2.restore(sess, save_path)
160      # Check that the parameter nodes have been restored.
161      self.assertEqual(10.0, self.evaluate(v0_2))
162      self.assertEqual(20.0, self.evaluate(v1_2))
163      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
164      self.assertEqual(30.0, self.evaluate(v2_2.values()))
165
166  def testBasic(self):
167    self.basicSaveRestore(variables.Variable)
168
169  @test_util.run_in_graph_and_eager_modes
170  def testResourceBasic(self):
171    self.basicSaveRestore(resource_variable_ops.ResourceVariable)
172
173  @test_util.run_deprecated_v1
174  def testResourceColocation(self):
175    partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
176    with ops_lib.device("/job:ps/device:GPU:0"):
177      v = variable_scope.get_variable("v0",
178                                      shape=[10, 2],
179                                      partitioner=partitioner,
180                                      use_resource=True)
181    saver_module.Saver({"v0": v}).build()
182    save_op = None
183    for op in ops_lib.get_default_graph().get_operations():
184      if op.type == "SaveV2":
185        save_op = op
186        break
187    assert save_op is not None
188    for save_inp in save_op.inputs[3:]:
189      # Input to SaveV2 op is placed on CPU of the same device as the Variable.
190      self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
191
192  def testResourceVariableReadOpsAddedDeterministically(self):
193    graph_defs = []
194    num_graphs = 10
195    for _ in range(num_graphs):
196      with ops_lib.Graph().as_default() as g:
197        for i in range(20):
198          resource_variable_ops.ResourceVariable(i, name="var%s" % i)
199        saver_module.Saver()
200        graph_defs.append(g.as_graph_def())
201    for i in range(num_graphs - 1):
202      self.assertEqual(graph_defs[i], graph_defs[i + 1])
203
204  def testEagerBasic(self):
205    with context.eager_mode():
206      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
207
208      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
209      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
210      save = saver_module.Saver([v1, v2])
211      save.save(None, ckpt_prefix)
212
213      v1.assign(0.0)
214      v2.assign([0, 0])
215      self.assertNear(0.0, self.evaluate(v1), 1e-5)
216      self.assertAllEqual([0, 0], self.evaluate(v2))
217
218      save.restore(None, ckpt_prefix)
219      self.assertNear(3.14, self.evaluate(v1), 1e-5)
220      self.assertAllEqual([1, 2], self.evaluate(v2))
221
222  def testEagerGraphCompatibility(self):
223    # Save from graph mode and restore from eager mode.
224    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
225    with context.graph_mode():
226      with self.session(graph=ops_lib.Graph()) as sess:
227        # Create a graph model and save the checkpoint.
228        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
229        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
230        graph_saver = saver_module.Saver([w1, w2])
231        self.evaluate(variables.global_variables_initializer())
232        graph_saver.save(sess, graph_ckpt_prefix)
233
234    with context.eager_mode():
235      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
236      ops_lib.reset_default_graph()
237
238      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
239      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")
240
241      graph_saver = saver_module.Saver([w1, w2])
242      graph_saver.restore(None, graph_ckpt_prefix)
243
244      self.assertAllEqual(self.evaluate(w1), 1.0)
245      self.assertAllEqual(self.evaluate(w2), 2.0)
246
247    # Save from eager mode and restore from graph mode.
248    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
249    with context.eager_mode():
250      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
251      ops_lib.reset_default_graph()
252
253      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
254      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")
255
256      graph_saver = saver_module.Saver([w3, w4])
257      graph_saver.save(None, eager_ckpt_prefix)
258
259    with context.graph_mode():
260      with self.session(graph=ops_lib.Graph()) as sess:
261        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
262        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
263        graph_saver = saver_module.Saver([w3, w4])
264        self.evaluate(variables.global_variables_initializer())
265        graph_saver.restore(sess, eager_ckpt_prefix)
266        self.assertAllEqual(w3.eval(), 3.0)
267        self.assertAllEqual(w4.eval(), 4.0)
268
269  @test_util.run_in_graph_and_eager_modes
270  def testResourceSaveRestoreCachingDevice(self):
271    save_path = os.path.join(self.get_temp_dir(), "resource_cache")
272    with self.session(graph=ops_lib.Graph()) as sess:
273      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0",
274                                                 name="v")
275      if context.executing_eagerly():
276        sess = None
277      else:
278        self.evaluate(variables.global_variables_initializer())
279      save = saver_module.Saver([v])
280      save.save(sess, save_path)
281
282      save2 = saver_module.Saver([v])
283      save2.restore(sess, save_path)
284      self.assertEquals(self.evaluate(v), [1])
285
286  def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self):
287    with ops_lib.Graph().as_default() as g:
288      v = resource_variable_ops.ResourceVariable(1.0, name="v")
289      with ops_lib.name_scope("saver1"):
290        saver_module.Saver()
291      with ops_lib.name_scope("saver2"):
292        saver_module.Saver({"name": v})
293    ops_in_saver1_scope_but_not_save_scope = [
294        op for op in g.get_operations()
295        if (op.name.startswith("saver1/") and
296            not op.name.startswith("saver1/save/"))]
297    self.assertEqual(ops_in_saver1_scope_but_not_save_scope, [])
298    ops_in_saver2_scope_but_not_save_scope = [
299        op for op in g.get_operations()
300        if (op.name.startswith("saver2/") and
301            not op.name.startswith("saver2/save/"))]
302    self.assertEqual(ops_in_saver2_scope_but_not_save_scope, [])
303
304  @test_util.run_deprecated_v1
305  def testSaveCopyRestoreWithSaveRelativePaths(self):
306    """Save, copy checkpoint dir and restore from copied dir.
307
308    This only works for save_relative_paths=True.
309    """
310    save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1")
311    os.mkdir(save_dir1)
312    save_path1 = os.path.join(save_dir1, "save_copy_restore")
313
314    # Build a graph with 2 parameter nodes, and Save and
315    # Restore nodes for them.
316    v0 = variables.VariableV1(10.0, name="v0")
317    v1 = variables.VariableV1(20.0, name="v1")
318    v2 = saver_test_utils.CheckpointedOp(name="v2")
319    v2_init = v2.insert("k1", 30.0)
320    save = saver_module.Saver(
321        var_list={
322            "v0": v0,
323            "v1": v1,
324            "v2": v2.saveable},
325        restore_sequentially=True,
326        save_relative_paths=True)
327    init_all_op = [variables.global_variables_initializer(), v2_init]
328
329    with self.cached_session() as sess:
330      # Initialize all variables
331      self.evaluate(init_all_op)
332
333      # Check that the parameter nodes have been initialized.
334      self.assertEqual(10.0, self.evaluate(v0))
335      self.assertEqual(20.0, self.evaluate(v1))
336      self.assertEqual(b"k1", self.evaluate(v2.keys()))
337      self.assertEqual(30.0, self.evaluate(v2.values()))
338
339      # Save the initialized values in the file at "save_path"
340      val = save.save(sess, save_path1)
341      self.assertTrue(isinstance(val, six.string_types))
342      self.assertEqual(save_path1, val)
343
344    self.assertEqual(
345        checkpoint_management.latest_checkpoint(save_dir1), save_path1)
346    save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2")
347    os.renames(save_dir1, save_dir2)
348    save_path2 = os.path.join(save_dir2, "save_copy_restore")
349    self.assertEqual(
350        checkpoint_management.latest_checkpoint(save_dir2), save_path2)
351
352    # Start a second session.  In that session the parameter nodes
353    # have not been initialized either.
354    with self.cached_session() as sess:
355      v0 = variables.VariableV1(-1.0, name="v0")
356      v1 = variables.VariableV1(-1.0, name="v1")
357      v2 = saver_test_utils.CheckpointedOp(name="v2")
358      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
359
360      # Assert that the variables are not initialized.
361      self.assertEqual(
362          len(variables.report_uninitialized_variables().eval()), 2)
363      self.assertEqual(0, len(self.evaluate(v2.keys())))
364      self.assertEqual(0, len(self.evaluate(v2.values())))
365
366      # Restore the saved values in the parameter nodes.
367      save.restore(sess, save_path2)
368      # Check that the parameter nodes have been restored.
369      self.assertEqual(10.0, self.evaluate(v0))
370      self.assertEqual(20.0, self.evaluate(v1))
371      self.assertEqual(b"k1", self.evaluate(v2.keys()))
372      self.assertEqual(30.0, self.evaluate(v2.values()))
373
374  @test_util.run_deprecated_v1
375  def testFilenameTensor(self):
376    v0 = variables.VariableV1(0, name="v0")
377    filename = b"somerandomfilename"
378    save = saver_module.Saver({"v0": v0}, filename=filename)
379    with self.cached_session() as sess:
380      tensor = sess.graph.get_tensor_by_name(
381          save.saver_def.filename_tensor_name)
382      self.assertEqual(self.evaluate(tensor), filename)
383
384  def testInvalidPath(self):
385    v0 = variables.VariableV1(0, name="v0")
386    for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
387      with self.cached_session() as sess:
388        save = saver_module.Saver({"v0": v0}, write_version=ver)
389        with self.assertRaisesRegexp(
390            ValueError, "The passed save_path is not a valid checkpoint:"):
391          save.restore(sess, "invalid path")
392
393  @test_util.run_v1_only("b/120545219")
394  def testInt64(self):
395    save_path = os.path.join(self.get_temp_dir(), "int64")
396
397    with self.cached_session() as sess:
398      # Build a graph with 1 node, and save and restore for them.
399      v = variables.VariableV1(np.int64(15), name="v")
400      save = saver_module.Saver({"v": v}, restore_sequentially=True)
401      self.evaluate(variables.global_variables_initializer())
402
403      # Save the initialized values in the file at "save_path"
404      val = save.save(sess, save_path)
405      self.assertTrue(isinstance(val, six.string_types))
406      self.assertEqual(save_path, val)
407
408      with self.cached_session() as sess:
409        v = variables.VariableV1(np.int64(-1), name="v")
410        save = saver_module.Saver({"v": v})
411
412      with self.assertRaisesWithPredicateMatch(
413          errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
414        self.evaluate(v)
415
416      # Restore the saved values in the parameter nodes.
417      save.restore(sess, save_path)
418      # Check that the parameter nodes have been restored.
419      self.assertEqual(np.int64(15), self.evaluate(v))
420
421  def testSomeErrors(self):
422    with ops_lib.Graph().as_default():
423      v0 = variables.VariableV1([10.0], name="v0")
424      v1 = variables.VariableV1([20.0], name="v1")
425      v2 = variables.VariableV1([20.0], name="v2")
426      v2._set_save_slice_info(
427          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
428
429      # By default the name used for "v2" will be "v1" and raise an error.
430      with self.assertRaisesRegexp(ValueError, "same name: v1"):
431        saver_module.Saver([v0, v1, v2])
432
433      # The names are different and will work.
434      saver_module.Saver({"vee1": v1, "other": [v2]})
435
436      # Partitioned variables also cause name conflicts.
437      p_v1 = variable_scope.get_variable(
438          "p_v1",
439          shape=[4, 5],
440          partitioner=partitioned_variables.fixed_size_partitioner(
441              num_shards=2))
442      p_v2 = variable_scope.get_variable(
443          "p_v2",
444          shape=[4, 5],
445          partitioner=partitioned_variables.fixed_size_partitioner(
446              num_shards=2))
447      p_v2._name = "p_v1"
448      with self.assertRaisesRegexp(ValueError, "same name: p_v1"):
449        saver_module.Saver([p_v1, p_v2])
450
451  def testSameName(self):
452    with ops_lib.Graph().as_default():
453      v0 = variables.VariableV1([10.0], name="v0")
454      v2 = saver_test_utils.CheckpointedOp(name="v2")
455
456      # Saving one variable under two names raises an error.
457      with self.assertRaisesRegexp(
458          ValueError, "The same saveable will be restored with two names: v0"):
459        saver_module.Saver({"v0": v0, "v0too": v0})
460
461      # Ditto for custom saveables.
462      with self.assertRaisesRegexp(
463          ValueError, "The same saveable will be restored with two names: v2"):
464        saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable})
465
466      # Verify non-duplicate names work.
467      saver_module.Saver({"v0": v0, "v2": v2.saveable})
468
469  @test_util.run_v1_only("b/120545219")
470  def testBasicsWithListOfVariables(self):
471    save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
472
473    with self.session(graph=ops_lib.Graph()) as sess:
474      # Build a graph with 2 parameter nodes, and Save and
475      # Restore nodes for them.
476      v0 = variables.VariableV1(10.0, name="v0")
477      v1 = variables.VariableV1(20.0, name="v1")
478      v2 = saver_test_utils.CheckpointedOp(name="v2")
479      v2_init = v2.insert("k1", 30.0)
480      save = saver_module.Saver([v0, v1, v2.saveable])
481      self.evaluate(variables.global_variables_initializer())
482      v2_init.run()
483
484      # Check that the parameter nodes have been initialized.
485      self.assertEqual(10.0, self.evaluate(v0))
486      self.assertEqual(20.0, self.evaluate(v1))
487      self.assertEqual(b"k1", self.evaluate(v2.keys()))
488      self.assertEqual(30.0, self.evaluate(v2.values()))
489
490      # Save the initialized values in the file at "save_path"
491      val = save.save(sess, save_path)
492      self.assertTrue(isinstance(val, six.string_types))
493      self.assertEqual(save_path, val)
494
495    # Start a second session.  In that session the variables
496    # have not been initialized either.
497    with self.session(graph=ops_lib.Graph()) as sess:
498      v0 = variables.VariableV1(-1.0, name="v0")
499      v1 = variables.VariableV1(-1.0, name="v1")
500      v2 = saver_test_utils.CheckpointedOp(name="v2")
501      save = saver_module.Saver([v0, v1, v2.saveable])
502
503      with self.assertRaisesWithPredicateMatch(
504          errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
505        self.evaluate(v0)
506      with self.assertRaisesWithPredicateMatch(
507          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
508        self.evaluate(v1)
509      self.assertEqual(0, len(self.evaluate(v2.keys())))
510      self.assertEqual(0, len(self.evaluate(v2.values())))
511
512      # Restore the saved values in the parameter nodes.
513      save.restore(sess, save_path)
514      # Check that the parameter nodes have been restored.
515      self.assertEqual(10.0, self.evaluate(v0))
516      self.assertEqual(20.0, self.evaluate(v1))
517      self.assertEqual(b"k1", self.evaluate(v2.keys()))
518      self.assertEqual(30.0, self.evaluate(v2.values()))
519
520    # Build another graph with 2 nodes, initialized
521    # differently, and a Restore node for them.
522    with self.session(graph=ops_lib.Graph()) as sess:
523      v0_2 = variables.VariableV1(1000.0, name="v0")
524      v1_2 = variables.VariableV1(2000.0, name="v1")
525      v2_2 = saver_test_utils.CheckpointedOp(name="v2")
526      save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
527      v2_2.insert("k1000", 3000.0).run()
528      self.evaluate(variables.global_variables_initializer())
529
530      # Check that the parameter nodes have been initialized.
531      self.assertEqual(1000.0, self.evaluate(v0_2))
532      self.assertEqual(2000.0, self.evaluate(v1_2))
533      self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
534      self.assertEqual(3000.0, self.evaluate(v2_2.values()))
535      # Restore the values saved earlier in the parameter nodes.
536      save2.restore(sess, save_path)
537      # Check that the parameter nodes have been restored.
538      self.assertEqual(10.0, self.evaluate(v0_2))
539      self.assertEqual(20.0, self.evaluate(v1_2))
540      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
541      self.assertEqual(30.0, self.evaluate(v2_2.values()))
542
543  def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
544    with self.session(graph=ops_lib.Graph()) as sess:
545      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
546      save = saver_module.Saver({var_name: var})
547      if not context.executing_eagerly():
548        self.evaluate(var.initializer)
549      val = save.save(sess, save_path)
550      self.assertEqual(save_path, val)
551    with self.session(graph=ops_lib.Graph()) as sess:
552      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
553      save = saver_module.Saver({var_name: var})
554      save.restore(sess, save_path)
555      self.assertAllClose(var_value, self.evaluate(var))
556
557  def testCacheRereadsFile(self):
558    save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
559    # Save and reload one Variable named "var0".
560    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
561    # Save and reload one Variable named "var1" in the same file.
562    # The cached readers should know to re-read the file.
563    self._SaveAndLoad("var1", 1.1, 2.2, save_path)
564
565  @test_util.run_deprecated_v1
566  def testAllowEmpty(self):
567    save_path = os.path.join(self.get_temp_dir(), "allow_empty")
568    with self.cached_session() as sess:
569      _ = constant_op.constant(1)
570      save = saver_module.Saver(allow_empty=True)
571      val = save.save(sess, save_path)
572      self.assertIsNone(val)
573    with self.cached_session() as sess:
574      save = saver_module.Saver(allow_empty=True)
575      save.restore(sess, save_path)
576
577  def testGPU(self):
578    if not test.is_gpu_available():
579      return
580    save_path = os.path.join(self.get_temp_dir(), "gpu")
581    with session.Session("", graph=ops_lib.Graph()) as sess:
582      with sess.graph.device(test.gpu_device_name()):
583        v0_1 = variables.VariableV1(123.45)
584      save = saver_module.Saver({"v0": v0_1})
585      self.evaluate(variables.global_variables_initializer())
586      save.save(sess, save_path)
587
588    with session.Session("", graph=ops_lib.Graph()) as sess:
589      with sess.graph.device(test.gpu_device_name()):
590        v0_2 = variables.VariableV1(543.21)
591      save = saver_module.Saver({"v0": v0_2})
592      self.evaluate(variables.global_variables_initializer())
593
594  def testSharedServerOnGPU(self):
595    if not test.is_gpu_available():
596      return
597    save_path = os.path.join(self.get_temp_dir(), "gpu")
598    with session.Session("", graph=ops_lib.Graph()) as sess:
599      with sess.graph.device(test.gpu_device_name()):
600        v0_1 = variables.VariableV1(123.45)
601      save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
602      self.evaluate(variables.global_variables_initializer())
603      save.save(sess, save_path)
604
605    with session.Session("", graph=ops_lib.Graph()) as sess:
606      with sess.graph.device(test.gpu_device_name()):
607        v0_2 = variables.VariableV1(543.21)
608      save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
609      self.evaluate(variables.global_variables_initializer())
610
611  def testVariables(self):
612    save_path = os.path.join(self.get_temp_dir(), "variables")
613    with session.Session("", graph=ops_lib.Graph()) as sess:
614      one = variables.VariableV1(1.0)
615      twos = variables.VariableV1([2.0, 2.0, 2.0])
616      v2 = saver_test_utils.CheckpointedOp(name="v2")
617      init = variables.global_variables_initializer()
618      save = saver_module.Saver()
619      init.run()
620      v2.insert("k1", 3.0).run()
621      save.save(sess, save_path)
622
623    with session.Session("", graph=ops_lib.Graph()) as sess:
624      one = variables.VariableV1(0.0)
625      twos = variables.VariableV1([0.0, 0.0, 0.0])
626      v2 = saver_test_utils.CheckpointedOp(name="v2")
627      # Saver with no arg, defaults to 'all variables'.
628      save = saver_module.Saver()
629      save.restore(sess, save_path)
630      self.assertAllClose(1.0, self.evaluate(one))
631      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
632      self.assertEqual(b"k1", self.evaluate(v2.keys()))
633      self.assertEqual(3.0, self.evaluate(v2.values()))
634
635  def testVarListShouldBeEmptyInDeferredBuild(self):
636    with ops_lib.Graph().as_default():
637      v = variables.VariableV1(1.0)
638      with self.assertRaisesRegexp(ValueError, "defer_build"):
639        saver_module.Saver([v], defer_build=True)
640
641  def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
642    save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
643    with ops_lib.Graph().as_default(), session.Session() as sess:
644      variables.VariableV1(1.0)
645      saver = saver_module.Saver(defer_build=True)
646      with self.assertRaisesRegexp(RuntimeError, "build"):
647        saver.save(sess, save_path)
648
649  def testDeferredBuild(self):
650    save_path = os.path.join(self.get_temp_dir(), "deferred_build")
651    with session.Session("", graph=ops_lib.Graph()) as sess:
652      one = variables.VariableV1(1.0)
653      save = saver_module.Saver(defer_build=True)
654      # if build is not deferred, saver cannot save the `twos`.
655      twos = variables.VariableV1([2.0, 2.0, 2.0])
656      init = variables.global_variables_initializer()
657      save.build()
658      init.run()
659      save.save(sess, save_path)
660
661    with session.Session("", graph=ops_lib.Graph()) as sess:
662      one = variables.VariableV1(0.0)
663      twos = variables.VariableV1([0.0, 0.0, 0.0])
664      # Saver with no arg, defaults to 'all variables'.
665      save = saver_module.Saver()
666      save.restore(sess, save_path)
667      self.assertAllClose(1.0, self.evaluate(one))
668      self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos))
669
670  @test_util.run_v1_only("b/120545219")
671  def testReshape(self):
672    save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
673    with session.Session("", graph=ops_lib.Graph()) as sess:
674      var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
675      init = variables.global_variables_initializer()
676      save = saver_module.Saver()
677      init.run()
678      save.save(sess, save_path)
679
680    # Error when restoring with default reshape=False
681    with session.Session("", graph=ops_lib.Graph()) as sess:
682      var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
683      save = saver_module.Saver()
684      with self.assertRaisesRegexp(
685          errors_impl.InvalidArgumentError,
686          "Assign requires shapes of both tensors to match."):
687        save.restore(sess, save_path)
688
689    # Restored to new shape with reshape=True
690    with session.Session("", graph=ops_lib.Graph()) as sess:
691      var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
692      save = saver_module.Saver(reshape=True)
693      save.restore(sess, save_path)
694      self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
695                          self.evaluate(var))
696
697  @test_util.run_in_graph_and_eager_modes
698  def testSaveWithGlobalStep(self, pad_step_number=False):
699    save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
700    global_step_int = 5
701    # Save and reload one Variable named "var0".
702    self._SaveAndLoad("var0", 0.0, 1.0, save_path)
703    for use_tensor in [True, False]:
704      with self.session(graph=ops_lib.Graph()):
705        var = resource_variable_ops.ResourceVariable(1.0, name="var0")
706        save = saver_module.Saver(
707            {
708                var._shared_name: var
709            }, pad_step_number=pad_step_number)
710        if context.executing_eagerly():
711          sess = None
712        else:
713          self.evaluate(var.initializer)
714          sess = ops_lib.get_default_session()
715        if use_tensor:
716          global_step = constant_op.constant(global_step_int)
717          val = save.save(sess, save_path, global_step=global_step)
718        else:
719          val = save.save(sess, save_path, global_step=global_step_int)
720        if pad_step_number:
721          expected_save_path = "%s-%s" % (save_path,
722                                          "{:08d}".format(global_step_int))
723        else:
724          expected_save_path = "%s-%d" % (save_path, global_step_int)
725        self.assertEqual(expected_save_path, val)
726
727  def testSaveWithGlobalStepWithPadding(self):
728    self.testSaveWithGlobalStep(pad_step_number=True)
729
730  def testSaveToNonexistingPath(self):
731    file_io.write_string_to_file(
732        os.path.join(self.get_temp_dir(), "actually_a_file"), "")
733    paths = [
734        os.path.join(self.get_temp_dir(), "nonexisting_dir/path"),
735        os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"),
736        os.path.join(self.get_temp_dir(), "actually_a_file/path"),
737    ]
738
739    for save_path in paths:
740      # Build a graph with 2 parameter nodes, and Save and
741      # Restore nodes for them.
742      v0 = variables.VariableV1(10.0, name="v0")
743      v1 = variables.VariableV1(20.0, name="v1")
744      save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
745      init_all_op = variables.global_variables_initializer()
746
747      # In the case where the parent directory doesn't exist, whether or not the
748      # save succeeds or fails is implementation dependent.  Therefore we allow
749      # both cases.
750      try:
751        with self.cached_session() as sess:
752          # Initialize all variables
753          self.evaluate(init_all_op)
754
755          # Check that the parameter nodes have been initialized.
756          self.assertEqual(10.0, self.evaluate(v0))
757          self.assertEqual(20.0, self.evaluate(v1))
758
759          # Save the graph.
760          save.save(sess, save_path)
761
762        with self.cached_session() as sess:
763          # Restore the saved values in the parameter nodes.
764          save.restore(sess, save_path)
765          # Check that the parameter nodes have been restored.
766          self.assertEqual(10.0, self.evaluate(v0))
767          self.assertEqual(20.0, self.evaluate(v1))
768      except ValueError as exc:
769        error_msg_template = "Parent directory of {} doesn't exist, can't save."
770        self.assertEqual(error_msg_template.format(save_path), str(exc))
771
772  def testSaveToURI(self):
773    # ParseURI functions don't work on Windows yet.
774    # TODO(jhseu): Remove this check when it works.
775    if os.name == "nt":
776      self.skipTest("Local URI support doesn't work on Windows")
777    save_path = "file://" + os.path.join(self.get_temp_dir(), "uri")
778
779    # Build a graph with 2 parameter nodes, and Save and
780    # Restore nodes for them.
781    v0 = variables.VariableV1(10.0, name="v0")
782    v1 = variables.VariableV1(20.0, name="v1")
783    save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
784    init_all_op = variables.global_variables_initializer()
785
786    with self.cached_session() as sess:
787      # Initialize all variables
788      self.evaluate(init_all_op)
789
790      # Check that the parameter nodes have been initialized.
791      self.assertEqual(10.0, self.evaluate(v0))
792      self.assertEqual(20.0, self.evaluate(v1))
793      save.save(sess, save_path)
794
795  def testSaveRestoreAndValidateVariableDtype(self):
796    for variable_op in [
797        variables.Variable, resource_variable_ops.ResourceVariable
798    ]:
799      save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
800
801      # Build the first session.
802      with self.session(graph=ops_lib.Graph()) as sess:
803        v0 = variable_op(10.0, name="v0", dtype=dtypes.float32)
804
805        if not context.executing_eagerly():
806          self.evaluate([variables.global_variables_initializer()])
807
808        save = saver_module.Saver({"v0": v0})
809        save.save(sess, save_path)
810
811      # Start a second session.
812      with self.session(graph=ops_lib.Graph()) as sess:
813        v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32)
814        # Restore the saved value with different dtype
815        # in the parameter nodes.
816        save = saver_module.Saver({"v0": v0_wrong_dtype})
817        with self.assertRaisesRegexp(errors.InvalidArgumentError,
818                                     "original dtype"):
819          save.restore(sess, save_path)
820
821  # Test restoring large tensors (triggers a thread pool)
822  def testRestoreLargeTensors(self):
823    save_dir = self.get_temp_dir()
824    def _model():
825      small_v = [variable_scope.get_variable(
826          "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)]
827      large_v = [variable_scope.get_variable(
828          "large%d" % i, shape=[32000, 1000], use_resource=True)
829                 for i in range(3)]
830      return small_v + large_v
831
832    save_graph = ops_lib.Graph()
833    with save_graph.as_default(), self.session(graph=save_graph) as sess:
834      orig_vars = _model()
835      self.evaluate(variables.global_variables_initializer())
836      save = saver_module.Saver(max_to_keep=1)
837      self.evaluate(variables.global_variables_initializer())
838      save.save(sess, save_dir)
839      orig_vals = self.evaluate(orig_vars)
840
841    restore_graph = ops_lib.Graph()
842    with restore_graph.as_default(), self.session(
843        graph=restore_graph) as sess:
844      restored_vars = _model()
845      save = saver_module.Saver(max_to_keep=1)
846      save.restore(sess, save_dir)
847      restored_vals = self.evaluate(restored_vars)
848
849    for orig, restored in zip(orig_vals, restored_vals):
850      self.assertAllEqual(orig, restored)
851
852
853class SaveRestoreShardedTest(test.TestCase):
854
855  _WRITE_VERSION = saver_pb2.SaverDef.V1
856
857  def _get_test_dir(self, dirname):
858    test_dir = os.path.join(self.get_temp_dir(), dirname)
859    gfile.MakeDirs(test_dir)
860    return test_dir
861
862  def testBasics(self):
863    save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
864
865    # Build a graph with 2 parameter nodes on different devices.
866    with session.Session(
867        target="",
868        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
869      with sess.graph.device("/cpu:0"):
870        v0 = variables.VariableV1(10, name="v0")
871        t0 = saver_test_utils.CheckpointedOp(name="t0")
872      with sess.graph.device("/cpu:1"):
873        v1 = variables.VariableV1(20, name="v1")
874        t1 = saver_test_utils.CheckpointedOp(name="t1")
875      save = saver_module.Saver(
876          {
877              "v0": v0,
878              "v1": v1,
879              "t0": t0.saveable,
880              "t1": t1.saveable
881          },
882          write_version=self._WRITE_VERSION,
883          sharded=True)
884      self.evaluate(variables.global_variables_initializer())
885      t0.insert("k1", 30.0).run()
886      t1.insert("k2", 40.0).run()
887      val = save.save(sess, save_path)
888      if save._write_version is saver_pb2.SaverDef.V1:
889        self.assertEqual(save_path + "-?????-of-00002", val)
890      else:
891        self.assertEqual(save_path, val)
892      meta_graph_filename = checkpoint_management.meta_graph_filename(val)
893      self.assertEqual(save_path + ".meta", meta_graph_filename)
894
895    if save._write_version is saver_pb2.SaverDef.V1:
896      # Restore different ops from shard 0 of the saved files.
897      with session.Session(
898          target="",
899          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
900        with sess.graph.device("/cpu:0"):
901          v0 = variables.VariableV1(111, name="v0")
902          t0 = saver_test_utils.CheckpointedOp(name="t0")
903        save = saver_module.Saver(
904            {
905                "v0": v0,
906                "t0": t0.saveable
907            },
908            write_version=self._WRITE_VERSION,
909            sharded=True)
910        self.evaluate(variables.global_variables_initializer())
911        t0.insert("k11", 33.0).run()
912        self.assertEqual(111, self.evaluate(v0))
913        self.assertEqual(b"k11", self.evaluate(t0.keys()))
914        self.assertEqual(33.0, self.evaluate(t0.values()))
915        save.restore(sess, save_path + "-00000-of-00002")
916        self.assertEqual(10, self.evaluate(v0))
917        self.assertEqual(b"k1", self.evaluate(t0.keys()))
918        self.assertEqual(30.0, self.evaluate(t0.values()))
919
920      # Restore different ops from shard 1 of the saved files.
921      with session.Session(
922          target="",
923          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
924        with sess.graph.device("/cpu:0"):
925          v1 = variables.VariableV1(222)
926          t1 = saver_test_utils.CheckpointedOp(name="t1")
927        save = saver_module.Saver(
928            {
929                "v1": v1,
930                "t1": t1.saveable
931            },
932            write_version=self._WRITE_VERSION,
933            sharded=True)
934        self.evaluate(variables.global_variables_initializer())
935        t1.insert("k22", 44.0).run()
936        self.assertEqual(222, self.evaluate(v1))
937        self.assertEqual(b"k22", self.evaluate(t1.keys()))
938        self.assertEqual(44.0, self.evaluate(t1.values()))
939        save.restore(sess, save_path + "-00001-of-00002")
940        self.assertEqual(20, self.evaluate(v1))
941        self.assertEqual(b"k2", self.evaluate(t1.keys()))
942        self.assertEqual(40.0, self.evaluate(t1.values()))
943
944    # Now try a restore with the sharded filename.
945    with session.Session(
946        target="",
947        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
948      with sess.graph.device("/cpu:0"):
949        v0 = variables.VariableV1(111, name="v0")
950        t0 = saver_test_utils.CheckpointedOp(name="t0")
951      with sess.graph.device("/cpu:1"):
952        v1 = variables.VariableV1(222, name="v1")
953        t1 = saver_test_utils.CheckpointedOp(name="t1")
954      save = saver_module.Saver(
955          {
956              "v0": v0,
957              "v1": v1,
958              "t0": t0.saveable,
959              "t1": t1.saveable
960          },
961          write_version=self._WRITE_VERSION,
962          sharded=True)
963      self.evaluate(variables.global_variables_initializer())
964      t0.insert("k11", 33.0).run()
965      t1.insert("k22", 44.0).run()
966      self.assertEqual(111, self.evaluate(v0))
967      self.assertEqual(222, self.evaluate(v1))
968      self.assertEqual(b"k11", self.evaluate(t0.keys()))
969      self.assertEqual(33.0, self.evaluate(t0.values()))
970      self.assertEqual(b"k22", self.evaluate(t1.keys()))
971      self.assertEqual(44.0, self.evaluate(t1.values()))
972      save_path = os.path.join(self.get_temp_dir(), "sharded_basics")
973      if save._write_version is saver_pb2.SaverDef.V1:
974        save.restore(sess, save_path + "-?????-of-?????")
975      else:
976        save.restore(sess, save_path)
977      self.assertEqual(10, self.evaluate(v0))
978      self.assertEqual(20, self.evaluate(v1))
979      self.assertEqual(b"k1", self.evaluate(t0.keys()))
980      self.assertEqual(30.0, self.evaluate(t0.values()))
981      self.assertEqual(b"k2", self.evaluate(t1.keys()))
982      self.assertEqual(40.0, self.evaluate(t1.values()))
983
984    if save._write_version is saver_pb2.SaverDef.V1:
985      self.assertEqual(
986          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
987          os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002"))
988    else:
989      self.assertEqual(
990          checkpoint_management.latest_checkpoint(self.get_temp_dir()),
991          os.path.join(self.get_temp_dir(), "sharded_basics"))
992
993  @test_util.run_deprecated_v1
994  def testSaverDef(self):
995    with self.cached_session():
996      v0 = variables.VariableV1(123, name="v0")
997      save = saver_module.Saver({"v0": v0}, sharded=True)
998      sd = save.as_saver_def()
999      self.assertTrue(sd.sharded)
1000
1001  def _testPartitionedVariables(self, use_resource):
1002    var_full_shape = [10, 3]
1003    # Allows save/restore mechanism to work w/ different slicings.
1004    var_name = "my_var"
1005    saved_dir = self._get_test_dir("partitioned_variables")
1006    saved_path = os.path.join(saved_dir, "ckpt")
1007
1008    call_saver_with_dict = False  # updated by test loop below
1009
1010    def _save(partitioner=None):
1011      with self.session(graph=ops_lib.Graph()) as sess:
1012        # Calls .eval() to return the ndarray that makes up the full variable.
1013        rnd = random_ops.random_uniform(var_full_shape).eval()
1014
1015        if partitioner:
1016          vs = [
1017              variable_scope.get_variable(
1018                  var_name,
1019                  shape=var_full_shape,
1020                  initializer=rnd,
1021                  partitioner=partitioner,
1022                  use_resource=use_resource)
1023          ]
1024        else:
1025          if use_resource:
1026            vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
1027          else:
1028            vs = [variables.VariableV1(rnd, name=var_name)]
1029
1030        self.evaluate(variables.global_variables_initializer())
1031        if call_saver_with_dict:
1032          saver = saver_module.Saver({var_name: vs[0]})
1033        else:
1034          saver = saver_module.Saver(vs)
1035        actual_path = saver.save(sess, saved_path)
1036        self.assertEqual(saved_path, actual_path)
1037
1038        return rnd
1039
1040    def _restore(partitioner=None):
1041      with self.session(graph=ops_lib.Graph()) as sess:
1042        if partitioner:
1043          new_vs = [
1044              variable_scope.get_variable(
1045                  var_name,
1046                  shape=var_full_shape,
1047                  initializer=array_ops.zeros(var_full_shape),
1048                  partitioner=partitioner)
1049          ]
1050        else:
1051          new_vs = [
1052              variables.VariableV1(
1053                  array_ops.zeros(
1054                      shape=var_full_shape),  # != original contents.
1055                  name=var_name)
1056          ]
1057
1058        self.evaluate(variables.global_variables_initializer())
1059        if call_saver_with_dict:
1060          saver = saver_module.Saver({
1061              var_name: new_vs[0]
1062          })
1063        else:
1064          saver = saver_module.Saver(new_vs)
1065        saver.restore(sess, saved_path)
1066
1067        if partitioner:
1068          return new_vs[0].as_tensor().eval()
1069        else:
1070          return new_vs[0].eval()
1071
1072    for call_saver_with_dict in {False, True}:
1073      # Save PartitionedVariable and restore into full variable.
1074      saved_full = _save(
1075          partitioner=partitioned_variables.fixed_size_partitioner(
1076              num_shards=2))
1077      restored_full = _restore()
1078      self.assertAllEqual(saved_full, restored_full)
1079
1080      # Restores into the same number of partitions.
1081      restored_full = _restore(
1082          partitioner=partitioned_variables.fixed_size_partitioner(
1083              num_shards=2))
1084      self.assertAllEqual(saved_full, restored_full)
1085
1086      # Restores into a different number of partitions.
1087      restored_full = _restore(
1088          partitioner=partitioned_variables.fixed_size_partitioner(
1089              num_shards=3))
1090      self.assertAllEqual(saved_full, restored_full)
1091
1092      # Now, saves a full variable and restores PartitionedVariable.
1093      saved_full = _save()
1094      restored_full = _restore(
1095          partitioner=partitioned_variables.fixed_size_partitioner(
1096              num_shards=3))
1097      self.assertAllEqual(saved_full, restored_full)
1098
1099  @test_util.run_deprecated_v1
1100  def testPartitionedVariable(self):
1101    self._testPartitionedVariables(use_resource=False)
1102
1103  @test_util.run_deprecated_v1
1104  def testPartitionedResourceVariable(self):
1105    self._testPartitionedVariables(use_resource=True)
1106
1107
1108class SaveRestoreShardedTestV2(SaveRestoreShardedTest):
1109  _WRITE_VERSION = saver_pb2.SaverDef.V2
1110
1111
1112class MaxToKeepTest(test.TestCase):
1113
1114  def _get_test_dir(self, dirname):
1115    test_dir = os.path.join(self.get_temp_dir(), dirname)
1116    gfile.MakeDirs(test_dir)
1117    return test_dir
1118
1119  def assertCheckpointState(self, model_checkpoint_path,
1120                            all_model_checkpoint_paths, save_dir):
1121    checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir)
1122    self.assertEqual(checkpoint_state.model_checkpoint_path,
1123                     model_checkpoint_path)
1124    self.assertEqual(checkpoint_state.all_model_checkpoint_paths,
1125                     all_model_checkpoint_paths)
1126
1127  def testMaxToKeepEager(self):
1128    with context.eager_mode():
1129      save_dir = self._get_test_dir("max_to_keep_eager")
1130
1131      v = variable_scope.variable(10.0, name="v")
1132      save = saver_module.Saver({"v": v}, max_to_keep=2)
1133      self.evaluate(variables.global_variables_initializer())
1134      if not context.executing_eagerly():
1135        self.assertEqual([], save.last_checkpoints)
1136
1137      s1 = save.save(None, os.path.join(save_dir, "s1"))
1138      self.assertEqual([s1], save.last_checkpoints)
1139      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1140      self.assertCheckpointState(
1141          model_checkpoint_path=s1,
1142          all_model_checkpoint_paths=[s1],
1143          save_dir=save_dir)
1144
1145      s2 = save.save(None, os.path.join(save_dir, "s2"))
1146      self.assertEqual([s1, s2], save.last_checkpoints)
1147      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1148      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1149      self.assertCheckpointState(
1150          model_checkpoint_path=s2,
1151          all_model_checkpoint_paths=[s1, s2],
1152          save_dir=save_dir)
1153
1154      s3 = save.save(None, os.path.join(save_dir, "s3"))
1155      self.assertEqual([s2, s3], save.last_checkpoints)
1156      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1157      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1158      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1159      self.assertCheckpointState(
1160          model_checkpoint_path=s3,
1161          all_model_checkpoint_paths=[s2, s3],
1162          save_dir=save_dir)
1163
1164      # Create a second helper, identical to the first.
1165      save2 = saver_module.Saver({"v": v}, max_to_keep=2)
1166      save2.set_last_checkpoints(save.last_checkpoints)
1167
1168      # Exercise the first helper.
1169
1170      # Adding s2 again (old s2 is removed first, then new s2 appended)
1171      s2 = save.save(None, os.path.join(save_dir, "s2"))
1172      self.assertEqual([s3, s2], save.last_checkpoints)
1173      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1174      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1175      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1176      self.assertCheckpointState(
1177          model_checkpoint_path=s2,
1178          all_model_checkpoint_paths=[s3, s2],
1179          save_dir=save_dir)
1180
1181      # Adding s1 (s3 should now be deleted as oldest in list)
1182      s1 = save.save(None, os.path.join(save_dir, "s1"))
1183      self.assertEqual([s2, s1], save.last_checkpoints)
1184      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1185      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1186      self.assertCheckpointState(
1187          model_checkpoint_path=s1,
1188          all_model_checkpoint_paths=[s2, s1],
1189          save_dir=save_dir)
1190
1191      s2 = save2.save(None, os.path.join(save_dir, "s2"))
1192      self.assertEqual([s3, s2], save2.last_checkpoints)
1193      # Created by the first helper.
1194      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1195      # Deleted by the first helper.
1196      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1197
1198  @test_util.run_deprecated_v1
1199  def testNonSharded(self):
1200    save_dir = self._get_test_dir("max_to_keep_non_sharded")
1201
1202    with self.cached_session() as sess:
1203      v = variables.VariableV1(10.0, name="v")
1204      save = saver_module.Saver({"v": v}, max_to_keep=2)
1205      self.evaluate(variables.global_variables_initializer())
1206      self.assertEqual([], save.last_checkpoints)
1207
1208      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1209      self.assertEqual([s1], save.last_checkpoints)
1210      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1211      self.assertCheckpointState(
1212          model_checkpoint_path=s1,
1213          all_model_checkpoint_paths=[s1],
1214          save_dir=save_dir)
1215
1216      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1217      self.assertEqual([s1, s2], save.last_checkpoints)
1218      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1219      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1220      self.assertCheckpointState(
1221          model_checkpoint_path=s2,
1222          all_model_checkpoint_paths=[s1, s2],
1223          save_dir=save_dir)
1224
1225      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1226      self.assertEqual([s2, s3], save.last_checkpoints)
1227      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1228      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1229      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1230      self.assertCheckpointState(
1231          model_checkpoint_path=s3,
1232          all_model_checkpoint_paths=[s2, s3],
1233          save_dir=save_dir)
1234
1235      # Create a second helper, identical to the first.
1236      save2 = saver_module.Saver(saver_def=save.as_saver_def())
1237      save2.set_last_checkpoints(save.last_checkpoints)
1238
1239      # Create a third helper, with the same configuration but no knowledge of
1240      # previous checkpoints.
1241      save3 = saver_module.Saver(saver_def=save.as_saver_def())
1242
1243      # Exercise the first helper.
1244
1245      # Adding s2 again (old s2 is removed first, then new s2 appended)
1246      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1247      self.assertEqual([s3, s2], save.last_checkpoints)
1248      self.assertFalse(checkpoint_management.checkpoint_exists(s1))
1249      self.assertFalse(
1250          checkpoint_management.checkpoint_exists(
1251              checkpoint_management.meta_graph_filename(s1)))
1252      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1253      self.assertTrue(
1254          checkpoint_management.checkpoint_exists(
1255              checkpoint_management.meta_graph_filename(s3)))
1256      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1257      self.assertTrue(
1258          checkpoint_management.checkpoint_exists(
1259              checkpoint_management.meta_graph_filename(s2)))
1260      self.assertCheckpointState(
1261          model_checkpoint_path=s2,
1262          all_model_checkpoint_paths=[s3, s2],
1263          save_dir=save_dir)
1264
1265      # Adding s1 (s3 should now be deleted as oldest in list)
1266      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1267      self.assertEqual([s2, s1], save.last_checkpoints)
1268      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1269      self.assertFalse(
1270          checkpoint_management.checkpoint_exists(
1271              checkpoint_management.meta_graph_filename(s3)))
1272      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1273      self.assertTrue(
1274          checkpoint_management.checkpoint_exists(
1275              checkpoint_management.meta_graph_filename(s2)))
1276      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1277      self.assertTrue(
1278          checkpoint_management.checkpoint_exists(
1279              checkpoint_management.meta_graph_filename(s1)))
1280      self.assertCheckpointState(
1281          model_checkpoint_path=s1,
1282          all_model_checkpoint_paths=[s2, s1],
1283          save_dir=save_dir)
1284
1285      # Exercise the second helper.
1286
1287      # Adding s2 again (old s2 is removed first, then new s2 appended)
1288      s2 = save2.save(sess, os.path.join(save_dir, "s2"))
1289      self.assertEqual([s3, s2], save2.last_checkpoints)
1290      # Created by the first helper.
1291      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1292      self.assertTrue(
1293          checkpoint_management.checkpoint_exists(
1294              checkpoint_management.meta_graph_filename(s1)))
1295      # Deleted by the first helper.
1296      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1297      self.assertFalse(
1298          checkpoint_management.checkpoint_exists(
1299              checkpoint_management.meta_graph_filename(s3)))
1300      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1301      self.assertTrue(
1302          checkpoint_management.checkpoint_exists(
1303              checkpoint_management.meta_graph_filename(s2)))
1304      self.assertCheckpointState(
1305          model_checkpoint_path=s2,
1306          all_model_checkpoint_paths=[s3, s2],
1307          save_dir=save_dir)
1308
1309      # Adding s1 (s3 should now be deleted as oldest in list)
1310      s1 = save2.save(sess, os.path.join(save_dir, "s1"))
1311      self.assertEqual([s2, s1], save2.last_checkpoints)
1312      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1313      self.assertFalse(
1314          checkpoint_management.checkpoint_exists(
1315              checkpoint_management.meta_graph_filename(s3)))
1316      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1317      self.assertTrue(
1318          checkpoint_management.checkpoint_exists(
1319              checkpoint_management.meta_graph_filename(s2)))
1320      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1321      self.assertTrue(
1322          checkpoint_management.checkpoint_exists(
1323              checkpoint_management.meta_graph_filename(s1)))
1324      self.assertCheckpointState(
1325          model_checkpoint_path=s1,
1326          all_model_checkpoint_paths=[s2, s1],
1327          save_dir=save_dir)
1328
1329      # Exercise the third helper.
1330
1331      # Adding s2 again (but helper is unaware of previous s2)
1332      s2 = save3.save(sess, os.path.join(save_dir, "s2"))
1333      self.assertEqual([s2], save3.last_checkpoints)
1334      # Created by the first helper.
1335      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1336      self.assertTrue(
1337          checkpoint_management.checkpoint_exists(
1338              checkpoint_management.meta_graph_filename(s1)))
1339      # Deleted by the first helper.
1340      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1341      self.assertFalse(
1342          checkpoint_management.checkpoint_exists(
1343              checkpoint_management.meta_graph_filename(s3)))
1344      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1345      self.assertTrue(
1346          checkpoint_management.checkpoint_exists(
1347              checkpoint_management.meta_graph_filename(s2)))
1348      # Even though the file for s1 exists, this saver isn't aware of it, which
1349      # is why it doesn't end up in the checkpoint state.
1350      self.assertCheckpointState(
1351          model_checkpoint_path=s2,
1352          all_model_checkpoint_paths=[s2],
1353          save_dir=save_dir)
1354
1355      # Adding s1 (s3 should not be deleted because helper is unaware of it)
1356      s1 = save3.save(sess, os.path.join(save_dir, "s1"))
1357      self.assertEqual([s2, s1], save3.last_checkpoints)
1358      self.assertFalse(checkpoint_management.checkpoint_exists(s3))
1359      self.assertFalse(
1360          checkpoint_management.checkpoint_exists(
1361              checkpoint_management.meta_graph_filename(s3)))
1362      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1363      self.assertTrue(
1364          checkpoint_management.checkpoint_exists(
1365              checkpoint_management.meta_graph_filename(s2)))
1366      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1367      self.assertTrue(
1368          checkpoint_management.checkpoint_exists(
1369              checkpoint_management.meta_graph_filename(s1)))
1370      self.assertCheckpointState(
1371          model_checkpoint_path=s1,
1372          all_model_checkpoint_paths=[s2, s1],
1373          save_dir=save_dir)
1374
1375  def testSharded(self):
1376    save_dir = self._get_test_dir("max_to_keep_sharded")
1377
1378    with session.Session(
1379        target="",
1380        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
1381      with sess.graph.device("/cpu:0"):
1382        v0 = variables.VariableV1(111, name="v0")
1383      with sess.graph.device("/cpu:1"):
1384        v1 = variables.VariableV1(222, name="v1")
1385      save = saver_module.Saver(
1386          {
1387              "v0": v0,
1388              "v1": v1
1389          }, sharded=True, max_to_keep=2)
1390      self.evaluate(variables.global_variables_initializer())
1391      self.assertEqual([], save.last_checkpoints)
1392
1393      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1394      self.assertEqual([s1], save.last_checkpoints)
1395      if save._write_version is saver_pb2.SaverDef.V1:
1396        self.assertEqual(2, len(gfile.Glob(s1)))
1397      else:
1398        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1399
1400      self.assertTrue(
1401          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1402
1403      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1404      self.assertEqual([s1, s2], save.last_checkpoints)
1405      if save._write_version is saver_pb2.SaverDef.V1:
1406        self.assertEqual(2, len(gfile.Glob(s1)))
1407      else:
1408        self.assertEqual(4, len(gfile.Glob(s1 + "*")))
1409      self.assertTrue(
1410          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1411      if save._write_version is saver_pb2.SaverDef.V1:
1412        self.assertEqual(2, len(gfile.Glob(s2)))
1413      else:
1414        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1415      self.assertTrue(
1416          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
1417
1418      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1419      self.assertEqual([s2, s3], save.last_checkpoints)
1420      self.assertEqual(0, len(gfile.Glob(s1 + "*")))
1421      self.assertFalse(
1422          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1423      if save._write_version is saver_pb2.SaverDef.V1:
1424        self.assertEqual(2, len(gfile.Glob(s2)))
1425      else:
1426        self.assertEqual(4, len(gfile.Glob(s2 + "*")))
1427      self.assertTrue(
1428          gfile.Exists(checkpoint_management.meta_graph_filename(s2)))
1429      if save._write_version is saver_pb2.SaverDef.V1:
1430        self.assertEqual(2, len(gfile.Glob(s3)))
1431      else:
1432        self.assertEqual(4, len(gfile.Glob(s3 + "*")))
1433      self.assertTrue(
1434          gfile.Exists(checkpoint_management.meta_graph_filename(s3)))
1435
1436  def testNoMaxToKeep(self):
1437    save_dir = self._get_test_dir("no_max_to_keep")
1438    save_dir2 = self._get_test_dir("max_to_keep_0")
1439
1440    with self.cached_session() as sess:
1441      v = variables.VariableV1(10.0, name="v")
1442      self.evaluate(variables.global_variables_initializer())
1443
1444      # Test max_to_keep being None.
1445      save = saver_module.Saver({"v": v}, max_to_keep=None)
1446      self.assertEqual([], save.last_checkpoints)
1447      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1448      self.assertEqual([], save.last_checkpoints)
1449      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1450      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1451      self.assertEqual([], save.last_checkpoints)
1452      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1453
1454      # Test max_to_keep being 0.
1455      save2 = saver_module.Saver({"v": v}, max_to_keep=0)
1456      self.assertEqual([], save2.last_checkpoints)
1457      s1 = save2.save(sess, os.path.join(save_dir2, "s1"))
1458      self.assertEqual([], save2.last_checkpoints)
1459      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1460      s2 = save2.save(sess, os.path.join(save_dir2, "s2"))
1461      self.assertEqual([], save2.last_checkpoints)
1462      self.assertTrue(checkpoint_management.checkpoint_exists(s2))
1463
1464  def testNoMetaGraph(self):
1465    save_dir = self._get_test_dir("no_meta_graph")
1466
1467    with self.cached_session() as sess:
1468      v = variables.VariableV1(10.0, name="v")
1469      save = saver_module.Saver({"v": v})
1470      self.evaluate(variables.global_variables_initializer())
1471
1472      s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
1473      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1474      self.assertFalse(
1475          gfile.Exists(checkpoint_management.meta_graph_filename(s1)))
1476
1477
1478class KeepCheckpointEveryNHoursTest(test.TestCase):
1479
1480  def _get_test_dir(self, dirname):
1481    test_dir = os.path.join(self.get_temp_dir(), dirname)
1482    gfile.MakeDirs(test_dir)
1483    return test_dir
1484
1485  @test_util.run_in_graph_and_eager_modes
1486  @test.mock.patch.object(saver_module, "time")
1487  def testNonSharded(self, mock_time):
1488    save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
1489
1490    with self.cached_session() as sess:
1491      v = variable_scope.variable([10.0], name="v")
1492      # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
1493      # call, which throws the test timing off in fastbuild mode.
1494      self.evaluate(variables.global_variables_initializer())
1495      # Create a saver that will keep the last 2 checkpoints plus one every 0.7
1496      # seconds.
1497      start_time = time.time()
1498      mock_time.time.return_value = start_time
1499      save = saver_module.Saver(
1500          {
1501              "v": v
1502          }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600)
1503      self.assertEqual([], save.last_checkpoints)
1504
1505      # Wait till 1 seconds have elapsed so s1 will be old enough to keep.
1506      # sleep may return early, don't trust it.
1507      mock_time.time.return_value = start_time + 1.0
1508      s1 = save.save(sess, os.path.join(save_dir, "s1"))
1509      self.assertEqual([s1], save.last_checkpoints)
1510
1511      s2 = save.save(sess, os.path.join(save_dir, "s2"))
1512      self.assertEqual([s1, s2], save.last_checkpoints)
1513
1514      # We now have 2 'last_checkpoints': [s1, s2].  The next call to Save(),
1515      # would normally delete s1, because max_to_keep is 2.  However, s1 is
1516      # older than 0.7s so we must keep it.
1517      s3 = save.save(sess, os.path.join(save_dir, "s3"))
1518      self.assertEqual([s2, s3], save.last_checkpoints)
1519
1520      # s1 should still be here, we are Not checking now to reduce time
1521      # variance in the test.
1522
1523      # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk.  The next
1524      # call to Save(), will delete s2, because max_to_keep is 2, and because
1525      # we already kept the old s1. s2 is very close in time to s1 so it gets
1526      # deleted.
1527      s4 = save.save(sess, os.path.join(save_dir, "s4"))
1528      self.assertEqual([s3, s4], save.last_checkpoints)
1529
1530      # Check that s1 is still here, but s2 is gone.
1531      self.assertTrue(checkpoint_management.checkpoint_exists(s1))
1532      self.assertFalse(checkpoint_management.checkpoint_exists(s2))
1533      self.assertTrue(checkpoint_management.checkpoint_exists(s3))
1534      self.assertTrue(checkpoint_management.checkpoint_exists(s4))
1535
1536
1537class SaveRestoreWithVariableNameMap(test.TestCase):
1538
1539  def _testNonReshape(self, variable_op):
1540    save_path = os.path.join(self.get_temp_dir(), "non_reshape")
1541
1542    with self.session(graph=ops_lib.Graph()) as sess:
1543      # Build a graph with 2 parameter nodes, and Save and
1544      # Restore nodes for them.
1545      v0 = variable_op(10.0, name="v0")
1546      v1 = variable_op(20.0, name="v1")
1547      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1548      self.evaluate(variables.global_variables_initializer())
1549
1550      # Check that the parameter nodes have been initialized.
1551      self.assertEqual(10.0, self.evaluate(v0))
1552      self.assertEqual(20.0, self.evaluate(v1))
1553
1554      # Save the initialized values in the file at "save_path"
1555      # Use a variable name map to set the saved tensor names
1556      val = save.save(sess, save_path)
1557      self.assertTrue(isinstance(val, six.string_types))
1558      self.assertEqual(save_path, val)
1559
1560      # Verify that the original names are not in the Saved file
1561      save = saver_module.Saver({"v0": v0, "v1": v1})
1562      with self.assertRaisesOpError("not found in checkpoint"):
1563        save.restore(sess, save_path)
1564
1565    # Verify that the mapped names are present in the Saved file and can be
1566    # Restored using remapped names.
1567    with self.session(graph=ops_lib.Graph()) as sess:
1568      v0 = variable_op(-1.0, name="v0")
1569      v1 = variable_op(-1.0, name="v1")
1570
1571      if not context.executing_eagerly():
1572        with self.assertRaisesOpError("uninitialized"):
1573          self.evaluate(v0)
1574        with self.assertRaisesOpError("uninitialized"):
1575          self.evaluate(v1)
1576
1577      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1578      save.restore(sess, save_path)
1579
1580      # Check that the parameter nodes have been restored.
1581      if not context.executing_eagerly():
1582        self.assertEqual(10.0, self.evaluate(v0))
1583        self.assertEqual(20.0, self.evaluate(v1))
1584
1585    # Add a prefix to the node names in the current graph and Restore using
1586    # remapped names.
1587    with self.session(graph=ops_lib.Graph()) as sess:
1588      v0 = variable_op(-1.0, name="restore_prefix/v0")
1589      v1 = variable_op(-1.0, name="restore_prefix/v1")
1590
1591      if not context.executing_eagerly():
1592        with self.assertRaisesOpError("uninitialized"):
1593          self.evaluate(v0)
1594        with self.assertRaisesOpError("uninitialized"):
1595          self.evaluate(v1)
1596
1597      # Restore the saved values in the parameter nodes.
1598      save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
1599      save.restore(sess, save_path)
1600
1601      # Check that the parameter nodes have been restored.
1602      self.assertEqual(10.0, self.evaluate(v0))
1603      self.assertEqual(20.0, self.evaluate(v1))
1604
1605  @test_util.run_in_graph_and_eager_modes
1606  def testNonReshapeResourceVariable(self):
1607    self._testNonReshape(resource_variable_ops.ResourceVariable)
1608
1609  def testNonReshapeVariable(self):
1610    self._testNonReshape(variables.Variable)
1611
1612
1613class MetaGraphTest(test.TestCase):
1614
1615  def _get_test_dir(self, dirname):
1616    test_dir = os.path.join(self.get_temp_dir(), dirname)
1617    gfile.MakeDirs(test_dir)
1618    return test_dir
1619
1620  @test_util.run_v1_only("b/120545219")
1621  def testAddCollectionDef(self):
1622    test_dir = self._get_test_dir("good_collection")
1623    filename = os.path.join(test_dir, "metafile")
1624    with self.cached_session():
1625      # Creates a graph.
1626      v0 = variables.VariableV1(1.0, name="v0")
1627      control_flow_ops.cond(
1628          math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
1629          lambda: math_ops.subtract(v0, 1))
1630      control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
1631                                  lambda i: math_ops.add(i, 1), [v0])
1632      var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
1633      count_up_to = var.count_up_to(3)
1634      input_queue = data_flow_ops.FIFOQueue(
1635          30, dtypes.float32, shared_name="collection_queue")
1636      qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to])
1637      variables.global_variables_initializer()
1638      # Creates a saver.
1639      save = saver_module.Saver({"v0": v0})
1640      # Adds a set of collections.
1641      ops_lib.add_to_collection("int_collection", 3)
1642      ops_lib.add_to_collection("float_collection", 3.5)
1643      ops_lib.add_to_collection("string_collection", "hello")
1644      ops_lib.add_to_collection("variable_collection", v0)
1645      # Add QueueRunners.
1646      queue_runner_impl.add_queue_runner(qr)
1647      # Adds user_defined proto in three formats: string, bytes and Any.
1648      queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
1649      ops_lib.add_to_collection("user_defined_string_collection",
1650                                str(queue_runner))
1651      ops_lib.add_to_collection("user_defined_bytes_collection",
1652                                queue_runner.SerializeToString())
1653      any_buf = Any()
1654      any_buf.Pack(queue_runner)
1655      ops_lib.add_to_collection("user_defined_any_collection", any_buf)
1656
1657      # Generates MetaGraphDef.
1658      meta_graph_def = save.export_meta_graph(filename)
1659      self.assertTrue(meta_graph_def.HasField("saver_def"))
1660      self.assertTrue(meta_graph_def.HasField("graph_def"))
1661      self.assertTrue(meta_graph_def.HasField("meta_info_def"))
1662      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
1663      self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
1664                          "")
1665      collection_def = meta_graph_def.collection_def
1666      self.assertEqual(len(collection_def), 12)
1667
1668    with ops_lib.Graph().as_default():
1669      # Restores from MetaGraphDef.
1670      new_saver = saver_module.import_meta_graph(filename)
1671      # Generates a new MetaGraphDef.
1672      new_meta_graph_def = new_saver.export_meta_graph()
1673      # It should be the same as the original.
1674
1675    test_util.assert_meta_graph_protos_equal(
1676        self, meta_graph_def, new_meta_graph_def)
1677
1678  def testAddCollectionDefFails(self):
1679    with self.cached_session():
1680      # Creates a graph.
1681      v0 = variables.VariableV1(10.0, name="v0")
1682      # Creates a saver.
1683      save = saver_module.Saver({"v0": v0})
1684      # Generates MetaGraphDef.
1685      meta_graph_def = meta_graph_pb2.MetaGraphDef()
1686
1687      # Verifies that collection with unsupported key will not be added.
1688      ops_lib.add_to_collection(save, 3)
1689      save._add_collection_def(meta_graph_def, save)
1690      self.assertEqual(len(meta_graph_def.collection_def), 0)
1691
1692      # Verifies that collection where item type does not match expected
1693      # type will not be added.
1694      ops_lib.add_to_collection("int_collection", 3)
1695      ops_lib.add_to_collection("int_collection", 3.5)
1696      save._add_collection_def(meta_graph_def, "int_collection")
1697      self.assertEqual(len(meta_graph_def.collection_def), 0)
1698
1699  def _testMultiSaverCollectionSave(self, test_dir):
1700    filename = os.path.join(test_dir, "metafile")
1701    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1702    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1703    with self.session(graph=ops_lib.Graph()) as sess:
1704      # Creates a graph.
1705      v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1706      v1 = variables.VariableV1(11.0, name="v1")
1707      # Creates 2 savers.
1708      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1709      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1710      ops_lib.add_to_collection("savers", saver0)
1711      ops_lib.add_to_collection("savers", saver1)
1712      self.evaluate(variables.global_variables_initializer())
1713      # Saves to different checkpoints.
1714      saver0.save(sess, saver0_ckpt)
1715      saver1.save(sess, saver1_ckpt)
1716      # Generates MetaGraphDef.
1717      meta_graph_def = saver_module.export_meta_graph(filename)
1718      meta_graph_def0 = saver0.export_meta_graph()
1719      meta_graph_def1 = saver1.export_meta_graph()
1720
1721      # Verifies that there is no saver_def in meta_graph_def.
1722      self.assertFalse(meta_graph_def.HasField("saver_def"))
1723      # Verifies that there is saver_def in meta_graph_def0 and 1.
1724      self.assertTrue(meta_graph_def0.HasField("saver_def"))
1725      self.assertTrue(meta_graph_def1.HasField("saver_def"))
1726
1727      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
1728      collection_def = meta_graph_def.collection_def["savers"]
1729      kind = collection_def.WhichOneof("kind")
1730      self.assertEqual(kind, "bytes_list")
1731      # Verifies that there are 2 entries in SAVERS collection.
1732      savers = getattr(collection_def, kind)
1733      self.assertEqual(2, len(savers.value))
1734
1735      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0.
1736      collection_def = meta_graph_def0.collection_def["savers"]
1737      kind = collection_def.WhichOneof("kind")
1738      self.assertEqual(kind, "bytes_list")
1739      # Verifies that there are 2 entries in SAVERS collection.
1740      savers = getattr(collection_def, kind)
1741      self.assertEqual(2, len(savers.value))
1742
1743  def _testMultiSaverCollectionRestore(self, test_dir):
1744    filename = os.path.join(test_dir, "metafile")
1745    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1746    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1747    with self.session(graph=ops_lib.Graph()) as sess:
1748      # Imports from meta_graph.
1749      saver_module.import_meta_graph(filename)
1750      # Retrieves SAVERS collection. Verifies there are 2 entries.
1751      savers = ops_lib.get_collection("savers")
1752      self.assertEqual(2, len(savers))
1753      # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1.
1754      new_saver0 = savers[0]
1755      new_saver0.restore(sess, saver0_ckpt)
1756      v0 = sess.graph.get_tensor_by_name("v0:0")
1757      v1 = sess.graph.get_tensor_by_name("v1:0")
1758      self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
1759                          self.evaluate(v0))
1760      self.assertEqual([3, 2], v0.get_shape())
1761      self.assertEqual([], v1.get_shape())
1762      with self.assertRaisesWithPredicateMatch(
1763          errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
1764        self.evaluate(v1)
1765      # Retrieves saver1. Verifies that new_saver1 can restore v1.
1766      new_saver1 = savers[1]
1767      new_saver1.restore(sess, saver1_ckpt)
1768      v1 = sess.graph.get_tensor_by_name("v1:0")
1769      self.assertEqual(11.0, self.evaluate(v1))
1770
1771  @test_util.run_v1_only("b/120545219")
1772  def testMultiSaverCollection(self):
1773    test_dir = self._get_test_dir("saver_collection")
1774    self._testMultiSaverCollectionSave(test_dir)
1775    self._testMultiSaverCollectionRestore(test_dir)
1776
1777  @test_util.run_v1_only("b/120545219")
1778  def testClearExtraneousSavers(self):
1779    test_dir = self._get_test_dir("clear_extraneous_savers")
1780    filename = os.path.join(test_dir, "metafile")
1781    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1782    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
1783    with self.session(graph=ops_lib.Graph()) as sess:
1784      # Creates a graph.
1785      v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
1786      v1 = variables.VariableV1(11.0, name="v1")
1787
1788      # Creates 2 savers.
1789      saver0 = saver_module.Saver({"v0": v0}, name="saver0")
1790      saver1 = saver_module.Saver({"v1": v1}, name="saver1")
1791      ops_lib.add_to_collection("savers", saver0)
1792      ops_lib.add_to_collection("savers", saver1)
1793      self.evaluate(variables.global_variables_initializer())
1794
1795      # Saves to different checkpoints.
1796      saver0.save(sess, saver0_ckpt)
1797      saver1.save(sess, saver1_ckpt)
1798
1799      # Generates MetaGraphDef.
1800      meta_graph_def = saver_module.export_meta_graph(filename)
1801      meta_graph_def0 = saver0.export_meta_graph()
1802      meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True)
1803
1804      # Verifies that there is no saver_def in meta_graph_def.
1805      self.assertFalse(meta_graph_def.HasField("saver_def"))
1806      # Verifies that there is saver_def in meta_graph_def0 and 1.
1807      self.assertTrue(meta_graph_def0.HasField("saver_def"))
1808      self.assertTrue(meta_graph_def1.HasField("saver_def"))
1809
1810      # Verifies SAVERS is saved as bytes_list for meta_graph_def.
1811      collection_def = meta_graph_def.collection_def["savers"]
1812      kind = collection_def.WhichOneof("kind")
1813      self.assertEqual(kind, "bytes_list")
1814
1815      # Verifies that there are 2 entries in SAVERS collection.
1816      savers = getattr(collection_def, kind)
1817      self.assertEqual(2, len(savers.value))
1818
1819      # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1.
1820      collection_def = meta_graph_def1.collection_def["savers"]
1821      kind = collection_def.WhichOneof("kind")
1822      self.assertEqual(kind, "bytes_list")
1823
1824      # Verifies that there is 1 entry in SAVERS collection.
1825      savers = getattr(collection_def, kind)
1826      self.assertEqual(1, len(savers.value))
1827
1828      # Verifies that saver0 graph nodes are omitted from the saver1 export
1829      self.assertEqual(33, len(meta_graph_def0.graph_def.node))
1830      self.assertEqual(21, len(meta_graph_def1.graph_def.node))
1831
1832  @test_util.run_deprecated_v1
1833  def testBinaryAndTextFormat(self):
1834    test_dir = self._get_test_dir("binary_and_text")
1835    filename = os.path.join(test_dir, "metafile")
1836    with self.session(graph=ops_lib.Graph()):
1837      # Creates a graph.
1838      variables.VariableV1(10.0, name="v0")
1839      # Exports the graph as binary format.
1840      saver_module.export_meta_graph(filename, as_text=False)
1841    with self.session(graph=ops_lib.Graph()):
1842      # Imports the binary format graph.
1843      saver = saver_module.import_meta_graph(filename)
1844      self.assertIsNotNone(saver)
1845      # Exports the graph as text format.
1846      saver.export_meta_graph(filename, as_text=True)
1847    with self.session(graph=ops_lib.Graph()):
1848      # Imports the text format graph.
1849      saver_module.import_meta_graph(filename)
1850      # Writes wrong contents to the file.
1851      graph_io.write_graph(saver.as_saver_def(),
1852                           os.path.dirname(filename),
1853                           os.path.basename(filename))
1854    with self.session(graph=ops_lib.Graph()):
1855      # Import should fail.
1856      with self.assertRaisesWithPredicateMatch(IOError,
1857                                               lambda e: "Cannot parse file"):
1858        saver_module.import_meta_graph(filename)
1859      # Deletes the file
1860      gfile.Remove(filename)
1861      with self.assertRaisesWithPredicateMatch(IOError,
1862                                               lambda e: "does not exist"):
1863        saver_module.import_meta_graph(filename)
1864
1865  @test_util.run_v1_only("b/120545219")
1866  def testSliceVariable(self):
1867    test_dir = self._get_test_dir("slice_saver")
1868    filename = os.path.join(test_dir, "metafile")
1869    with self.cached_session():
1870      v1 = variables.VariableV1([20.0], name="v1")
1871      v2 = variables.VariableV1([20.0], name="v2")
1872      v2._set_save_slice_info(
1873          variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
1874
1875      # The names are different and will work.
1876      slice_saver = saver_module.Saver({"first": v1, "second": v2})
1877      self.evaluate(variables.global_variables_initializer())
1878      # Exports to meta_graph
1879      meta_graph_def = slice_saver.export_meta_graph(filename)
1880
1881    with ops_lib.Graph().as_default():
1882      # Restores from MetaGraphDef.
1883      new_saver = saver_module.import_meta_graph(filename)
1884      self.assertIsNotNone(new_saver)
1885      # Generates a new MetaGraphDef.
1886      new_meta_graph_def = new_saver.export_meta_graph()
1887      # It should be the same as the original.
1888      test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
1889                                               new_meta_graph_def)
1890
1891  def _testGraphExtensionSave(self, test_dir):
1892    filename = os.path.join(test_dir, "metafile")
1893    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1894    # Creates an inference graph.
1895    # Hidden 1
1896    images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
1897    with ops_lib.name_scope("hidden1"):
1898      weights = variables.VariableV1(
1899          random_ops.truncated_normal(
1900              [28, 128], stddev=1.0 / math.sqrt(float(28))),
1901          name="weights")
1902      # The use of control_flow_ops.cond here is purely for adding test coverage
1903      # the save and restore of control flow context (which doesn't make any
1904      # sense here from a machine learning perspective).  The typical biases is
1905      # a simple Variable without the conditions.
1906      biases = variables.VariableV1(
1907          control_flow_ops.cond(
1908              math_ops.less(random.random(), 0.5),
1909              lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
1910          name="biases")
1911      hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
1912    # Hidden 2
1913    with ops_lib.name_scope("hidden2"):
1914      weights = variables.VariableV1(
1915          random_ops.truncated_normal(
1916              [128, 32], stddev=1.0 / math.sqrt(float(128))),
1917          name="weights")
1918
1919      # The use of control_flow_ops.while_loop here is purely for adding test
1920      # coverage the save and restore of control flow context (which doesn't
1921      # make any sense here from a machine learning perspective).  The typical
1922      # biases is a simple Variable without the conditions.
1923      def loop_cond(it, _):
1924        return it < 2
1925
1926      def loop_body(it, biases):
1927        biases += constant_op.constant(0.1, shape=[32])
1928        return it + 1, biases
1929
1930      _, biases = control_flow_ops.while_loop(
1931          loop_cond, loop_body,
1932          [constant_op.constant(0),
1933           variables.VariableV1(array_ops.zeros([32]))])
1934      hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
1935    # Linear
1936    with ops_lib.name_scope("softmax_linear"):
1937      weights = variables.VariableV1(
1938          random_ops.truncated_normal(
1939              [32, 10], stddev=1.0 / math.sqrt(float(32))),
1940          name="weights")
1941      biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
1942      logits = math_ops.matmul(hidden2, weights) + biases
1943      ops_lib.add_to_collection("logits", logits)
1944    init_all_op = variables.global_variables_initializer()
1945
1946    with self.cached_session() as sess:
1947      # Initializes all the variables.
1948      self.evaluate(init_all_op)
1949      # Runs to logit.
1950      self.evaluate(logits)
1951      # Creates a saver.
1952      saver0 = saver_module.Saver()
1953      saver0.save(sess, saver0_ckpt)
1954      # Generates MetaGraphDef.
1955      saver0.export_meta_graph(filename)
1956
1957  def _testGraphExtensionRestore(self, test_dir):
1958    filename = os.path.join(test_dir, "metafile")
1959    train_filename = os.path.join(test_dir, "train_metafile")
1960    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1961    with self.session(graph=ops_lib.Graph()) as sess:
1962      # Restores from MetaGraphDef.
1963      new_saver = saver_module.import_meta_graph(filename)
1964      # Generates a new MetaGraphDef.
1965      new_saver.export_meta_graph()
1966      # Restores from checkpoint.
1967      new_saver.restore(sess, saver0_ckpt)
1968      # Adds loss and train.
1969      labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels")
1970      batch_size = array_ops.size(labels)
1971      labels = array_ops.expand_dims(labels, 1)
1972      indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1)
1973      concated = array_ops.concat([indices, labels], 1)
1974      onehot_labels = sparse_ops.sparse_to_dense(
1975          concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
1976      logits = ops_lib.get_collection("logits")[0]
1977      cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
1978          labels=onehot_labels, logits=logits, name="xentropy")
1979      loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
1980
1981      summary.scalar("loss", loss)
1982      # Creates the gradient descent optimizer with the given learning rate.
1983      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
1984
1985      # Runs train_op.
1986      train_op = optimizer.minimize(loss)
1987      ops_lib.add_to_collection("train_op", train_op)
1988
1989      # Runs train_op.
1990      self.evaluate(train_op)
1991
1992      # Generates MetaGraphDef.
1993      saver_module.export_meta_graph(train_filename)
1994
1995  def _testRestoreFromTrainGraphWithControlContext(self, test_dir):
1996    train_filename = os.path.join(test_dir, "train_metafile")
1997    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
1998    with self.session(graph=ops_lib.Graph()) as sess:
1999      # Restores from MetaGraphDef.
2000      new_saver = saver_module.import_meta_graph(train_filename)
2001      # Restores from checkpoint.
2002      new_saver.restore(sess, saver0_ckpt)
2003      train_op = ops_lib.get_collection("train_op")[0]
2004      self.evaluate(train_op)
2005
2006  @test_util.run_deprecated_v1
2007  def testGraphExtension(self):
2008    test_dir = self._get_test_dir("graph_extension")
2009    self._testGraphExtensionSave(test_dir)
2010    self._testGraphExtensionRestore(test_dir)
2011    self._testRestoreFromTrainGraphWithControlContext(test_dir)
2012
2013  def _testGradientSerDes(self, graph_fn):
2014    """Tests that gradients can be computed after exporting and importing.
2015
2016    Builds a graph, exports it, and verifies that it can be imported and the
2017    gradient can be built and run correctly.
2018
2019    Args:
2020      graph_fn: takes a single float Tensor argument as input, outputs a single
2021        Tensor
2022    """
2023    test_dir = self._get_test_dir("nested_control_flow")
2024    filename = os.path.join(test_dir, "metafile")
2025    saver_ckpt = os.path.join(test_dir, "saver.ckpt")
2026
2027    # Create while loop using `outer_body_fn`.
2028    with ops_lib.Graph().as_default():
2029      var = variables.VariableV1(0.0)
2030      var_name = var.name
2031      output = graph_fn(var)
2032      output_name = output.name
2033      init_op = variables.global_variables_initializer()
2034
2035      # Generate a MetaGraphDef containing the while loop.
2036      with session.Session() as sess:
2037        self.evaluate(init_op)
2038        self.evaluate(output)
2039        saver = saver_module.Saver()
2040        saver.save(sess, saver_ckpt)
2041        saver.export_meta_graph(filename)
2042
2043      # Build and run the gradients of the while loop. We use this below to
2044      # verify that the gradients are correct with an imported MetaGraphDef.
2045      grad = gradients_impl.gradients([output], [var])
2046      # Turn off constant folding to avoid breaking testNestedControlFlowSerDes.
2047      # It appears that a missing control dependency in the gradient graph
2048      # causes the fetch node to not be triggered.
2049      no_constfold_config = config_pb2.ConfigProto()
2050      no_constfold_config.graph_options.rewrite_options.constant_folding = (
2051          rewriter_config_pb2.RewriterConfig.OFF)
2052      with session.Session(config=no_constfold_config) as sess:
2053        self.evaluate(init_op)
2054        expected_grad_value = self.evaluate(grad)
2055
2056    # Restore the MetaGraphDef into a new Graph.
2057    with ops_lib.Graph().as_default():
2058      with session.Session() as sess:
2059        saver = saver_module.import_meta_graph(filename)
2060        saver.restore(sess, saver_ckpt)
2061
2062      # Make sure we can still build gradients and get the same result.
2063      var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
2064      output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
2065      grad = gradients_impl.gradients([output], [var])
2066
2067      init_op = variables.global_variables_initializer()
2068
2069      with session.Session(config=no_constfold_config) as sess:
2070        self.evaluate(init_op)
2071        actual_grad_value = self.evaluate(grad)
2072        self.assertEqual(expected_grad_value, actual_grad_value)
2073
2074  def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
2075    # Build a while loop with `outer_body_fn`, export it, and verify that it can
2076    # be imported and the gradient can be built and run correctly.
2077    # pylint: disable=g-long-lambda
2078    return self._testGradientSerDes(
2079        lambda x: control_flow_ops.while_loop(
2080            lambda i, y: i < 5, outer_body_fn, [0, x])[1])
2081    # pylint: enable=g-long-lambda
2082
2083  def testNestedWhileLoopsSerDes(self):
2084    # Test two simple nested while loops.
2085    def body(i, x):
2086      _, r = control_flow_ops.while_loop(lambda j, y: j < 3,
2087                                         lambda j, y: (j + 1, y + x),
2088                                         [0, 0.0])
2089      return i + 1, x + r
2090    self._testWhileLoopAndGradientSerDes(body)
2091
2092  def testNestedControlFlowSerDes(self):
2093    # Test while loop in a cond in a while loop.
2094    # pylint: disable=g-long-lambda
2095    def body(i, x):
2096      cond_result = control_flow_ops.cond(
2097          i > 0,
2098          lambda: control_flow_ops.while_loop(
2099              lambda j, y: j < 3,
2100              lambda j, y: (j + 1, y + x),
2101              [0, 0.0])[1],
2102          lambda: x)
2103      return i + 1, cond_result
2104    # pylint: enable=g-long-lambda
2105    self._testWhileLoopAndGradientSerDes(body)
2106
2107  def testNestedCondsSerDes(self):
2108    # Test conds in a cond.
2109    # pylint: disable=g-long-lambda
2110    self._testGradientSerDes(lambda x: control_flow_ops.cond(
2111        x > 0,
2112        lambda: control_flow_ops.cond(x > 3,
2113                                      lambda: array_ops.identity(x),
2114                                      lambda: math_ops.multiply(x, 2.0)),
2115        lambda: control_flow_ops.cond(x < -3,
2116                                      lambda: constant_op.constant(1.0),
2117                                      lambda: math_ops.multiply(x, -1.0))))
2118    # pylint: enable=g-long-lambda
2119
2120  @test_util.run_v1_only("b/120545219")
2121  def testStrippedOpListDef(self):
2122    with self.cached_session():
2123      # Creates a graph.
2124      v0 = variables.VariableV1(0.0)
2125      var = variables.VariableV1(10.0)
2126      math_ops.add(v0, var)
2127
2128      @function.Defun(dtypes.float32)
2129      def minus_one(x):
2130        return x - 1
2131
2132      minus_one(array_ops.identity(v0))
2133      save = saver_module.Saver({"v0": v0})
2134      variables.global_variables_initializer()
2135
2136      # Generates MetaGraphDef.
2137      meta_graph_def = save.export_meta_graph()
2138      ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op]
2139      if save._write_version is saver_pb2.SaverDef.V1:
2140        self.assertEqual(ops, [
2141            "Add", "Assign", "Const", "Identity", "NoOp",
2142            "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub",
2143            "VariableV2"
2144        ])
2145      else:
2146        self.assertEqual(ops, [
2147            "Add", "Assign", "Const", "Identity", "NoOp",
2148            "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2"
2149        ])
2150
2151      # Test calling stripped_op_list_for_graph directly
2152      op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def)
2153      self.assertEqual(ops, [o.name for o in op_list.op])
2154      for o in op_list.op:
2155        self.assertEqual(o.summary, "")
2156        self.assertEqual(o.description, "")
2157
2158  @test_util.run_deprecated_v1
2159  def testStripDefaultValuedAttrs(self):
2160    """Verifies that default valued attrs are stripped, unless disabled."""
2161
2162    # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
2163    # (complex64) in the "Complex" op must be removed.
2164    with self.cached_session():
2165      real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
2166      imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
2167      math_ops.complex(real_num, imag_num, name="complex")
2168
2169      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2170      variables.global_variables_initializer()
2171
2172      meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
2173      node_def = test_util.get_node_def_from_graph("complex",
2174                                                   meta_graph_def.graph_def)
2175      self.assertNotIn("T", node_def.attr)
2176      self.assertNotIn("Tout", node_def.attr)
2177
2178    # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
2179    # (complex64) in the "Complex" op must *not* be removed, even if they map
2180    # to their defaults.
2181    with self.session(graph=ops_lib.Graph()):
2182      real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
2183      imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
2184      math_ops.complex(real_num, imag_num, name="complex")
2185
2186      save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
2187      variables.global_variables_initializer()
2188
2189      meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
2190      node_def = test_util.get_node_def_from_graph("complex",
2191                                                   meta_graph_def.graph_def)
2192      self.assertIn("T", node_def.attr)
2193      self.assertIn("Tout", node_def.attr)
2194
2195  @test_util.run_deprecated_v1
2196  def testImportIntoNamescope(self):
2197    # Test that we can import a meta graph into a namescope.
2198    test_dir = self._get_test_dir("import_into_namescope")
2199    filename = os.path.join(test_dir, "ckpt")
2200    image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2201    label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2202    with session.Session() as sess:
2203      weights = variables.VariableV1(
2204          random_ops.random_uniform([784, 10]), name="weights")
2205      bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2206      logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
2207      nn_ops.softmax(logit, name="prediction")
2208      cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2209                                                      logits=logit, name="cost")
2210      adam.AdamOptimizer().minimize(cost, name="optimize")
2211      saver = saver_module.Saver()
2212      self.evaluate(variables.global_variables_initializer())
2213      saver.save(sess, filename)
2214
2215    graph = ops_lib.Graph()
2216    with session.Session(graph=graph) as sess:
2217      new_saver = saver_module.import_meta_graph(
2218          filename + ".meta", graph=graph, import_scope="new_model")
2219      new_saver.restore(sess, filename)
2220      sess.run(["new_model/optimize"], {
2221          "new_model/image:0": np.random.random([1, 784]),
2222          "new_model/label:0": np.random.randint(
2223              10, size=[1, 10])
2224      })
2225
2226  def testImportIntoNamescopeWithoutVariables(self):
2227    # Save a simple graph that contains no variables into a checkpoint.
2228    test_dir = self._get_test_dir("no_vars_graph")
2229    filename = os.path.join(test_dir, "ckpt")
2230    graph_1 = ops_lib.Graph()
2231    with session.Session(graph=graph_1) as sess:
2232      constant_op.constant([1, 2, 3], name="x")
2233      constant_op.constant([1, 2, 3], name="y")
2234      saver = saver_module.Saver(allow_empty=True)
2235      saver.save(sess, filename)
2236
2237    # Create a fresh graph.
2238    graph_2 = ops_lib.Graph()
2239    with session.Session(graph=graph_2) as sess:
2240      # Restore the above checkpoint under scope "subgraph_1".
2241      new_saver_1 = saver_module.import_meta_graph(
2242          filename + ".meta", graph=graph_2, import_scope="subgraph_1")
2243      # There are no variables to restore, so import_meta_graph should not
2244      # return a Saver.
2245      self.assertIsNone(new_saver_1)
2246
2247      # Create a variable in graph_2 under scope "my_scope".
2248      variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
2249      self.evaluate(variables.global_variables_initializer())
2250      # Restore the checkpoint into a different scope "subgraph_2".
2251      new_saver_2 = saver_module.import_meta_graph(
2252          filename + ".meta", graph=graph_2, import_scope="subgraph_2")
2253      # Because the variable does not live in scope "subgraph_2",
2254      # import_meta_graph should not attempt to restore the variable. So,
2255      # import_meta_graph still won't return a Saver instance.
2256      self.assertIsNone(new_saver_2)
2257
2258      # However, if we restore the checkpoint under scope "my_scope",
2259      # import_meta_graph will detect the variable and return a Saver for
2260      # restoring it. This should happen even when the variable does not
2261      # originate from graph_1.
2262      new_saver_3 = saver_module.import_meta_graph(
2263          filename + ".meta", graph=graph_2, import_scope="my_scope")
2264      self.assertIsInstance(new_saver_3, saver_module.Saver)
2265
2266  @test_util.run_deprecated_v1
2267  def testImportIntoImplicitNamescope(self):
2268    # Test that we can import a meta graph into an implicit namescope.
2269    test_dir = self._get_test_dir("import_into_namescope")
2270    filename = os.path.join(test_dir, "ckpt")
2271    image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2272    label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2273    with session.Session() as sess:
2274      weights = variables.VariableV1(
2275          random_ops.random_uniform([784, 10]), name="weights")
2276      bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2277      logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
2278      nn_ops.softmax(logit, name="prediction")
2279      cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2280                                                      logits=logit, name="cost")
2281      adam.AdamOptimizer().minimize(cost, name="optimize")
2282      saver = saver_module.Saver()
2283      self.evaluate(variables.global_variables_initializer())
2284      saver.save(sess, filename)
2285
2286    graph = ops_lib.Graph()
2287    with session.Session(graph=graph) as sess:
2288      with ops_lib.name_scope("new_model"):
2289        new_saver = saver_module.import_meta_graph(
2290            filename + ".meta", graph=graph)
2291
2292      new_saver.restore(sess, filename)
2293      sess.run(["new_model/optimize"], {
2294          "new_model/image:0": np.random.random([1, 784]),
2295          "new_model/label:0": np.random.randint(
2296              10, size=[1, 10])
2297      })
2298
2299  def testClearDevicesOnImport(self):
2300    # Test that we import a graph without its devices and run successfully.
2301    with ops_lib.Graph().as_default():
2302      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2303        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2304        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2305        weights = variables.VariableV1(
2306            random_ops.random_uniform([784, 10]), name="weights")
2307        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2308        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2309        nn_ops.softmax(logit, name="prediction")
2310        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2311                                                        logits=logit)
2312        adam.AdamOptimizer().minimize(cost, name="optimize")
2313      meta_graph_def = saver_module.export_meta_graph()
2314
2315    with session.Session(graph=ops_lib.Graph()) as sess:
2316      saver_module.import_meta_graph(
2317          meta_graph_def, clear_devices=False, import_scope="new_model")
2318      # Device refers to GPU, which is not available here.
2319      with self.assertRaises(errors_impl.InvalidArgumentError):
2320        self.evaluate(variables.global_variables_initializer())
2321
2322    with session.Session(graph=ops_lib.Graph()) as sess:
2323      saver_module.import_meta_graph(
2324          meta_graph_def, clear_devices=True, import_scope="new_model")
2325      self.evaluate(variables.global_variables_initializer())
2326      sess.run(["new_model/optimize"], {
2327          "new_model/image:0": np.random.random([1, 784]),
2328          "new_model/label:0": np.random.randint(
2329              10, size=[1, 10])
2330      })
2331
2332  def testClearDevicesOnExport(self):
2333    # Test that we export a graph without its devices and run successfully.
2334    with ops_lib.Graph().as_default():
2335      with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
2336        image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
2337        label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
2338        weights = variables.VariableV1(
2339            random_ops.random_uniform([784, 10]), name="weights")
2340        bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
2341        logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
2342        nn_ops.softmax(logit, name="prediction")
2343        cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
2344                                                        logits=logit)
2345        adam.AdamOptimizer().minimize(cost, name="optimize")
2346      meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
2347      graph_io.write_graph(meta_graph_def, self.get_temp_dir(),
2348                           "meta_graph.pbtxt")
2349
2350    with session.Session(graph=ops_lib.Graph()) as sess:
2351      saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2352      self.evaluate(variables.global_variables_initializer())
2353      sess.run(["new_model/optimize"], {
2354          "new_model/image:0": np.random.random([1, 784]),
2355          "new_model/label:0": np.random.randint(
2356              10, size=[1, 10])
2357      })
2358
2359  def testPreserveDatasetAndFunctions(self):
2360    with ops_lib.Graph().as_default() as g:
2361      dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
2362      iterator = dataset_ops.make_one_shot_iterator(dataset)
2363      next_element = iterator.get_next()
2364      _ = array_ops.identity(next_element, name="output")
2365
2366      # Generate three MetaGraphDef protos using different code paths.
2367      meta_graph_def_simple = saver_module.export_meta_graph()
2368      meta_graph_def_devices_cleared = saver_module.export_meta_graph(
2369          clear_devices=True)
2370      meta_graph_def_from_graph_def = saver_module.export_meta_graph(
2371          clear_devices=True, graph_def=g.as_graph_def())
2372
2373    for meta_graph_def in [meta_graph_def_simple,
2374                           meta_graph_def_devices_cleared,
2375                           meta_graph_def_from_graph_def]:
2376      with session.Session(graph=ops_lib.Graph()) as sess:
2377        saver_module.import_meta_graph(meta_graph_def, import_scope="new_model")
2378        self.evaluate(variables.global_variables_initializer())
2379        for i in range(10):
2380          self.assertEqual(i * i, sess.run("new_model/output:0"))
2381        with self.assertRaises(errors.OutOfRangeError):
2382          sess.run("new_model/output:0")
2383
2384
2385class CheckpointReaderTest(test.TestCase):
2386
2387  _WRITE_VERSION = saver_pb2.SaverDef.V1
2388
2389  @test_util.run_deprecated_v1
2390  def testDebugString(self):
2391    # Builds a graph.
2392    v0 = variables.VariableV1(
2393        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2394    v1 = variables.VariableV1(
2395        [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
2396    init_all_op = variables.global_variables_initializer()
2397    save = saver_module.Saver(
2398        {
2399            "v0": v0,
2400            "v1": v1
2401        }, write_version=self._WRITE_VERSION)
2402    save_path = os.path.join(self.get_temp_dir(),
2403                             "ckpt_for_debug_string" + str(self._WRITE_VERSION))
2404    with self.cached_session() as sess:
2405      self.evaluate(init_all_op)
2406      # Saves a checkpoint.
2407      save.save(sess, save_path)
2408
2409      # Creates a reader.
2410      reader = pywrap_tensorflow.NewCheckpointReader(save_path)
2411      # Verifies that the tensors exist.
2412      self.assertTrue(reader.has_tensor("v0"))
2413      self.assertTrue(reader.has_tensor("v1"))
2414      debug_string = reader.debug_string()
2415      # Verifies that debug string contains the right strings.
2416      self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string)
2417      self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string)
2418      # Verifies get_variable_to_shape_map() returns the correct information.
2419      var_map = reader.get_variable_to_shape_map()
2420      self.assertEqual([2, 3], var_map["v0"])
2421      self.assertEqual([3, 2, 1], var_map["v1"])
2422      # Verifies get_tensor() returns the tensor value.
2423      v0_tensor = reader.get_tensor("v0")
2424      v1_tensor = reader.get_tensor("v1")
2425      self.assertAllEqual(v0.eval(), v0_tensor)
2426      self.assertAllEqual(v1.eval(), v1_tensor)
2427      # Verifies get_tensor() fails for non-existent tensors.
2428      with self.assertRaisesRegexp(errors.NotFoundError,
2429                                   "v3 not found in checkpoint"):
2430        reader.get_tensor("v3")
2431
2432  def testNonexistentPath(self):
2433    with self.assertRaisesRegexp(errors.NotFoundError,
2434                                 "Unsuccessful TensorSliceReader"):
2435      pywrap_tensorflow.NewCheckpointReader("non-existent")
2436
2437
2438class CheckpointReaderForV2Test(CheckpointReaderTest):
2439  _WRITE_VERSION = saver_pb2.SaverDef.V2
2440
2441
2442class WriteGraphTest(test.TestCase):
2443
2444  def _get_test_dir(self, dirname):
2445    test_dir = os.path.join(self.get_temp_dir(), dirname)
2446    gfile.MakeDirs(test_dir)
2447    return test_dir
2448
2449  def testWriteGraph(self):
2450    test_dir = self._get_test_dir("write_graph_dir")
2451    variables.VariableV1(
2452        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2453    path = graph_io.write_graph(ops_lib.get_default_graph(),
2454                                os.path.join(test_dir, "l1"), "graph.pbtxt")
2455    truth = os.path.join(test_dir, "l1", "graph.pbtxt")
2456    self.assertEqual(path, truth)
2457    self.assertTrue(os.path.exists(path))
2458
2459  def testRecursiveCreate(self):
2460    test_dir = self._get_test_dir("deep_dir")
2461    variables.VariableV1(
2462        [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
2463    path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
2464                                os.path.join(test_dir, "l1", "l2", "l3"),
2465                                "graph.pbtxt")
2466    truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt")
2467    self.assertEqual(path, truth)
2468    self.assertTrue(os.path.exists(path))
2469
2470
2471class ScopedGraphTest(test.TestCase):
2472
2473  def _get_test_dir(self, dirname):
2474    test_dir = os.path.join(self.get_temp_dir(), dirname)
2475    gfile.MakeDirs(test_dir)
2476    return test_dir
2477
2478  def _testScopedSave(self, test_dir, exported_filename, ckpt_filename):
2479    graph = ops_lib.Graph()
2480    with graph.as_default():
2481      # Creates an inference graph.
2482      # Hidden 1
2483      images = constant_op.constant(
2484          1.2, dtypes.float32, shape=[100, 28], name="images")
2485      with ops_lib.name_scope("hidden1"):
2486        weights1 = variables.VariableV1(
2487            random_ops.truncated_normal(
2488                [28, 128], stddev=1.0 / math.sqrt(float(28))),
2489            name="weights")
2490        # The use of control_flow_ops.cond here is purely for adding test
2491        # coverage the save and restore of control flow context (which doesn't
2492        # make any sense here from a machine learning perspective).  The typical
2493        # biases is a simple Variable without the conditions.
2494        biases1 = variables.VariableV1(
2495            control_flow_ops.cond(
2496                math_ops.less(random.random(), 0.5),
2497                lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
2498            name="biases")
2499        hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
2500
2501      # Hidden 2
2502      with ops_lib.name_scope("hidden2"):
2503        weights2 = variables.VariableV1(
2504            random_ops.truncated_normal(
2505                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2506            name="weights")
2507
2508        # The use of control_flow_ops.while_loop here is purely for adding test
2509        # coverage the save and restore of control flow context (which doesn't
2510        # make any sense here from a machine learning perspective).  The typical
2511        # biases is a simple Variable without the conditions.
2512        def loop_cond(it, _):
2513          return it < 2
2514
2515        def loop_body(it, biases2):
2516          biases2 += constant_op.constant(0.1, shape=[32])
2517          return it + 1, biases2
2518
2519        _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
2520            constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
2521        ])
2522        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
2523      # Linear
2524      with ops_lib.name_scope("softmax_linear"):
2525        weights3 = variables.VariableV1(
2526            random_ops.truncated_normal(
2527                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2528            name="weights")
2529        biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
2530        logits = math_ops.matmul(hidden2, weights3) + biases3
2531        ops_lib.add_to_collection("logits", logits)
2532
2533        # Adds user_defined proto in three formats: string, bytes and Any.
2534        # Any proto should just pass through.
2535        queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue")
2536        ops_lib.add_to_collection("user_defined_string_collection",
2537                                  str(queue_runner))
2538        ops_lib.add_to_collection("user_defined_bytes_collection",
2539                                  queue_runner.SerializeToString())
2540        any_buf = Any()
2541        any_buf.Pack(queue_runner)
2542        ops_lib.add_to_collection("user_defined_any_collection", any_buf)
2543
2544      _, var_list = meta_graph.export_scoped_meta_graph(
2545          filename=os.path.join(test_dir, exported_filename),
2546          graph=ops_lib.get_default_graph(),
2547          export_scope="hidden1")
2548      self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2549
2550    with self.session(graph=graph) as sess:
2551      self.evaluate(variables.global_variables_initializer())
2552      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2553      saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False)
2554
2555  def _testScopedRestore(self, test_dir, exported_filename,
2556                         new_exported_filename, ckpt_filename):
2557    graph = ops_lib.Graph()
2558    # Create all the missing inputs.
2559    with graph.as_default():
2560      new_image = constant_op.constant(
2561          1.2, dtypes.float32, shape=[100, 28], name="images")
2562    var_list = meta_graph.import_scoped_meta_graph(
2563        os.path.join(test_dir, exported_filename),
2564        graph=graph,
2565        input_map={"$unbound_inputs_images": new_image},
2566        import_scope="new_hidden1")
2567    self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
2568    hidden1 = graph.as_graph_element("new_hidden1/Relu:0")
2569    weights1 = graph.as_graph_element("new_hidden1/weights:0")
2570    biases1 = graph.as_graph_element("new_hidden1/biases:0")
2571
2572    with graph.as_default():
2573      # Hidden 2
2574      with ops_lib.name_scope("hidden2"):
2575        weights = variables.VariableV1(
2576            random_ops.truncated_normal(
2577                [128, 32], stddev=1.0 / math.sqrt(float(128))),
2578            name="weights")
2579
2580        # The use of control_flow_ops.while_loop here is purely for adding test
2581        # coverage the save and restore of control flow context (which doesn't
2582        # make any sense here from a machine learning perspective).  The typical
2583        # biases is a simple Variable without the conditions.
2584        def loop_cond(it, _):
2585          return it < 2
2586
2587        def loop_body(it, biases):
2588          biases += constant_op.constant(0.1, shape=[32])
2589          return it + 1, biases
2590
2591        _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
2592            constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
2593        ])
2594        hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
2595      # Linear
2596      with ops_lib.name_scope("softmax_linear"):
2597        weights = variables.VariableV1(
2598            random_ops.truncated_normal(
2599                [32, 10], stddev=1.0 / math.sqrt(float(32))),
2600            name="weights")
2601        biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
2602        logits = math_ops.matmul(hidden2, weights) + biases
2603        ops_lib.add_to_collection("logits", logits)
2604
2605      # The rest of the variables.
2606      rest_variables = list(
2607          set(variables.global_variables()) - set(var_list.keys()))
2608      init_rest_op = variables.variables_initializer(rest_variables)
2609
2610    with self.session(graph=graph) as sess:
2611      saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
2612      saver.restore(sess, os.path.join(test_dir, ckpt_filename))
2613      # Verify that we have restored weights1 and biases1.
2614      self.evaluate([weights1, biases1])
2615      # Initialize the rest of the variables and run logits.
2616      self.evaluate(init_rest_op)
2617      self.evaluate(logits)
2618
2619  # Verifies that we can save the subgraph under "hidden1" and restore it
2620  # into "new_hidden1" in the new graph.
2621  @test_util.run_deprecated_v1
2622  def testScopedSaveAndRestore(self):
2623    test_dir = self._get_test_dir("scoped_export_import")
2624    ckpt_filename = "ckpt"
2625    self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename)
2626    self._testScopedRestore(test_dir, "exported_hidden1.pbtxt",
2627                            "exported_new_hidden1.pbtxt", ckpt_filename)
2628
2629  # Verifies that we can copy the subgraph under "hidden1" and copy it
2630  # to different name scope in the same graph or different graph.
2631  @test_util.run_deprecated_v1
2632  def testCopyScopedGraph(self):
2633    test_dir = self._get_test_dir("scoped_copy")
2634    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2635    graph1 = ops_lib.Graph()
2636    with graph1.as_default():
2637      with ops_lib.name_scope("hidden1"):
2638        images = constant_op.constant(
2639            1.0, dtypes.float32, shape=[3, 2], name="images")
2640        weights1 = variables.VariableV1(
2641            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2642        biases1 = variables.VariableV1([0.1] * 3, name="biases")
2643        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2644
2645    # Run the graph and save scoped checkpoint.
2646    with self.session(graph=graph1) as sess:
2647      self.evaluate(variables.global_variables_initializer())
2648      _, var_list_1 = meta_graph.export_scoped_meta_graph(
2649          export_scope="hidden1")
2650      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2651      saver.save(sess, saver0_ckpt, write_state=False)
2652
2653    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2654
2655    # Verifies copy to the same graph with the same name fails.
2656    with graph1.as_default():
2657      with self.assertRaisesWithPredicateMatch(
2658          ValueError, lambda e: "need to be different" in str(e)):
2659        meta_graph.copy_scoped_meta_graph(
2660            from_scope="hidden1", to_scope="hidden1")
2661
2662    # Verifies copy to the same graph.
2663    with graph1.as_default():
2664      var_list_2 = meta_graph.copy_scoped_meta_graph(
2665          from_scope="hidden1", to_scope="hidden2")
2666
2667    with self.session(graph=graph1) as sess:
2668      saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2669      saver1.restore(sess, saver0_ckpt)
2670      saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1)
2671      saver2.restore(sess, saver0_ckpt)
2672      self.assertAllClose(expected, sess.run("hidden1/relu:0"))
2673      self.assertAllClose(expected, sess.run("hidden2/relu:0"))
2674
2675    # Verifies copy to differen graph.
2676    graph2 = ops_lib.Graph()
2677    new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2678        from_scope="hidden1",
2679        to_scope="new_hidden1",
2680        from_graph=graph1,
2681        to_graph=graph2)
2682
2683    with self.session(graph=graph2) as sess:
2684      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2685      saver3.restore(sess, saver0_ckpt)
2686      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2687
2688  @test_util.run_deprecated_v1
2689  def testExportGraphDefWithScope(self):
2690    test_dir = self._get_test_dir("export_graph_def")
2691    saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
2692    graph1 = ops_lib.Graph()
2693    with graph1.as_default():
2694      with ops_lib.name_scope("hidden1"):
2695        images = constant_op.constant(
2696            1.0, dtypes.float32, shape=[3, 2], name="images")
2697        weights1 = variables.VariableV1(
2698            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
2699        biases1 = variables.VariableV1([0.1] * 3, name="biases")
2700        nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
2701
2702    # Run the graph and save scoped checkpoint.
2703    with self.session(graph=graph1) as sess:
2704      self.evaluate(variables.global_variables_initializer())
2705      _, var_list_1 = meta_graph.export_scoped_meta_graph(
2706          graph_def=graph1.as_graph_def(), export_scope="hidden1")
2707      saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1)
2708      saver.save(sess, saver0_ckpt, write_state=False)
2709
2710    expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3))
2711
2712    # Verifies that we can run successfully after restoring.
2713    graph2 = ops_lib.Graph()
2714    new_var_list_1 = meta_graph.copy_scoped_meta_graph(
2715        from_scope="hidden1",
2716        to_scope="new_hidden1",
2717        from_graph=graph1,
2718        to_graph=graph2)
2719
2720    with self.session(graph=graph2) as sess:
2721      saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1)
2722      saver3.restore(sess, saver0_ckpt)
2723      self.assertAllClose(expected, sess.run("new_hidden1/relu:0"))
2724
2725  @test_util.run_deprecated_v1
2726  def testSerializeSaverWithScope(self):
2727    test_dir = self._get_test_dir("export_graph_def")
2728    saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
2729    saver2_ckpt = os.path.join(test_dir, "saver2.ckpt")
2730    graph = ops_lib.Graph()
2731    with graph.as_default():
2732      with ops_lib.name_scope("hidden1"):
2733        variable1 = variables.VariableV1([1.0], name="variable1")
2734        saver1 = saver_module.Saver(var_list=[variable1])
2735        graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
2736
2737      with ops_lib.name_scope("hidden2"):
2738        variable2 = variables.VariableV1([2.0], name="variable2")
2739      saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
2740      graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
2741
2742    with self.session(graph=graph) as sess:
2743      self.evaluate(variables.global_variables_initializer())
2744      saver1.save(sess, saver1_ckpt, write_state=False)
2745      saver2.save(sess, saver2_ckpt, write_state=False)
2746
2747    graph1 = ops_lib.Graph()
2748    var_dict1 = meta_graph.copy_scoped_meta_graph(
2749        from_scope="hidden1",
2750        to_scope="new_hidden1",
2751        from_graph=graph,
2752        to_graph=graph1)
2753    self.assertEqual(1, len(var_dict1))
2754
2755    saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS)
2756    self.assertEqual(1, len(saver_list1))
2757
2758    with self.session(graph=graph1) as sess:
2759      saver_list1[0].restore(sess, saver1_ckpt)
2760      self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"]))
2761
2762    graph2 = ops_lib.Graph()
2763    var_dict2 = meta_graph.copy_scoped_meta_graph(
2764        from_scope="hidden2",
2765        to_scope="new_hidden2",
2766        from_graph=graph,
2767        to_graph=graph2)
2768    self.assertEqual(1, len(var_dict2))
2769
2770    saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS)
2771    self.assertEqual(1, len(saver_list2))
2772
2773    with self.session(graph=graph2) as sess:
2774      saver_list2[0].restore(sess, saver2_ckpt)
2775      self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"]))
2776
2777
2778class _OwnsAVariableSimple(trackable_base.Trackable):
2779  """A Trackable object which can be saved using a tf.train.Saver."""
2780
2781  def __init__(self):
2782    self.non_dep_variable = variable_scope.get_variable(
2783        name="non_dep_variable", initializer=6., use_resource=True)
2784
2785  def _gather_saveables_for_checkpoint(self):
2786    return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
2787
2788  # The Saver sorts by name before parsing, so we need a name property.
2789  @property
2790  def name(self):
2791    return self.non_dep_variable.name
2792
2793
2794class _MirroringSaveable(
2795    saver_module.BaseSaverBuilder.ResourceVariableSaveable):
2796
2797  def __init__(self, primary_variable, mirrored_variable, name):
2798    self._primary_variable = primary_variable
2799    self._mirrored_variable = mirrored_variable
2800    super(_MirroringSaveable, self).__init__(
2801        self._primary_variable, "", name)
2802
2803  def restore(self, restored_tensors, restored_shapes):
2804    """Restore the same value into both variables."""
2805    tensor, = restored_tensors
2806    return control_flow_ops.group(
2807        self._primary_variable.assign(tensor),
2808        self._mirrored_variable.assign(tensor))
2809
2810
2811class _OwnsMirroredVariables(trackable_base.Trackable):
2812  """A Trackable object which returns a more complex SaveableObject."""
2813
2814  def __init__(self):
2815    self.non_dep_variable = variable_scope.get_variable(
2816        name="non_dep_variable", initializer=6., use_resource=True)
2817    self.mirrored = variable_scope.get_variable(
2818        name="mirrored", initializer=15., use_resource=True)
2819
2820  def _gather_saveables_for_checkpoint(self):
2821    def _saveable_factory(name=self.non_dep_variable.name):
2822      return _MirroringSaveable(
2823          primary_variable=self.non_dep_variable,
2824          mirrored_variable=self.mirrored,
2825          name=name)
2826    return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory}
2827
2828  # The Saver sorts by name before parsing, so we need a name property.
2829  @property
2830  def name(self):
2831    return self.non_dep_variable.name
2832
2833
2834class NonLayerTrackable(trackable_tracking.AutoTrackable):
2835
2836  def __init__(self):
2837    super(NonLayerTrackable, self).__init__()
2838    self.a_variable = trackable_utils.add_variable(
2839        self, name="a_variable", shape=[])
2840
2841
2842class MyModel(training.Model):
2843  """A concrete Model for testing."""
2844
2845  def __init__(self):
2846    super(MyModel, self).__init__()
2847    self._named_dense = core.Dense(1, use_bias=True)
2848    self._second = core.Dense(1, use_bias=False)
2849    # We can still track Trackables which aren't Layers.
2850    self._non_layer = NonLayerTrackable()
2851
2852  def call(self, values):
2853    ret = self._second(self._named_dense(values))
2854    return ret
2855
2856
2857class TrackableCompatibilityTests(test.TestCase):
2858
2859  # TODO(allenl): Track down python3 reference cycles in these tests.
2860  @test_util.run_in_graph_and_eager_modes
2861  def testNotSaveableButIsTrackable(self):
2862    v = _OwnsAVariableSimple()
2863    test_dir = self.get_temp_dir()
2864    prefix = os.path.join(test_dir, "ckpt")
2865    for saver in (saver_module.Saver(var_list=[v]),
2866                  saver_module.Saver(var_list={"v": v})):
2867      with self.cached_session() as sess:
2868        self.evaluate(v.non_dep_variable.assign(42.))
2869        save_path = saver.save(sess, prefix)
2870        self.evaluate(v.non_dep_variable.assign(43.))
2871        saver.restore(sess, save_path)
2872        self.assertEqual(42., self.evaluate(v.non_dep_variable))
2873
2874  @test_util.run_in_graph_and_eager_modes
2875  def testMoreComplexSaveableReturned(self):
2876    v = _OwnsMirroredVariables()
2877    test_dir = self.get_temp_dir()
2878    prefix = os.path.join(test_dir, "ckpt")
2879    self.evaluate(v.non_dep_variable.assign(42.))
2880    for saver in (saver_module.Saver(var_list=[v]),
2881                  saver_module.Saver(var_list={"v": v})):
2882      with self.cached_session() as sess:
2883        save_path = saver.save(sess, prefix)
2884        self.evaluate(v.non_dep_variable.assign(43.))
2885        self.evaluate(v.mirrored.assign(44.))
2886        saver.restore(sess, save_path)
2887        self.assertEqual(42., self.evaluate(v.non_dep_variable))
2888        self.assertEqual(42., self.evaluate(v.mirrored))
2889
2890  def testSingleTensorEvaluation(self):
2891
2892    class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):
2893
2894      def __init__(self, name):
2895        self.eval_count = 0
2896        def _tensor():
2897          self.eval_count += 1
2898          return constant_op.constant([1.])
2899        dummy_op = constant_op.constant([2.])
2900        super(_CountingSaveable, self).__init__(
2901            dummy_op,
2902            [saver_module.BaseSaverBuilder.SaveSpec(
2903                _tensor, "", name, dtype=dummy_op.dtype)],
2904            name)
2905
2906      def restore(self, restored_tensors, restored_shapes):
2907        """Restore the same value into both variables."""
2908        pass
2909
2910    with context.eager_mode():
2911      v = _CountingSaveable("foo")
2912      saver = saver_module.Saver(var_list=[v])
2913      test_dir = self.get_temp_dir()
2914      prefix = os.path.join(test_dir, "ckpt")
2915      with self.cached_session() as sess:
2916        save_path = saver.save(sess, prefix)
2917        self.assertEqual(1, v.eval_count)
2918        saver.restore(sess, save_path)
2919        self.assertEqual(1, v.eval_count)
2920
2921  def _initialized_model(self):
2922    input_value = constant_op.constant([[3.]])
2923    model = MyModel()
2924    optimizer = adam.AdamOptimizer(0.001)
2925    optimizer_step = training_util.get_or_create_global_step()
2926    root_trackable = trackable_utils.Checkpoint(
2927        optimizer=optimizer, model=model, optimizer_step=optimizer_step)
2928    train_op = optimizer.minimize(
2929        functools.partial(model, input_value),
2930        global_step=optimizer_step)
2931    self.evaluate(trackable_utils.gather_initializers(
2932        root_trackable))
2933    self.evaluate(train_op)
2934    # A regular variable, a slot variable, and a non-slot Optimizer variable
2935    # with known values to check when loading.
2936    self.evaluate(model._named_dense.bias.assign([1.]))
2937    self.evaluate(optimizer.get_slot(
2938        var=model._named_dense.bias, name="m").assign([2.]))
2939    beta1_power, _ = optimizer._get_beta_accumulators()
2940    self.evaluate(beta1_power.assign(3.))
2941    return root_trackable
2942
2943  def _set_sentinels(self, root_trackable):
2944    self.evaluate(root_trackable.model._named_dense.bias.assign([101.]))
2945    self.evaluate(
2946        root_trackable.optimizer.get_slot(
2947            var=root_trackable.model._named_dense.bias, name="m")
2948        .assign([102.]))
2949    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
2950    self.evaluate(beta1_power.assign(103.))
2951
2952  def _check_sentinels(self, root_trackable):
2953    self.assertAllEqual(
2954        [1.], self.evaluate(root_trackable.model._named_dense.bias))
2955    self.assertAllEqual([2.], self.evaluate(
2956        root_trackable.optimizer.get_slot(
2957            var=root_trackable.model._named_dense.bias, name="m")))
2958    beta1_power, _ = root_trackable.optimizer._get_beta_accumulators()
2959    self.assertAllEqual(3., self.evaluate(beta1_power))
2960
2961  def testVariableNotFoundErrorRaised(self):
2962    # Restore does some tricky exception handling to figure out if it should
2963    # load an object-based checkpoint. Tests that the exception handling isn't
2964    # too broad.
2965    checkpoint_directory = self.get_temp_dir()
2966    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
2967
2968    a = resource_variable_ops.ResourceVariable(1., name="a")
2969    b = resource_variable_ops.ResourceVariable(1., name="b")
2970    a_saver = saver_module.Saver([a])
2971    b_saver = saver_module.Saver([b])
2972    with self.cached_session() as sess:
2973      self.evaluate(a.initializer)
2974      save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
2975      with self.assertRaisesRegexp(
2976          errors.NotFoundError, "Key b not found in checkpoint"):
2977        b_saver.restore(sess=sess, save_path=save_path)
2978
2979      with self.assertRaises(errors.NotFoundError) as cs:
2980        b_saver.restore(sess=sess, save_path=save_path)
2981
2982      # Make sure we don't have a confusing "During handling of the above
2983      # exception" block in Python 3.
2984      self.assertNotIn("NewCheckpointReader", cs.exception.message)
2985
2986  @test_util.run_v1_only("b/120545219")
2987  def testGraphChangedForRestoreErrorRaised(self):
2988    checkpoint_directory = self.get_temp_dir()
2989    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
2990
2991    with ops_lib.Graph().as_default() as g:
2992      a = variables.VariableV1(1., name="a")
2993      a_saver = saver_module.Saver([a])
2994
2995      with self.session(graph=g) as sess:
2996        self.evaluate(a.initializer)
2997        save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
2998
2999    with ops_lib.Graph().as_default() as g:
3000      a = variables.VariableV1([1.], name="a")
3001      a_saver = saver_module.Saver([a])
3002      with self.session(graph=g) as sess:
3003        with self.assertRaisesRegexp(
3004            errors.InvalidArgumentError,
3005            "a mismatch between the current graph and the graph"):
3006          a_saver.restore(sess=sess, save_path=save_path)
3007
3008  def testLoadFromObjectBasedGraph(self):
3009    checkpoint_directory = self.get_temp_dir()
3010    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
3011
3012    save_graph = ops_lib.Graph()
3013    with save_graph.as_default(), self.session(graph=save_graph) as sess:
3014      root = self._initialized_model()
3015      object_saver = trackable_utils.Checkpoint(root=root)
3016      save_path = object_saver.save(file_prefix=checkpoint_prefix)
3017
3018      # An incompatible object-based checkpoint to check error messages
3019      var = resource_variable_ops.ResourceVariable(1., name="a")
3020      self.evaluate(var.initializer)
3021      second_saver = trackable_utils.Checkpoint(v=var)
3022      second_path = second_saver.save(file_prefix=os.path.join(
3023          checkpoint_directory, "second"))
3024
3025    restore_graph = ops_lib.Graph()
3026    with restore_graph.as_default(), self.session(
3027        graph=restore_graph) as sess:
3028      root = self._initialized_model()
3029      self._set_sentinels(root)
3030      saver = saver_module.Saver()
3031      saver.restore(sess=sess, save_path=save_path)
3032      self._check_sentinels(root)
3033      before_second_restore_ops = restore_graph.get_operations()
3034      # Test that multiple restores do not pollute the graph
3035      saver.restore(sess=sess, save_path=save_path)
3036      self.assertEqual(before_second_restore_ops,
3037                       restore_graph.get_operations())
3038      with self.assertRaisesRegexp(errors.NotFoundError,
3039                                   "Could not find some variables"):
3040        saver.restore(sess=sess, save_path=second_path)
3041
3042  def testLoadFromObjectBasedEager(self):
3043    checkpoint_directory = self.get_temp_dir()
3044    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
3045
3046    save_graph = ops_lib.Graph()
3047    with save_graph.as_default(), self.session(graph=save_graph):
3048      root = self._initialized_model()
3049      object_saver = trackable_utils.Checkpoint(root=root)
3050      save_path = object_saver.save(file_prefix=checkpoint_prefix)
3051
3052    with context.eager_mode():
3053      root = self._initialized_model()
3054      self._set_sentinels(root)
3055      saver = saver_module.Saver(
3056          root.model.variables + root.optimizer.variables())
3057      saver.restore(sess=None, save_path=save_path)
3058      self._check_sentinels(root)
3059
3060
3061if __name__ == "__main__":
3062  test.main()
3063