1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tensorflow.python.training.saver.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import math 23import os 24import random 25import time 26 27import numpy as np 28import six 29 30from google.protobuf.any_pb2 import Any 31 32from tensorflow.core.protobuf import config_pb2 33from tensorflow.core.protobuf import meta_graph_pb2 34from tensorflow.core.protobuf import queue_runner_pb2 35from tensorflow.core.protobuf import rewriter_config_pb2 36from tensorflow.core.protobuf import saver_pb2 37from tensorflow.python import pywrap_tensorflow 38from tensorflow.python.client import session 39from tensorflow.python.data.ops import dataset_ops 40from tensorflow.python.eager import context 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import errors 44from tensorflow.python.framework import errors_impl 45from tensorflow.python.framework import function 46from tensorflow.python.framework import graph_io 47from tensorflow.python.framework import meta_graph 48from tensorflow.python.framework import ops as ops_lib 49from tensorflow.python.framework import test_util 50from tensorflow.python.keras.engine import training 51from tensorflow.python.keras.layers import core 52from tensorflow.python.lib.io import file_io 53from tensorflow.python.ops import array_ops 54from tensorflow.python.ops import control_flow_ops 55from tensorflow.python.ops import data_flow_ops 56from tensorflow.python.ops import gradients_impl 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import nn_ops 59from tensorflow.python.ops import partitioned_variables 60from tensorflow.python.ops import random_ops 61from tensorflow.python.ops import resource_variable_ops 62from tensorflow.python.ops import sparse_ops 63from tensorflow.python.ops import variable_scope 64from tensorflow.python.ops import variables 65import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 66from tensorflow.python.platform import gfile 67from tensorflow.python.platform import test 68from tensorflow.python.summary import summary 69from tensorflow.python.training import adam 70from tensorflow.python.training import checkpoint_management 71from tensorflow.python.training import gradient_descent 72from tensorflow.python.training import queue_runner_impl 73from tensorflow.python.training import saver as saver_module 74from tensorflow.python.training import saver_test_utils 75from tensorflow.python.training import training_util 76from tensorflow.python.training.tracking import base as trackable_base 77from tensorflow.python.training.tracking import tracking as trackable_tracking 78from tensorflow.python.training.tracking import util as trackable_utils 79from tensorflow.python.util import compat 80 81 82class SaverTest(test.TestCase): 83 84 def basicSaveRestore(self, variable_op): 85 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") 86 87 with self.session(graph=ops_lib.Graph()) as sess: 88 # Build a graph with 2 parameter nodes, and Save and 89 # Restore nodes for them. 90 v0 = variable_op(10.0, name="v0") 91 v1 = variable_op(20.0, name="v1") 92 v2 = saver_test_utils.CheckpointedOp(name="v2") 93 v2_init = v2.insert("k1", 30.0) 94 95 # Initialize all variables 96 if not context.executing_eagerly(): 97 self.evaluate([variables.global_variables_initializer(), v2_init]) 98 99 # Check that the parameter nodes have been initialized. 100 self.assertEqual(10.0, self.evaluate(v0)) 101 self.assertEqual(20.0, self.evaluate(v1)) 102 self.assertEqual(b"k1", self.evaluate(v2.keys())) 103 self.assertEqual(30.0, self.evaluate(v2.values())) 104 105 # Save the initialized values in the file at "save_path" 106 save = saver_module.Saver( 107 { 108 "v0": v0, 109 "v1": v1, 110 "v2": v2.saveable 111 }, restore_sequentially=True) 112 val = save.save(sess, save_path) 113 self.assertTrue(isinstance(val, six.string_types)) 114 self.assertEqual(save_path, val) 115 116 # Start a second session. In that session the parameter nodes 117 # have not been initialized either. 118 with self.session(graph=ops_lib.Graph()) as sess: 119 v0 = variable_op(-1.0, name="v0") 120 v1 = variable_op(-1.0, name="v1") 121 v2 = saver_test_utils.CheckpointedOp(name="v2") 122 123 # Assert that the variables are not initialized. 124 if not context.executing_eagerly(): 125 self.assertEqual( 126 len(variables.report_uninitialized_variables().eval()), 2) 127 self.assertEqual(0, len(self.evaluate(v2.keys()))) 128 self.assertEqual(0, len(self.evaluate(v2.values()))) 129 # Restore the saved values in the parameter nodes. 130 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 131 save.restore(sess, save_path) 132 # Check that the parameter nodes have been restored. 133 self.assertEqual(10.0, self.evaluate(v0)) 134 self.assertEqual(20.0, self.evaluate(v1)) 135 self.assertEqual(b"k1", self.evaluate(v2.keys())) 136 self.assertEqual(30.0, self.evaluate(v2.values())) 137 138 # Build another graph with 2 nodes, initialized 139 # differently, and a Restore node for them. 140 with self.session(graph=ops_lib.Graph()) as sess: 141 v0_2 = variable_op(1000.0, name="v0") 142 v1_2 = variable_op(2000.0, name="v1") 143 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 144 v2_init = v2_2.insert("k1000", 3000.0) 145 146 # Check that the parameter nodes have been initialized. 147 if not context.executing_eagerly(): 148 init_all_op = [variables.global_variables_initializer(), v2_init] 149 self.evaluate(init_all_op) 150 # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty 151 # table as it claims in eager mode? 152 self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) 153 self.assertEqual(3000.0, self.evaluate(v2_2.values())) 154 self.assertEqual(1000.0, self.evaluate(v0_2)) 155 self.assertEqual(2000.0, self.evaluate(v1_2)) 156 157 # Restore the values saved earlier in the parameter nodes. 158 save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable}) 159 save2.restore(sess, save_path) 160 # Check that the parameter nodes have been restored. 161 self.assertEqual(10.0, self.evaluate(v0_2)) 162 self.assertEqual(20.0, self.evaluate(v1_2)) 163 self.assertEqual(b"k1", self.evaluate(v2_2.keys())) 164 self.assertEqual(30.0, self.evaluate(v2_2.values())) 165 166 def testBasic(self): 167 self.basicSaveRestore(variables.Variable) 168 169 @test_util.run_in_graph_and_eager_modes 170 def testResourceBasic(self): 171 self.basicSaveRestore(resource_variable_ops.ResourceVariable) 172 173 @test_util.run_deprecated_v1 174 def testResourceColocation(self): 175 partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) 176 with ops_lib.device("/job:ps/device:GPU:0"): 177 v = variable_scope.get_variable("v0", 178 shape=[10, 2], 179 partitioner=partitioner, 180 use_resource=True) 181 saver_module.Saver({"v0": v}).build() 182 save_op = None 183 for op in ops_lib.get_default_graph().get_operations(): 184 if op.type == "SaveV2": 185 save_op = op 186 break 187 assert save_op is not None 188 for save_inp in save_op.inputs[3:]: 189 # Input to SaveV2 op is placed on CPU of the same device as the Variable. 190 self.assertEqual("/job:ps/device:CPU:0", save_inp.device) 191 192 def testResourceVariableReadOpsAddedDeterministically(self): 193 graph_defs = [] 194 num_graphs = 10 195 for _ in range(num_graphs): 196 with ops_lib.Graph().as_default() as g: 197 for i in range(20): 198 resource_variable_ops.ResourceVariable(i, name="var%s" % i) 199 saver_module.Saver() 200 graph_defs.append(g.as_graph_def()) 201 for i in range(num_graphs - 1): 202 self.assertEqual(graph_defs[i], graph_defs[i + 1]) 203 204 def testEagerBasic(self): 205 with context.eager_mode(): 206 ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt") 207 208 v1 = resource_variable_ops.ResourceVariable(3.14, name="v1") 209 v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2") 210 save = saver_module.Saver([v1, v2]) 211 save.save(None, ckpt_prefix) 212 213 v1.assign(0.0) 214 v2.assign([0, 0]) 215 self.assertNear(0.0, self.evaluate(v1), 1e-5) 216 self.assertAllEqual([0, 0], self.evaluate(v2)) 217 218 save.restore(None, ckpt_prefix) 219 self.assertNear(3.14, self.evaluate(v1), 1e-5) 220 self.assertAllEqual([1, 2], self.evaluate(v2)) 221 222 def testEagerGraphCompatibility(self): 223 # Save from graph mode and restore from eager mode. 224 graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt") 225 with context.graph_mode(): 226 with self.session(graph=ops_lib.Graph()) as sess: 227 # Create a graph model and save the checkpoint. 228 w1 = resource_variable_ops.ResourceVariable(1.0, name="w1") 229 w2 = resource_variable_ops.ResourceVariable(2.0, name="w2") 230 graph_saver = saver_module.Saver([w1, w2]) 231 self.evaluate(variables.global_variables_initializer()) 232 graph_saver.save(sess, graph_ckpt_prefix) 233 234 with context.eager_mode(): 235 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 236 ops_lib.reset_default_graph() 237 238 w1 = resource_variable_ops.ResourceVariable(0.0, name="w1") 239 w2 = resource_variable_ops.ResourceVariable(0.0, name="w2") 240 241 graph_saver = saver_module.Saver([w1, w2]) 242 graph_saver.restore(None, graph_ckpt_prefix) 243 244 self.assertAllEqual(self.evaluate(w1), 1.0) 245 self.assertAllEqual(self.evaluate(w2), 2.0) 246 247 # Save from eager mode and restore from graph mode. 248 eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt") 249 with context.eager_mode(): 250 ops_lib._default_graph_stack.reset() # pylint: disable=protected-access 251 ops_lib.reset_default_graph() 252 253 w3 = resource_variable_ops.ResourceVariable(3.0, name="w3") 254 w4 = resource_variable_ops.ResourceVariable(4.0, name="w4") 255 256 graph_saver = saver_module.Saver([w3, w4]) 257 graph_saver.save(None, eager_ckpt_prefix) 258 259 with context.graph_mode(): 260 with self.session(graph=ops_lib.Graph()) as sess: 261 w3 = resource_variable_ops.ResourceVariable(0.0, name="w3") 262 w4 = resource_variable_ops.ResourceVariable(0.0, name="w4") 263 graph_saver = saver_module.Saver([w3, w4]) 264 self.evaluate(variables.global_variables_initializer()) 265 graph_saver.restore(sess, eager_ckpt_prefix) 266 self.assertAllEqual(w3.eval(), 3.0) 267 self.assertAllEqual(w4.eval(), 4.0) 268 269 @test_util.run_in_graph_and_eager_modes 270 def testResourceSaveRestoreCachingDevice(self): 271 save_path = os.path.join(self.get_temp_dir(), "resource_cache") 272 with self.session(graph=ops_lib.Graph()) as sess: 273 v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0", 274 name="v") 275 if context.executing_eagerly(): 276 sess = None 277 else: 278 self.evaluate(variables.global_variables_initializer()) 279 save = saver_module.Saver([v]) 280 save.save(sess, save_path) 281 282 save2 = saver_module.Saver([v]) 283 save2.restore(sess, save_path) 284 self.assertEquals(self.evaluate(v), [1]) 285 286 def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self): 287 with ops_lib.Graph().as_default() as g: 288 v = resource_variable_ops.ResourceVariable(1.0, name="v") 289 with ops_lib.name_scope("saver1"): 290 saver_module.Saver() 291 with ops_lib.name_scope("saver2"): 292 saver_module.Saver({"name": v}) 293 ops_in_saver1_scope_but_not_save_scope = [ 294 op for op in g.get_operations() 295 if (op.name.startswith("saver1/") and 296 not op.name.startswith("saver1/save/"))] 297 self.assertEqual(ops_in_saver1_scope_but_not_save_scope, []) 298 ops_in_saver2_scope_but_not_save_scope = [ 299 op for op in g.get_operations() 300 if (op.name.startswith("saver2/") and 301 not op.name.startswith("saver2/save/"))] 302 self.assertEqual(ops_in_saver2_scope_but_not_save_scope, []) 303 304 @test_util.run_deprecated_v1 305 def testSaveCopyRestoreWithSaveRelativePaths(self): 306 """Save, copy checkpoint dir and restore from copied dir. 307 308 This only works for save_relative_paths=True. 309 """ 310 save_dir1 = os.path.join(self.get_temp_dir(), "save_dir1") 311 os.mkdir(save_dir1) 312 save_path1 = os.path.join(save_dir1, "save_copy_restore") 313 314 # Build a graph with 2 parameter nodes, and Save and 315 # Restore nodes for them. 316 v0 = variables.VariableV1(10.0, name="v0") 317 v1 = variables.VariableV1(20.0, name="v1") 318 v2 = saver_test_utils.CheckpointedOp(name="v2") 319 v2_init = v2.insert("k1", 30.0) 320 save = saver_module.Saver( 321 var_list={ 322 "v0": v0, 323 "v1": v1, 324 "v2": v2.saveable}, 325 restore_sequentially=True, 326 save_relative_paths=True) 327 init_all_op = [variables.global_variables_initializer(), v2_init] 328 329 with self.cached_session() as sess: 330 # Initialize all variables 331 self.evaluate(init_all_op) 332 333 # Check that the parameter nodes have been initialized. 334 self.assertEqual(10.0, self.evaluate(v0)) 335 self.assertEqual(20.0, self.evaluate(v1)) 336 self.assertEqual(b"k1", self.evaluate(v2.keys())) 337 self.assertEqual(30.0, self.evaluate(v2.values())) 338 339 # Save the initialized values in the file at "save_path" 340 val = save.save(sess, save_path1) 341 self.assertTrue(isinstance(val, six.string_types)) 342 self.assertEqual(save_path1, val) 343 344 self.assertEqual( 345 checkpoint_management.latest_checkpoint(save_dir1), save_path1) 346 save_dir2 = os.path.join(self.get_temp_dir(), "save_dir2") 347 os.renames(save_dir1, save_dir2) 348 save_path2 = os.path.join(save_dir2, "save_copy_restore") 349 self.assertEqual( 350 checkpoint_management.latest_checkpoint(save_dir2), save_path2) 351 352 # Start a second session. In that session the parameter nodes 353 # have not been initialized either. 354 with self.cached_session() as sess: 355 v0 = variables.VariableV1(-1.0, name="v0") 356 v1 = variables.VariableV1(-1.0, name="v1") 357 v2 = saver_test_utils.CheckpointedOp(name="v2") 358 save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) 359 360 # Assert that the variables are not initialized. 361 self.assertEqual( 362 len(variables.report_uninitialized_variables().eval()), 2) 363 self.assertEqual(0, len(self.evaluate(v2.keys()))) 364 self.assertEqual(0, len(self.evaluate(v2.values()))) 365 366 # Restore the saved values in the parameter nodes. 367 save.restore(sess, save_path2) 368 # Check that the parameter nodes have been restored. 369 self.assertEqual(10.0, self.evaluate(v0)) 370 self.assertEqual(20.0, self.evaluate(v1)) 371 self.assertEqual(b"k1", self.evaluate(v2.keys())) 372 self.assertEqual(30.0, self.evaluate(v2.values())) 373 374 @test_util.run_deprecated_v1 375 def testFilenameTensor(self): 376 v0 = variables.VariableV1(0, name="v0") 377 filename = b"somerandomfilename" 378 save = saver_module.Saver({"v0": v0}, filename=filename) 379 with self.cached_session() as sess: 380 tensor = sess.graph.get_tensor_by_name( 381 save.saver_def.filename_tensor_name) 382 self.assertEqual(self.evaluate(tensor), filename) 383 384 def testInvalidPath(self): 385 v0 = variables.VariableV1(0, name="v0") 386 for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): 387 with self.cached_session() as sess: 388 save = saver_module.Saver({"v0": v0}, write_version=ver) 389 with self.assertRaisesRegexp( 390 ValueError, "The passed save_path is not a valid checkpoint:"): 391 save.restore(sess, "invalid path") 392 393 @test_util.run_v1_only("b/120545219") 394 def testInt64(self): 395 save_path = os.path.join(self.get_temp_dir(), "int64") 396 397 with self.cached_session() as sess: 398 # Build a graph with 1 node, and save and restore for them. 399 v = variables.VariableV1(np.int64(15), name="v") 400 save = saver_module.Saver({"v": v}, restore_sequentially=True) 401 self.evaluate(variables.global_variables_initializer()) 402 403 # Save the initialized values in the file at "save_path" 404 val = save.save(sess, save_path) 405 self.assertTrue(isinstance(val, six.string_types)) 406 self.assertEqual(save_path, val) 407 408 with self.cached_session() as sess: 409 v = variables.VariableV1(np.int64(-1), name="v") 410 save = saver_module.Saver({"v": v}) 411 412 with self.assertRaisesWithPredicateMatch( 413 errors_impl.OpError, lambda e: "uninitialized value v" in e.message): 414 self.evaluate(v) 415 416 # Restore the saved values in the parameter nodes. 417 save.restore(sess, save_path) 418 # Check that the parameter nodes have been restored. 419 self.assertEqual(np.int64(15), self.evaluate(v)) 420 421 def testSomeErrors(self): 422 with ops_lib.Graph().as_default(): 423 v0 = variables.VariableV1([10.0], name="v0") 424 v1 = variables.VariableV1([20.0], name="v1") 425 v2 = variables.VariableV1([20.0], name="v2") 426 v2._set_save_slice_info( 427 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 428 429 # By default the name used for "v2" will be "v1" and raise an error. 430 with self.assertRaisesRegexp(ValueError, "same name: v1"): 431 saver_module.Saver([v0, v1, v2]) 432 433 # The names are different and will work. 434 saver_module.Saver({"vee1": v1, "other": [v2]}) 435 436 # Partitioned variables also cause name conflicts. 437 p_v1 = variable_scope.get_variable( 438 "p_v1", 439 shape=[4, 5], 440 partitioner=partitioned_variables.fixed_size_partitioner( 441 num_shards=2)) 442 p_v2 = variable_scope.get_variable( 443 "p_v2", 444 shape=[4, 5], 445 partitioner=partitioned_variables.fixed_size_partitioner( 446 num_shards=2)) 447 p_v2._name = "p_v1" 448 with self.assertRaisesRegexp(ValueError, "same name: p_v1"): 449 saver_module.Saver([p_v1, p_v2]) 450 451 def testSameName(self): 452 with ops_lib.Graph().as_default(): 453 v0 = variables.VariableV1([10.0], name="v0") 454 v2 = saver_test_utils.CheckpointedOp(name="v2") 455 456 # Saving one variable under two names raises an error. 457 with self.assertRaisesRegexp( 458 ValueError, "The same saveable will be restored with two names: v0"): 459 saver_module.Saver({"v0": v0, "v0too": v0}) 460 461 # Ditto for custom saveables. 462 with self.assertRaisesRegexp( 463 ValueError, "The same saveable will be restored with two names: v2"): 464 saver_module.Saver({"v2": v2.saveable, "v2too": v2.saveable}) 465 466 # Verify non-duplicate names work. 467 saver_module.Saver({"v0": v0, "v2": v2.saveable}) 468 469 @test_util.run_v1_only("b/120545219") 470 def testBasicsWithListOfVariables(self): 471 save_path = os.path.join(self.get_temp_dir(), "basics_with_list") 472 473 with self.session(graph=ops_lib.Graph()) as sess: 474 # Build a graph with 2 parameter nodes, and Save and 475 # Restore nodes for them. 476 v0 = variables.VariableV1(10.0, name="v0") 477 v1 = variables.VariableV1(20.0, name="v1") 478 v2 = saver_test_utils.CheckpointedOp(name="v2") 479 v2_init = v2.insert("k1", 30.0) 480 save = saver_module.Saver([v0, v1, v2.saveable]) 481 self.evaluate(variables.global_variables_initializer()) 482 v2_init.run() 483 484 # Check that the parameter nodes have been initialized. 485 self.assertEqual(10.0, self.evaluate(v0)) 486 self.assertEqual(20.0, self.evaluate(v1)) 487 self.assertEqual(b"k1", self.evaluate(v2.keys())) 488 self.assertEqual(30.0, self.evaluate(v2.values())) 489 490 # Save the initialized values in the file at "save_path" 491 val = save.save(sess, save_path) 492 self.assertTrue(isinstance(val, six.string_types)) 493 self.assertEqual(save_path, val) 494 495 # Start a second session. In that session the variables 496 # have not been initialized either. 497 with self.session(graph=ops_lib.Graph()) as sess: 498 v0 = variables.VariableV1(-1.0, name="v0") 499 v1 = variables.VariableV1(-1.0, name="v1") 500 v2 = saver_test_utils.CheckpointedOp(name="v2") 501 save = saver_module.Saver([v0, v1, v2.saveable]) 502 503 with self.assertRaisesWithPredicateMatch( 504 errors_impl.OpError, lambda e: "uninitialized value v0" in e.message): 505 self.evaluate(v0) 506 with self.assertRaisesWithPredicateMatch( 507 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 508 self.evaluate(v1) 509 self.assertEqual(0, len(self.evaluate(v2.keys()))) 510 self.assertEqual(0, len(self.evaluate(v2.values()))) 511 512 # Restore the saved values in the parameter nodes. 513 save.restore(sess, save_path) 514 # Check that the parameter nodes have been restored. 515 self.assertEqual(10.0, self.evaluate(v0)) 516 self.assertEqual(20.0, self.evaluate(v1)) 517 self.assertEqual(b"k1", self.evaluate(v2.keys())) 518 self.assertEqual(30.0, self.evaluate(v2.values())) 519 520 # Build another graph with 2 nodes, initialized 521 # differently, and a Restore node for them. 522 with self.session(graph=ops_lib.Graph()) as sess: 523 v0_2 = variables.VariableV1(1000.0, name="v0") 524 v1_2 = variables.VariableV1(2000.0, name="v1") 525 v2_2 = saver_test_utils.CheckpointedOp(name="v2") 526 save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable]) 527 v2_2.insert("k1000", 3000.0).run() 528 self.evaluate(variables.global_variables_initializer()) 529 530 # Check that the parameter nodes have been initialized. 531 self.assertEqual(1000.0, self.evaluate(v0_2)) 532 self.assertEqual(2000.0, self.evaluate(v1_2)) 533 self.assertEqual(b"k1000", self.evaluate(v2_2.keys())) 534 self.assertEqual(3000.0, self.evaluate(v2_2.values())) 535 # Restore the values saved earlier in the parameter nodes. 536 save2.restore(sess, save_path) 537 # Check that the parameter nodes have been restored. 538 self.assertEqual(10.0, self.evaluate(v0_2)) 539 self.assertEqual(20.0, self.evaluate(v1_2)) 540 self.assertEqual(b"k1", self.evaluate(v2_2.keys())) 541 self.assertEqual(30.0, self.evaluate(v2_2.values())) 542 543 def _SaveAndLoad(self, var_name, var_value, other_value, save_path): 544 with self.session(graph=ops_lib.Graph()) as sess: 545 var = resource_variable_ops.ResourceVariable(var_value, name=var_name) 546 save = saver_module.Saver({var_name: var}) 547 if not context.executing_eagerly(): 548 self.evaluate(var.initializer) 549 val = save.save(sess, save_path) 550 self.assertEqual(save_path, val) 551 with self.session(graph=ops_lib.Graph()) as sess: 552 var = resource_variable_ops.ResourceVariable(other_value, name=var_name) 553 save = saver_module.Saver({var_name: var}) 554 save.restore(sess, save_path) 555 self.assertAllClose(var_value, self.evaluate(var)) 556 557 def testCacheRereadsFile(self): 558 save_path = os.path.join(self.get_temp_dir(), "cache_rereads") 559 # Save and reload one Variable named "var0". 560 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 561 # Save and reload one Variable named "var1" in the same file. 562 # The cached readers should know to re-read the file. 563 self._SaveAndLoad("var1", 1.1, 2.2, save_path) 564 565 @test_util.run_deprecated_v1 566 def testAllowEmpty(self): 567 save_path = os.path.join(self.get_temp_dir(), "allow_empty") 568 with self.cached_session() as sess: 569 _ = constant_op.constant(1) 570 save = saver_module.Saver(allow_empty=True) 571 val = save.save(sess, save_path) 572 self.assertIsNone(val) 573 with self.cached_session() as sess: 574 save = saver_module.Saver(allow_empty=True) 575 save.restore(sess, save_path) 576 577 def testGPU(self): 578 if not test.is_gpu_available(): 579 return 580 save_path = os.path.join(self.get_temp_dir(), "gpu") 581 with session.Session("", graph=ops_lib.Graph()) as sess: 582 with sess.graph.device(test.gpu_device_name()): 583 v0_1 = variables.VariableV1(123.45) 584 save = saver_module.Saver({"v0": v0_1}) 585 self.evaluate(variables.global_variables_initializer()) 586 save.save(sess, save_path) 587 588 with session.Session("", graph=ops_lib.Graph()) as sess: 589 with sess.graph.device(test.gpu_device_name()): 590 v0_2 = variables.VariableV1(543.21) 591 save = saver_module.Saver({"v0": v0_2}) 592 self.evaluate(variables.global_variables_initializer()) 593 594 def testSharedServerOnGPU(self): 595 if not test.is_gpu_available(): 596 return 597 save_path = os.path.join(self.get_temp_dir(), "gpu") 598 with session.Session("", graph=ops_lib.Graph()) as sess: 599 with sess.graph.device(test.gpu_device_name()): 600 v0_1 = variables.VariableV1(123.45) 601 save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True) 602 self.evaluate(variables.global_variables_initializer()) 603 save.save(sess, save_path) 604 605 with session.Session("", graph=ops_lib.Graph()) as sess: 606 with sess.graph.device(test.gpu_device_name()): 607 v0_2 = variables.VariableV1(543.21) 608 save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True) 609 self.evaluate(variables.global_variables_initializer()) 610 611 def testVariables(self): 612 save_path = os.path.join(self.get_temp_dir(), "variables") 613 with session.Session("", graph=ops_lib.Graph()) as sess: 614 one = variables.VariableV1(1.0) 615 twos = variables.VariableV1([2.0, 2.0, 2.0]) 616 v2 = saver_test_utils.CheckpointedOp(name="v2") 617 init = variables.global_variables_initializer() 618 save = saver_module.Saver() 619 init.run() 620 v2.insert("k1", 3.0).run() 621 save.save(sess, save_path) 622 623 with session.Session("", graph=ops_lib.Graph()) as sess: 624 one = variables.VariableV1(0.0) 625 twos = variables.VariableV1([0.0, 0.0, 0.0]) 626 v2 = saver_test_utils.CheckpointedOp(name="v2") 627 # Saver with no arg, defaults to 'all variables'. 628 save = saver_module.Saver() 629 save.restore(sess, save_path) 630 self.assertAllClose(1.0, self.evaluate(one)) 631 self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) 632 self.assertEqual(b"k1", self.evaluate(v2.keys())) 633 self.assertEqual(3.0, self.evaluate(v2.values())) 634 635 def testVarListShouldBeEmptyInDeferredBuild(self): 636 with ops_lib.Graph().as_default(): 637 v = variables.VariableV1(1.0) 638 with self.assertRaisesRegexp(ValueError, "defer_build"): 639 saver_module.Saver([v], defer_build=True) 640 641 def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self): 642 save_path = os.path.join(self.get_temp_dir(), "error_deferred_build") 643 with ops_lib.Graph().as_default(), session.Session() as sess: 644 variables.VariableV1(1.0) 645 saver = saver_module.Saver(defer_build=True) 646 with self.assertRaisesRegexp(RuntimeError, "build"): 647 saver.save(sess, save_path) 648 649 def testDeferredBuild(self): 650 save_path = os.path.join(self.get_temp_dir(), "deferred_build") 651 with session.Session("", graph=ops_lib.Graph()) as sess: 652 one = variables.VariableV1(1.0) 653 save = saver_module.Saver(defer_build=True) 654 # if build is not deferred, saver cannot save the `twos`. 655 twos = variables.VariableV1([2.0, 2.0, 2.0]) 656 init = variables.global_variables_initializer() 657 save.build() 658 init.run() 659 save.save(sess, save_path) 660 661 with session.Session("", graph=ops_lib.Graph()) as sess: 662 one = variables.VariableV1(0.0) 663 twos = variables.VariableV1([0.0, 0.0, 0.0]) 664 # Saver with no arg, defaults to 'all variables'. 665 save = saver_module.Saver() 666 save.restore(sess, save_path) 667 self.assertAllClose(1.0, self.evaluate(one)) 668 self.assertAllClose([2.0, 2.0, 2.0], self.evaluate(twos)) 669 670 @test_util.run_v1_only("b/120545219") 671 def testReshape(self): 672 save_path = os.path.join(self.get_temp_dir(), "variables_reshape") 673 with session.Session("", graph=ops_lib.Graph()) as sess: 674 var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 675 init = variables.global_variables_initializer() 676 save = saver_module.Saver() 677 init.run() 678 save.save(sess, save_path) 679 680 # Error when restoring with default reshape=False 681 with session.Session("", graph=ops_lib.Graph()) as sess: 682 var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 683 save = saver_module.Saver() 684 with self.assertRaisesRegexp( 685 errors_impl.InvalidArgumentError, 686 "Assign requires shapes of both tensors to match."): 687 save.restore(sess, save_path) 688 689 # Restored to new shape with reshape=True 690 with session.Session("", graph=ops_lib.Graph()) as sess: 691 var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) 692 save = saver_module.Saver(reshape=True) 693 save.restore(sess, save_path) 694 self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], 695 self.evaluate(var)) 696 697 @test_util.run_in_graph_and_eager_modes 698 def testSaveWithGlobalStep(self, pad_step_number=False): 699 save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step") 700 global_step_int = 5 701 # Save and reload one Variable named "var0". 702 self._SaveAndLoad("var0", 0.0, 1.0, save_path) 703 for use_tensor in [True, False]: 704 with self.session(graph=ops_lib.Graph()): 705 var = resource_variable_ops.ResourceVariable(1.0, name="var0") 706 save = saver_module.Saver( 707 { 708 var._shared_name: var 709 }, pad_step_number=pad_step_number) 710 if context.executing_eagerly(): 711 sess = None 712 else: 713 self.evaluate(var.initializer) 714 sess = ops_lib.get_default_session() 715 if use_tensor: 716 global_step = constant_op.constant(global_step_int) 717 val = save.save(sess, save_path, global_step=global_step) 718 else: 719 val = save.save(sess, save_path, global_step=global_step_int) 720 if pad_step_number: 721 expected_save_path = "%s-%s" % (save_path, 722 "{:08d}".format(global_step_int)) 723 else: 724 expected_save_path = "%s-%d" % (save_path, global_step_int) 725 self.assertEqual(expected_save_path, val) 726 727 def testSaveWithGlobalStepWithPadding(self): 728 self.testSaveWithGlobalStep(pad_step_number=True) 729 730 def testSaveToNonexistingPath(self): 731 file_io.write_string_to_file( 732 os.path.join(self.get_temp_dir(), "actually_a_file"), "") 733 paths = [ 734 os.path.join(self.get_temp_dir(), "nonexisting_dir/path"), 735 os.path.join(self.get_temp_dir(), "other_nonexisting_dir/path1/path2"), 736 os.path.join(self.get_temp_dir(), "actually_a_file/path"), 737 ] 738 739 for save_path in paths: 740 # Build a graph with 2 parameter nodes, and Save and 741 # Restore nodes for them. 742 v0 = variables.VariableV1(10.0, name="v0") 743 v1 = variables.VariableV1(20.0, name="v1") 744 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 745 init_all_op = variables.global_variables_initializer() 746 747 # In the case where the parent directory doesn't exist, whether or not the 748 # save succeeds or fails is implementation dependent. Therefore we allow 749 # both cases. 750 try: 751 with self.cached_session() as sess: 752 # Initialize all variables 753 self.evaluate(init_all_op) 754 755 # Check that the parameter nodes have been initialized. 756 self.assertEqual(10.0, self.evaluate(v0)) 757 self.assertEqual(20.0, self.evaluate(v1)) 758 759 # Save the graph. 760 save.save(sess, save_path) 761 762 with self.cached_session() as sess: 763 # Restore the saved values in the parameter nodes. 764 save.restore(sess, save_path) 765 # Check that the parameter nodes have been restored. 766 self.assertEqual(10.0, self.evaluate(v0)) 767 self.assertEqual(20.0, self.evaluate(v1)) 768 except ValueError as exc: 769 error_msg_template = "Parent directory of {} doesn't exist, can't save." 770 self.assertEqual(error_msg_template.format(save_path), str(exc)) 771 772 def testSaveToURI(self): 773 # ParseURI functions don't work on Windows yet. 774 # TODO(jhseu): Remove this check when it works. 775 if os.name == "nt": 776 self.skipTest("Local URI support doesn't work on Windows") 777 save_path = "file://" + os.path.join(self.get_temp_dir(), "uri") 778 779 # Build a graph with 2 parameter nodes, and Save and 780 # Restore nodes for them. 781 v0 = variables.VariableV1(10.0, name="v0") 782 v1 = variables.VariableV1(20.0, name="v1") 783 save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True) 784 init_all_op = variables.global_variables_initializer() 785 786 with self.cached_session() as sess: 787 # Initialize all variables 788 self.evaluate(init_all_op) 789 790 # Check that the parameter nodes have been initialized. 791 self.assertEqual(10.0, self.evaluate(v0)) 792 self.assertEqual(20.0, self.evaluate(v1)) 793 save.save(sess, save_path) 794 795 def testSaveRestoreAndValidateVariableDtype(self): 796 for variable_op in [ 797 variables.Variable, resource_variable_ops.ResourceVariable 798 ]: 799 save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") 800 801 # Build the first session. 802 with self.session(graph=ops_lib.Graph()) as sess: 803 v0 = variable_op(10.0, name="v0", dtype=dtypes.float32) 804 805 if not context.executing_eagerly(): 806 self.evaluate([variables.global_variables_initializer()]) 807 808 save = saver_module.Saver({"v0": v0}) 809 save.save(sess, save_path) 810 811 # Start a second session. 812 with self.session(graph=ops_lib.Graph()) as sess: 813 v0_wrong_dtype = variable_op(1, name="v0", dtype=dtypes.int32) 814 # Restore the saved value with different dtype 815 # in the parameter nodes. 816 save = saver_module.Saver({"v0": v0_wrong_dtype}) 817 with self.assertRaisesRegexp(errors.InvalidArgumentError, 818 "original dtype"): 819 save.restore(sess, save_path) 820 821 # Test restoring large tensors (triggers a thread pool) 822 def testRestoreLargeTensors(self): 823 save_dir = self.get_temp_dir() 824 def _model(): 825 small_v = [variable_scope.get_variable( 826 "small%d" % i, shape=[10, 2], use_resource=True) for i in range(5)] 827 large_v = [variable_scope.get_variable( 828 "large%d" % i, shape=[32000, 1000], use_resource=True) 829 for i in range(3)] 830 return small_v + large_v 831 832 save_graph = ops_lib.Graph() 833 with save_graph.as_default(), self.session(graph=save_graph) as sess: 834 orig_vars = _model() 835 self.evaluate(variables.global_variables_initializer()) 836 save = saver_module.Saver(max_to_keep=1) 837 self.evaluate(variables.global_variables_initializer()) 838 save.save(sess, save_dir) 839 orig_vals = self.evaluate(orig_vars) 840 841 restore_graph = ops_lib.Graph() 842 with restore_graph.as_default(), self.session( 843 graph=restore_graph) as sess: 844 restored_vars = _model() 845 save = saver_module.Saver(max_to_keep=1) 846 save.restore(sess, save_dir) 847 restored_vals = self.evaluate(restored_vars) 848 849 for orig, restored in zip(orig_vals, restored_vals): 850 self.assertAllEqual(orig, restored) 851 852 853class SaveRestoreShardedTest(test.TestCase): 854 855 _WRITE_VERSION = saver_pb2.SaverDef.V1 856 857 def _get_test_dir(self, dirname): 858 test_dir = os.path.join(self.get_temp_dir(), dirname) 859 gfile.MakeDirs(test_dir) 860 return test_dir 861 862 def testBasics(self): 863 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 864 865 # Build a graph with 2 parameter nodes on different devices. 866 with session.Session( 867 target="", 868 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 869 with sess.graph.device("/cpu:0"): 870 v0 = variables.VariableV1(10, name="v0") 871 t0 = saver_test_utils.CheckpointedOp(name="t0") 872 with sess.graph.device("/cpu:1"): 873 v1 = variables.VariableV1(20, name="v1") 874 t1 = saver_test_utils.CheckpointedOp(name="t1") 875 save = saver_module.Saver( 876 { 877 "v0": v0, 878 "v1": v1, 879 "t0": t0.saveable, 880 "t1": t1.saveable 881 }, 882 write_version=self._WRITE_VERSION, 883 sharded=True) 884 self.evaluate(variables.global_variables_initializer()) 885 t0.insert("k1", 30.0).run() 886 t1.insert("k2", 40.0).run() 887 val = save.save(sess, save_path) 888 if save._write_version is saver_pb2.SaverDef.V1: 889 self.assertEqual(save_path + "-?????-of-00002", val) 890 else: 891 self.assertEqual(save_path, val) 892 meta_graph_filename = checkpoint_management.meta_graph_filename(val) 893 self.assertEqual(save_path + ".meta", meta_graph_filename) 894 895 if save._write_version is saver_pb2.SaverDef.V1: 896 # Restore different ops from shard 0 of the saved files. 897 with session.Session( 898 target="", 899 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 900 with sess.graph.device("/cpu:0"): 901 v0 = variables.VariableV1(111, name="v0") 902 t0 = saver_test_utils.CheckpointedOp(name="t0") 903 save = saver_module.Saver( 904 { 905 "v0": v0, 906 "t0": t0.saveable 907 }, 908 write_version=self._WRITE_VERSION, 909 sharded=True) 910 self.evaluate(variables.global_variables_initializer()) 911 t0.insert("k11", 33.0).run() 912 self.assertEqual(111, self.evaluate(v0)) 913 self.assertEqual(b"k11", self.evaluate(t0.keys())) 914 self.assertEqual(33.0, self.evaluate(t0.values())) 915 save.restore(sess, save_path + "-00000-of-00002") 916 self.assertEqual(10, self.evaluate(v0)) 917 self.assertEqual(b"k1", self.evaluate(t0.keys())) 918 self.assertEqual(30.0, self.evaluate(t0.values())) 919 920 # Restore different ops from shard 1 of the saved files. 921 with session.Session( 922 target="", 923 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 924 with sess.graph.device("/cpu:0"): 925 v1 = variables.VariableV1(222) 926 t1 = saver_test_utils.CheckpointedOp(name="t1") 927 save = saver_module.Saver( 928 { 929 "v1": v1, 930 "t1": t1.saveable 931 }, 932 write_version=self._WRITE_VERSION, 933 sharded=True) 934 self.evaluate(variables.global_variables_initializer()) 935 t1.insert("k22", 44.0).run() 936 self.assertEqual(222, self.evaluate(v1)) 937 self.assertEqual(b"k22", self.evaluate(t1.keys())) 938 self.assertEqual(44.0, self.evaluate(t1.values())) 939 save.restore(sess, save_path + "-00001-of-00002") 940 self.assertEqual(20, self.evaluate(v1)) 941 self.assertEqual(b"k2", self.evaluate(t1.keys())) 942 self.assertEqual(40.0, self.evaluate(t1.values())) 943 944 # Now try a restore with the sharded filename. 945 with session.Session( 946 target="", 947 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 948 with sess.graph.device("/cpu:0"): 949 v0 = variables.VariableV1(111, name="v0") 950 t0 = saver_test_utils.CheckpointedOp(name="t0") 951 with sess.graph.device("/cpu:1"): 952 v1 = variables.VariableV1(222, name="v1") 953 t1 = saver_test_utils.CheckpointedOp(name="t1") 954 save = saver_module.Saver( 955 { 956 "v0": v0, 957 "v1": v1, 958 "t0": t0.saveable, 959 "t1": t1.saveable 960 }, 961 write_version=self._WRITE_VERSION, 962 sharded=True) 963 self.evaluate(variables.global_variables_initializer()) 964 t0.insert("k11", 33.0).run() 965 t1.insert("k22", 44.0).run() 966 self.assertEqual(111, self.evaluate(v0)) 967 self.assertEqual(222, self.evaluate(v1)) 968 self.assertEqual(b"k11", self.evaluate(t0.keys())) 969 self.assertEqual(33.0, self.evaluate(t0.values())) 970 self.assertEqual(b"k22", self.evaluate(t1.keys())) 971 self.assertEqual(44.0, self.evaluate(t1.values())) 972 save_path = os.path.join(self.get_temp_dir(), "sharded_basics") 973 if save._write_version is saver_pb2.SaverDef.V1: 974 save.restore(sess, save_path + "-?????-of-?????") 975 else: 976 save.restore(sess, save_path) 977 self.assertEqual(10, self.evaluate(v0)) 978 self.assertEqual(20, self.evaluate(v1)) 979 self.assertEqual(b"k1", self.evaluate(t0.keys())) 980 self.assertEqual(30.0, self.evaluate(t0.values())) 981 self.assertEqual(b"k2", self.evaluate(t1.keys())) 982 self.assertEqual(40.0, self.evaluate(t1.values())) 983 984 if save._write_version is saver_pb2.SaverDef.V1: 985 self.assertEqual( 986 checkpoint_management.latest_checkpoint(self.get_temp_dir()), 987 os.path.join(self.get_temp_dir(), "sharded_basics-?????-of-00002")) 988 else: 989 self.assertEqual( 990 checkpoint_management.latest_checkpoint(self.get_temp_dir()), 991 os.path.join(self.get_temp_dir(), "sharded_basics")) 992 993 @test_util.run_deprecated_v1 994 def testSaverDef(self): 995 with self.cached_session(): 996 v0 = variables.VariableV1(123, name="v0") 997 save = saver_module.Saver({"v0": v0}, sharded=True) 998 sd = save.as_saver_def() 999 self.assertTrue(sd.sharded) 1000 1001 def _testPartitionedVariables(self, use_resource): 1002 var_full_shape = [10, 3] 1003 # Allows save/restore mechanism to work w/ different slicings. 1004 var_name = "my_var" 1005 saved_dir = self._get_test_dir("partitioned_variables") 1006 saved_path = os.path.join(saved_dir, "ckpt") 1007 1008 call_saver_with_dict = False # updated by test loop below 1009 1010 def _save(partitioner=None): 1011 with self.session(graph=ops_lib.Graph()) as sess: 1012 # Calls .eval() to return the ndarray that makes up the full variable. 1013 rnd = random_ops.random_uniform(var_full_shape).eval() 1014 1015 if partitioner: 1016 vs = [ 1017 variable_scope.get_variable( 1018 var_name, 1019 shape=var_full_shape, 1020 initializer=rnd, 1021 partitioner=partitioner, 1022 use_resource=use_resource) 1023 ] 1024 else: 1025 if use_resource: 1026 vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)] 1027 else: 1028 vs = [variables.VariableV1(rnd, name=var_name)] 1029 1030 self.evaluate(variables.global_variables_initializer()) 1031 if call_saver_with_dict: 1032 saver = saver_module.Saver({var_name: vs[0]}) 1033 else: 1034 saver = saver_module.Saver(vs) 1035 actual_path = saver.save(sess, saved_path) 1036 self.assertEqual(saved_path, actual_path) 1037 1038 return rnd 1039 1040 def _restore(partitioner=None): 1041 with self.session(graph=ops_lib.Graph()) as sess: 1042 if partitioner: 1043 new_vs = [ 1044 variable_scope.get_variable( 1045 var_name, 1046 shape=var_full_shape, 1047 initializer=array_ops.zeros(var_full_shape), 1048 partitioner=partitioner) 1049 ] 1050 else: 1051 new_vs = [ 1052 variables.VariableV1( 1053 array_ops.zeros( 1054 shape=var_full_shape), # != original contents. 1055 name=var_name) 1056 ] 1057 1058 self.evaluate(variables.global_variables_initializer()) 1059 if call_saver_with_dict: 1060 saver = saver_module.Saver({ 1061 var_name: new_vs[0] 1062 }) 1063 else: 1064 saver = saver_module.Saver(new_vs) 1065 saver.restore(sess, saved_path) 1066 1067 if partitioner: 1068 return new_vs[0].as_tensor().eval() 1069 else: 1070 return new_vs[0].eval() 1071 1072 for call_saver_with_dict in {False, True}: 1073 # Save PartitionedVariable and restore into full variable. 1074 saved_full = _save( 1075 partitioner=partitioned_variables.fixed_size_partitioner( 1076 num_shards=2)) 1077 restored_full = _restore() 1078 self.assertAllEqual(saved_full, restored_full) 1079 1080 # Restores into the same number of partitions. 1081 restored_full = _restore( 1082 partitioner=partitioned_variables.fixed_size_partitioner( 1083 num_shards=2)) 1084 self.assertAllEqual(saved_full, restored_full) 1085 1086 # Restores into a different number of partitions. 1087 restored_full = _restore( 1088 partitioner=partitioned_variables.fixed_size_partitioner( 1089 num_shards=3)) 1090 self.assertAllEqual(saved_full, restored_full) 1091 1092 # Now, saves a full variable and restores PartitionedVariable. 1093 saved_full = _save() 1094 restored_full = _restore( 1095 partitioner=partitioned_variables.fixed_size_partitioner( 1096 num_shards=3)) 1097 self.assertAllEqual(saved_full, restored_full) 1098 1099 @test_util.run_deprecated_v1 1100 def testPartitionedVariable(self): 1101 self._testPartitionedVariables(use_resource=False) 1102 1103 @test_util.run_deprecated_v1 1104 def testPartitionedResourceVariable(self): 1105 self._testPartitionedVariables(use_resource=True) 1106 1107 1108class SaveRestoreShardedTestV2(SaveRestoreShardedTest): 1109 _WRITE_VERSION = saver_pb2.SaverDef.V2 1110 1111 1112class MaxToKeepTest(test.TestCase): 1113 1114 def _get_test_dir(self, dirname): 1115 test_dir = os.path.join(self.get_temp_dir(), dirname) 1116 gfile.MakeDirs(test_dir) 1117 return test_dir 1118 1119 def assertCheckpointState(self, model_checkpoint_path, 1120 all_model_checkpoint_paths, save_dir): 1121 checkpoint_state = checkpoint_management.get_checkpoint_state(save_dir) 1122 self.assertEqual(checkpoint_state.model_checkpoint_path, 1123 model_checkpoint_path) 1124 self.assertEqual(checkpoint_state.all_model_checkpoint_paths, 1125 all_model_checkpoint_paths) 1126 1127 def testMaxToKeepEager(self): 1128 with context.eager_mode(): 1129 save_dir = self._get_test_dir("max_to_keep_eager") 1130 1131 v = variable_scope.variable(10.0, name="v") 1132 save = saver_module.Saver({"v": v}, max_to_keep=2) 1133 self.evaluate(variables.global_variables_initializer()) 1134 if not context.executing_eagerly(): 1135 self.assertEqual([], save.last_checkpoints) 1136 1137 s1 = save.save(None, os.path.join(save_dir, "s1")) 1138 self.assertEqual([s1], save.last_checkpoints) 1139 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1140 self.assertCheckpointState( 1141 model_checkpoint_path=s1, 1142 all_model_checkpoint_paths=[s1], 1143 save_dir=save_dir) 1144 1145 s2 = save.save(None, os.path.join(save_dir, "s2")) 1146 self.assertEqual([s1, s2], save.last_checkpoints) 1147 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1148 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1149 self.assertCheckpointState( 1150 model_checkpoint_path=s2, 1151 all_model_checkpoint_paths=[s1, s2], 1152 save_dir=save_dir) 1153 1154 s3 = save.save(None, os.path.join(save_dir, "s3")) 1155 self.assertEqual([s2, s3], save.last_checkpoints) 1156 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1157 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1158 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1159 self.assertCheckpointState( 1160 model_checkpoint_path=s3, 1161 all_model_checkpoint_paths=[s2, s3], 1162 save_dir=save_dir) 1163 1164 # Create a second helper, identical to the first. 1165 save2 = saver_module.Saver({"v": v}, max_to_keep=2) 1166 save2.set_last_checkpoints(save.last_checkpoints) 1167 1168 # Exercise the first helper. 1169 1170 # Adding s2 again (old s2 is removed first, then new s2 appended) 1171 s2 = save.save(None, os.path.join(save_dir, "s2")) 1172 self.assertEqual([s3, s2], save.last_checkpoints) 1173 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1174 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1175 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1176 self.assertCheckpointState( 1177 model_checkpoint_path=s2, 1178 all_model_checkpoint_paths=[s3, s2], 1179 save_dir=save_dir) 1180 1181 # Adding s1 (s3 should now be deleted as oldest in list) 1182 s1 = save.save(None, os.path.join(save_dir, "s1")) 1183 self.assertEqual([s2, s1], save.last_checkpoints) 1184 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1185 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1186 self.assertCheckpointState( 1187 model_checkpoint_path=s1, 1188 all_model_checkpoint_paths=[s2, s1], 1189 save_dir=save_dir) 1190 1191 s2 = save2.save(None, os.path.join(save_dir, "s2")) 1192 self.assertEqual([s3, s2], save2.last_checkpoints) 1193 # Created by the first helper. 1194 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1195 # Deleted by the first helper. 1196 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1197 1198 @test_util.run_deprecated_v1 1199 def testNonSharded(self): 1200 save_dir = self._get_test_dir("max_to_keep_non_sharded") 1201 1202 with self.cached_session() as sess: 1203 v = variables.VariableV1(10.0, name="v") 1204 save = saver_module.Saver({"v": v}, max_to_keep=2) 1205 self.evaluate(variables.global_variables_initializer()) 1206 self.assertEqual([], save.last_checkpoints) 1207 1208 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1209 self.assertEqual([s1], save.last_checkpoints) 1210 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1211 self.assertCheckpointState( 1212 model_checkpoint_path=s1, 1213 all_model_checkpoint_paths=[s1], 1214 save_dir=save_dir) 1215 1216 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1217 self.assertEqual([s1, s2], save.last_checkpoints) 1218 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1219 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1220 self.assertCheckpointState( 1221 model_checkpoint_path=s2, 1222 all_model_checkpoint_paths=[s1, s2], 1223 save_dir=save_dir) 1224 1225 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1226 self.assertEqual([s2, s3], save.last_checkpoints) 1227 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1228 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1229 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1230 self.assertCheckpointState( 1231 model_checkpoint_path=s3, 1232 all_model_checkpoint_paths=[s2, s3], 1233 save_dir=save_dir) 1234 1235 # Create a second helper, identical to the first. 1236 save2 = saver_module.Saver(saver_def=save.as_saver_def()) 1237 save2.set_last_checkpoints(save.last_checkpoints) 1238 1239 # Create a third helper, with the same configuration but no knowledge of 1240 # previous checkpoints. 1241 save3 = saver_module.Saver(saver_def=save.as_saver_def()) 1242 1243 # Exercise the first helper. 1244 1245 # Adding s2 again (old s2 is removed first, then new s2 appended) 1246 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1247 self.assertEqual([s3, s2], save.last_checkpoints) 1248 self.assertFalse(checkpoint_management.checkpoint_exists(s1)) 1249 self.assertFalse( 1250 checkpoint_management.checkpoint_exists( 1251 checkpoint_management.meta_graph_filename(s1))) 1252 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1253 self.assertTrue( 1254 checkpoint_management.checkpoint_exists( 1255 checkpoint_management.meta_graph_filename(s3))) 1256 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1257 self.assertTrue( 1258 checkpoint_management.checkpoint_exists( 1259 checkpoint_management.meta_graph_filename(s2))) 1260 self.assertCheckpointState( 1261 model_checkpoint_path=s2, 1262 all_model_checkpoint_paths=[s3, s2], 1263 save_dir=save_dir) 1264 1265 # Adding s1 (s3 should now be deleted as oldest in list) 1266 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1267 self.assertEqual([s2, s1], save.last_checkpoints) 1268 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1269 self.assertFalse( 1270 checkpoint_management.checkpoint_exists( 1271 checkpoint_management.meta_graph_filename(s3))) 1272 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1273 self.assertTrue( 1274 checkpoint_management.checkpoint_exists( 1275 checkpoint_management.meta_graph_filename(s2))) 1276 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1277 self.assertTrue( 1278 checkpoint_management.checkpoint_exists( 1279 checkpoint_management.meta_graph_filename(s1))) 1280 self.assertCheckpointState( 1281 model_checkpoint_path=s1, 1282 all_model_checkpoint_paths=[s2, s1], 1283 save_dir=save_dir) 1284 1285 # Exercise the second helper. 1286 1287 # Adding s2 again (old s2 is removed first, then new s2 appended) 1288 s2 = save2.save(sess, os.path.join(save_dir, "s2")) 1289 self.assertEqual([s3, s2], save2.last_checkpoints) 1290 # Created by the first helper. 1291 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1292 self.assertTrue( 1293 checkpoint_management.checkpoint_exists( 1294 checkpoint_management.meta_graph_filename(s1))) 1295 # Deleted by the first helper. 1296 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1297 self.assertFalse( 1298 checkpoint_management.checkpoint_exists( 1299 checkpoint_management.meta_graph_filename(s3))) 1300 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1301 self.assertTrue( 1302 checkpoint_management.checkpoint_exists( 1303 checkpoint_management.meta_graph_filename(s2))) 1304 self.assertCheckpointState( 1305 model_checkpoint_path=s2, 1306 all_model_checkpoint_paths=[s3, s2], 1307 save_dir=save_dir) 1308 1309 # Adding s1 (s3 should now be deleted as oldest in list) 1310 s1 = save2.save(sess, os.path.join(save_dir, "s1")) 1311 self.assertEqual([s2, s1], save2.last_checkpoints) 1312 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1313 self.assertFalse( 1314 checkpoint_management.checkpoint_exists( 1315 checkpoint_management.meta_graph_filename(s3))) 1316 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1317 self.assertTrue( 1318 checkpoint_management.checkpoint_exists( 1319 checkpoint_management.meta_graph_filename(s2))) 1320 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1321 self.assertTrue( 1322 checkpoint_management.checkpoint_exists( 1323 checkpoint_management.meta_graph_filename(s1))) 1324 self.assertCheckpointState( 1325 model_checkpoint_path=s1, 1326 all_model_checkpoint_paths=[s2, s1], 1327 save_dir=save_dir) 1328 1329 # Exercise the third helper. 1330 1331 # Adding s2 again (but helper is unaware of previous s2) 1332 s2 = save3.save(sess, os.path.join(save_dir, "s2")) 1333 self.assertEqual([s2], save3.last_checkpoints) 1334 # Created by the first helper. 1335 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1336 self.assertTrue( 1337 checkpoint_management.checkpoint_exists( 1338 checkpoint_management.meta_graph_filename(s1))) 1339 # Deleted by the first helper. 1340 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1341 self.assertFalse( 1342 checkpoint_management.checkpoint_exists( 1343 checkpoint_management.meta_graph_filename(s3))) 1344 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1345 self.assertTrue( 1346 checkpoint_management.checkpoint_exists( 1347 checkpoint_management.meta_graph_filename(s2))) 1348 # Even though the file for s1 exists, this saver isn't aware of it, which 1349 # is why it doesn't end up in the checkpoint state. 1350 self.assertCheckpointState( 1351 model_checkpoint_path=s2, 1352 all_model_checkpoint_paths=[s2], 1353 save_dir=save_dir) 1354 1355 # Adding s1 (s3 should not be deleted because helper is unaware of it) 1356 s1 = save3.save(sess, os.path.join(save_dir, "s1")) 1357 self.assertEqual([s2, s1], save3.last_checkpoints) 1358 self.assertFalse(checkpoint_management.checkpoint_exists(s3)) 1359 self.assertFalse( 1360 checkpoint_management.checkpoint_exists( 1361 checkpoint_management.meta_graph_filename(s3))) 1362 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1363 self.assertTrue( 1364 checkpoint_management.checkpoint_exists( 1365 checkpoint_management.meta_graph_filename(s2))) 1366 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1367 self.assertTrue( 1368 checkpoint_management.checkpoint_exists( 1369 checkpoint_management.meta_graph_filename(s1))) 1370 self.assertCheckpointState( 1371 model_checkpoint_path=s1, 1372 all_model_checkpoint_paths=[s2, s1], 1373 save_dir=save_dir) 1374 1375 def testSharded(self): 1376 save_dir = self._get_test_dir("max_to_keep_sharded") 1377 1378 with session.Session( 1379 target="", 1380 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 1381 with sess.graph.device("/cpu:0"): 1382 v0 = variables.VariableV1(111, name="v0") 1383 with sess.graph.device("/cpu:1"): 1384 v1 = variables.VariableV1(222, name="v1") 1385 save = saver_module.Saver( 1386 { 1387 "v0": v0, 1388 "v1": v1 1389 }, sharded=True, max_to_keep=2) 1390 self.evaluate(variables.global_variables_initializer()) 1391 self.assertEqual([], save.last_checkpoints) 1392 1393 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1394 self.assertEqual([s1], save.last_checkpoints) 1395 if save._write_version is saver_pb2.SaverDef.V1: 1396 self.assertEqual(2, len(gfile.Glob(s1))) 1397 else: 1398 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1399 1400 self.assertTrue( 1401 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1402 1403 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1404 self.assertEqual([s1, s2], save.last_checkpoints) 1405 if save._write_version is saver_pb2.SaverDef.V1: 1406 self.assertEqual(2, len(gfile.Glob(s1))) 1407 else: 1408 self.assertEqual(4, len(gfile.Glob(s1 + "*"))) 1409 self.assertTrue( 1410 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1411 if save._write_version is saver_pb2.SaverDef.V1: 1412 self.assertEqual(2, len(gfile.Glob(s2))) 1413 else: 1414 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1415 self.assertTrue( 1416 gfile.Exists(checkpoint_management.meta_graph_filename(s2))) 1417 1418 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1419 self.assertEqual([s2, s3], save.last_checkpoints) 1420 self.assertEqual(0, len(gfile.Glob(s1 + "*"))) 1421 self.assertFalse( 1422 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1423 if save._write_version is saver_pb2.SaverDef.V1: 1424 self.assertEqual(2, len(gfile.Glob(s2))) 1425 else: 1426 self.assertEqual(4, len(gfile.Glob(s2 + "*"))) 1427 self.assertTrue( 1428 gfile.Exists(checkpoint_management.meta_graph_filename(s2))) 1429 if save._write_version is saver_pb2.SaverDef.V1: 1430 self.assertEqual(2, len(gfile.Glob(s3))) 1431 else: 1432 self.assertEqual(4, len(gfile.Glob(s3 + "*"))) 1433 self.assertTrue( 1434 gfile.Exists(checkpoint_management.meta_graph_filename(s3))) 1435 1436 def testNoMaxToKeep(self): 1437 save_dir = self._get_test_dir("no_max_to_keep") 1438 save_dir2 = self._get_test_dir("max_to_keep_0") 1439 1440 with self.cached_session() as sess: 1441 v = variables.VariableV1(10.0, name="v") 1442 self.evaluate(variables.global_variables_initializer()) 1443 1444 # Test max_to_keep being None. 1445 save = saver_module.Saver({"v": v}, max_to_keep=None) 1446 self.assertEqual([], save.last_checkpoints) 1447 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1448 self.assertEqual([], save.last_checkpoints) 1449 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1450 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1451 self.assertEqual([], save.last_checkpoints) 1452 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1453 1454 # Test max_to_keep being 0. 1455 save2 = saver_module.Saver({"v": v}, max_to_keep=0) 1456 self.assertEqual([], save2.last_checkpoints) 1457 s1 = save2.save(sess, os.path.join(save_dir2, "s1")) 1458 self.assertEqual([], save2.last_checkpoints) 1459 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1460 s2 = save2.save(sess, os.path.join(save_dir2, "s2")) 1461 self.assertEqual([], save2.last_checkpoints) 1462 self.assertTrue(checkpoint_management.checkpoint_exists(s2)) 1463 1464 def testNoMetaGraph(self): 1465 save_dir = self._get_test_dir("no_meta_graph") 1466 1467 with self.cached_session() as sess: 1468 v = variables.VariableV1(10.0, name="v") 1469 save = saver_module.Saver({"v": v}) 1470 self.evaluate(variables.global_variables_initializer()) 1471 1472 s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False) 1473 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1474 self.assertFalse( 1475 gfile.Exists(checkpoint_management.meta_graph_filename(s1))) 1476 1477 1478class KeepCheckpointEveryNHoursTest(test.TestCase): 1479 1480 def _get_test_dir(self, dirname): 1481 test_dir = os.path.join(self.get_temp_dir(), dirname) 1482 gfile.MakeDirs(test_dir) 1483 return test_dir 1484 1485 @test_util.run_in_graph_and_eager_modes 1486 @test.mock.patch.object(saver_module, "time") 1487 def testNonSharded(self, mock_time): 1488 save_dir = self._get_test_dir("keep_checkpoint_every_n_hours") 1489 1490 with self.cached_session() as sess: 1491 v = variable_scope.variable([10.0], name="v") 1492 # Run the initializer NOW to avoid the 0.5s overhead of the first Run() 1493 # call, which throws the test timing off in fastbuild mode. 1494 self.evaluate(variables.global_variables_initializer()) 1495 # Create a saver that will keep the last 2 checkpoints plus one every 0.7 1496 # seconds. 1497 start_time = time.time() 1498 mock_time.time.return_value = start_time 1499 save = saver_module.Saver( 1500 { 1501 "v": v 1502 }, max_to_keep=2, keep_checkpoint_every_n_hours=0.7 / 3600) 1503 self.assertEqual([], save.last_checkpoints) 1504 1505 # Wait till 1 seconds have elapsed so s1 will be old enough to keep. 1506 # sleep may return early, don't trust it. 1507 mock_time.time.return_value = start_time + 1.0 1508 s1 = save.save(sess, os.path.join(save_dir, "s1")) 1509 self.assertEqual([s1], save.last_checkpoints) 1510 1511 s2 = save.save(sess, os.path.join(save_dir, "s2")) 1512 self.assertEqual([s1, s2], save.last_checkpoints) 1513 1514 # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(), 1515 # would normally delete s1, because max_to_keep is 2. However, s1 is 1516 # older than 0.7s so we must keep it. 1517 s3 = save.save(sess, os.path.join(save_dir, "s3")) 1518 self.assertEqual([s2, s3], save.last_checkpoints) 1519 1520 # s1 should still be here, we are Not checking now to reduce time 1521 # variance in the test. 1522 1523 # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next 1524 # call to Save(), will delete s2, because max_to_keep is 2, and because 1525 # we already kept the old s1. s2 is very close in time to s1 so it gets 1526 # deleted. 1527 s4 = save.save(sess, os.path.join(save_dir, "s4")) 1528 self.assertEqual([s3, s4], save.last_checkpoints) 1529 1530 # Check that s1 is still here, but s2 is gone. 1531 self.assertTrue(checkpoint_management.checkpoint_exists(s1)) 1532 self.assertFalse(checkpoint_management.checkpoint_exists(s2)) 1533 self.assertTrue(checkpoint_management.checkpoint_exists(s3)) 1534 self.assertTrue(checkpoint_management.checkpoint_exists(s4)) 1535 1536 1537class SaveRestoreWithVariableNameMap(test.TestCase): 1538 1539 def _testNonReshape(self, variable_op): 1540 save_path = os.path.join(self.get_temp_dir(), "non_reshape") 1541 1542 with self.session(graph=ops_lib.Graph()) as sess: 1543 # Build a graph with 2 parameter nodes, and Save and 1544 # Restore nodes for them. 1545 v0 = variable_op(10.0, name="v0") 1546 v1 = variable_op(20.0, name="v1") 1547 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1548 self.evaluate(variables.global_variables_initializer()) 1549 1550 # Check that the parameter nodes have been initialized. 1551 self.assertEqual(10.0, self.evaluate(v0)) 1552 self.assertEqual(20.0, self.evaluate(v1)) 1553 1554 # Save the initialized values in the file at "save_path" 1555 # Use a variable name map to set the saved tensor names 1556 val = save.save(sess, save_path) 1557 self.assertTrue(isinstance(val, six.string_types)) 1558 self.assertEqual(save_path, val) 1559 1560 # Verify that the original names are not in the Saved file 1561 save = saver_module.Saver({"v0": v0, "v1": v1}) 1562 with self.assertRaisesOpError("not found in checkpoint"): 1563 save.restore(sess, save_path) 1564 1565 # Verify that the mapped names are present in the Saved file and can be 1566 # Restored using remapped names. 1567 with self.session(graph=ops_lib.Graph()) as sess: 1568 v0 = variable_op(-1.0, name="v0") 1569 v1 = variable_op(-1.0, name="v1") 1570 1571 if not context.executing_eagerly(): 1572 with self.assertRaisesOpError("uninitialized"): 1573 self.evaluate(v0) 1574 with self.assertRaisesOpError("uninitialized"): 1575 self.evaluate(v1) 1576 1577 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1578 save.restore(sess, save_path) 1579 1580 # Check that the parameter nodes have been restored. 1581 if not context.executing_eagerly(): 1582 self.assertEqual(10.0, self.evaluate(v0)) 1583 self.assertEqual(20.0, self.evaluate(v1)) 1584 1585 # Add a prefix to the node names in the current graph and Restore using 1586 # remapped names. 1587 with self.session(graph=ops_lib.Graph()) as sess: 1588 v0 = variable_op(-1.0, name="restore_prefix/v0") 1589 v1 = variable_op(-1.0, name="restore_prefix/v1") 1590 1591 if not context.executing_eagerly(): 1592 with self.assertRaisesOpError("uninitialized"): 1593 self.evaluate(v0) 1594 with self.assertRaisesOpError("uninitialized"): 1595 self.evaluate(v1) 1596 1597 # Restore the saved values in the parameter nodes. 1598 save = saver_module.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1}) 1599 save.restore(sess, save_path) 1600 1601 # Check that the parameter nodes have been restored. 1602 self.assertEqual(10.0, self.evaluate(v0)) 1603 self.assertEqual(20.0, self.evaluate(v1)) 1604 1605 @test_util.run_in_graph_and_eager_modes 1606 def testNonReshapeResourceVariable(self): 1607 self._testNonReshape(resource_variable_ops.ResourceVariable) 1608 1609 def testNonReshapeVariable(self): 1610 self._testNonReshape(variables.Variable) 1611 1612 1613class MetaGraphTest(test.TestCase): 1614 1615 def _get_test_dir(self, dirname): 1616 test_dir = os.path.join(self.get_temp_dir(), dirname) 1617 gfile.MakeDirs(test_dir) 1618 return test_dir 1619 1620 @test_util.run_v1_only("b/120545219") 1621 def testAddCollectionDef(self): 1622 test_dir = self._get_test_dir("good_collection") 1623 filename = os.path.join(test_dir, "metafile") 1624 with self.cached_session(): 1625 # Creates a graph. 1626 v0 = variables.VariableV1(1.0, name="v0") 1627 control_flow_ops.cond( 1628 math_ops.less(v0, 10), lambda: math_ops.add(v0, 1), 1629 lambda: math_ops.subtract(v0, 1)) 1630 control_flow_ops.while_loop(lambda i: math_ops.less(i, 10), 1631 lambda i: math_ops.add(i, 1), [v0]) 1632 var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64)) 1633 count_up_to = var.count_up_to(3) 1634 input_queue = data_flow_ops.FIFOQueue( 1635 30, dtypes.float32, shared_name="collection_queue") 1636 qr = queue_runner_impl.QueueRunner(input_queue, [count_up_to]) 1637 variables.global_variables_initializer() 1638 # Creates a saver. 1639 save = saver_module.Saver({"v0": v0}) 1640 # Adds a set of collections. 1641 ops_lib.add_to_collection("int_collection", 3) 1642 ops_lib.add_to_collection("float_collection", 3.5) 1643 ops_lib.add_to_collection("string_collection", "hello") 1644 ops_lib.add_to_collection("variable_collection", v0) 1645 # Add QueueRunners. 1646 queue_runner_impl.add_queue_runner(qr) 1647 # Adds user_defined proto in three formats: string, bytes and Any. 1648 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 1649 ops_lib.add_to_collection("user_defined_string_collection", 1650 str(queue_runner)) 1651 ops_lib.add_to_collection("user_defined_bytes_collection", 1652 queue_runner.SerializeToString()) 1653 any_buf = Any() 1654 any_buf.Pack(queue_runner) 1655 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 1656 1657 # Generates MetaGraphDef. 1658 meta_graph_def = save.export_meta_graph(filename) 1659 self.assertTrue(meta_graph_def.HasField("saver_def")) 1660 self.assertTrue(meta_graph_def.HasField("graph_def")) 1661 self.assertTrue(meta_graph_def.HasField("meta_info_def")) 1662 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "") 1663 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version, 1664 "") 1665 collection_def = meta_graph_def.collection_def 1666 self.assertEqual(len(collection_def), 12) 1667 1668 with ops_lib.Graph().as_default(): 1669 # Restores from MetaGraphDef. 1670 new_saver = saver_module.import_meta_graph(filename) 1671 # Generates a new MetaGraphDef. 1672 new_meta_graph_def = new_saver.export_meta_graph() 1673 # It should be the same as the original. 1674 1675 test_util.assert_meta_graph_protos_equal( 1676 self, meta_graph_def, new_meta_graph_def) 1677 1678 def testAddCollectionDefFails(self): 1679 with self.cached_session(): 1680 # Creates a graph. 1681 v0 = variables.VariableV1(10.0, name="v0") 1682 # Creates a saver. 1683 save = saver_module.Saver({"v0": v0}) 1684 # Generates MetaGraphDef. 1685 meta_graph_def = meta_graph_pb2.MetaGraphDef() 1686 1687 # Verifies that collection with unsupported key will not be added. 1688 ops_lib.add_to_collection(save, 3) 1689 save._add_collection_def(meta_graph_def, save) 1690 self.assertEqual(len(meta_graph_def.collection_def), 0) 1691 1692 # Verifies that collection where item type does not match expected 1693 # type will not be added. 1694 ops_lib.add_to_collection("int_collection", 3) 1695 ops_lib.add_to_collection("int_collection", 3.5) 1696 save._add_collection_def(meta_graph_def, "int_collection") 1697 self.assertEqual(len(meta_graph_def.collection_def), 0) 1698 1699 def _testMultiSaverCollectionSave(self, test_dir): 1700 filename = os.path.join(test_dir, "metafile") 1701 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1702 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1703 with self.session(graph=ops_lib.Graph()) as sess: 1704 # Creates a graph. 1705 v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1706 v1 = variables.VariableV1(11.0, name="v1") 1707 # Creates 2 savers. 1708 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1709 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1710 ops_lib.add_to_collection("savers", saver0) 1711 ops_lib.add_to_collection("savers", saver1) 1712 self.evaluate(variables.global_variables_initializer()) 1713 # Saves to different checkpoints. 1714 saver0.save(sess, saver0_ckpt) 1715 saver1.save(sess, saver1_ckpt) 1716 # Generates MetaGraphDef. 1717 meta_graph_def = saver_module.export_meta_graph(filename) 1718 meta_graph_def0 = saver0.export_meta_graph() 1719 meta_graph_def1 = saver1.export_meta_graph() 1720 1721 # Verifies that there is no saver_def in meta_graph_def. 1722 self.assertFalse(meta_graph_def.HasField("saver_def")) 1723 # Verifies that there is saver_def in meta_graph_def0 and 1. 1724 self.assertTrue(meta_graph_def0.HasField("saver_def")) 1725 self.assertTrue(meta_graph_def1.HasField("saver_def")) 1726 1727 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 1728 collection_def = meta_graph_def.collection_def["savers"] 1729 kind = collection_def.WhichOneof("kind") 1730 self.assertEqual(kind, "bytes_list") 1731 # Verifies that there are 2 entries in SAVERS collection. 1732 savers = getattr(collection_def, kind) 1733 self.assertEqual(2, len(savers.value)) 1734 1735 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def0. 1736 collection_def = meta_graph_def0.collection_def["savers"] 1737 kind = collection_def.WhichOneof("kind") 1738 self.assertEqual(kind, "bytes_list") 1739 # Verifies that there are 2 entries in SAVERS collection. 1740 savers = getattr(collection_def, kind) 1741 self.assertEqual(2, len(savers.value)) 1742 1743 def _testMultiSaverCollectionRestore(self, test_dir): 1744 filename = os.path.join(test_dir, "metafile") 1745 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1746 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1747 with self.session(graph=ops_lib.Graph()) as sess: 1748 # Imports from meta_graph. 1749 saver_module.import_meta_graph(filename) 1750 # Retrieves SAVERS collection. Verifies there are 2 entries. 1751 savers = ops_lib.get_collection("savers") 1752 self.assertEqual(2, len(savers)) 1753 # Retrieves saver0. Verifies that new_saver0 can restore v0, but not v1. 1754 new_saver0 = savers[0] 1755 new_saver0.restore(sess, saver0_ckpt) 1756 v0 = sess.graph.get_tensor_by_name("v0:0") 1757 v1 = sess.graph.get_tensor_by_name("v1:0") 1758 self.assertAllEqual([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], 1759 self.evaluate(v0)) 1760 self.assertEqual([3, 2], v0.get_shape()) 1761 self.assertEqual([], v1.get_shape()) 1762 with self.assertRaisesWithPredicateMatch( 1763 errors_impl.OpError, lambda e: "uninitialized value v1" in e.message): 1764 self.evaluate(v1) 1765 # Retrieves saver1. Verifies that new_saver1 can restore v1. 1766 new_saver1 = savers[1] 1767 new_saver1.restore(sess, saver1_ckpt) 1768 v1 = sess.graph.get_tensor_by_name("v1:0") 1769 self.assertEqual(11.0, self.evaluate(v1)) 1770 1771 @test_util.run_v1_only("b/120545219") 1772 def testMultiSaverCollection(self): 1773 test_dir = self._get_test_dir("saver_collection") 1774 self._testMultiSaverCollectionSave(test_dir) 1775 self._testMultiSaverCollectionRestore(test_dir) 1776 1777 @test_util.run_v1_only("b/120545219") 1778 def testClearExtraneousSavers(self): 1779 test_dir = self._get_test_dir("clear_extraneous_savers") 1780 filename = os.path.join(test_dir, "metafile") 1781 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1782 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 1783 with self.session(graph=ops_lib.Graph()) as sess: 1784 # Creates a graph. 1785 v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0") 1786 v1 = variables.VariableV1(11.0, name="v1") 1787 1788 # Creates 2 savers. 1789 saver0 = saver_module.Saver({"v0": v0}, name="saver0") 1790 saver1 = saver_module.Saver({"v1": v1}, name="saver1") 1791 ops_lib.add_to_collection("savers", saver0) 1792 ops_lib.add_to_collection("savers", saver1) 1793 self.evaluate(variables.global_variables_initializer()) 1794 1795 # Saves to different checkpoints. 1796 saver0.save(sess, saver0_ckpt) 1797 saver1.save(sess, saver1_ckpt) 1798 1799 # Generates MetaGraphDef. 1800 meta_graph_def = saver_module.export_meta_graph(filename) 1801 meta_graph_def0 = saver0.export_meta_graph() 1802 meta_graph_def1 = saver1.export_meta_graph(clear_extraneous_savers=True) 1803 1804 # Verifies that there is no saver_def in meta_graph_def. 1805 self.assertFalse(meta_graph_def.HasField("saver_def")) 1806 # Verifies that there is saver_def in meta_graph_def0 and 1. 1807 self.assertTrue(meta_graph_def0.HasField("saver_def")) 1808 self.assertTrue(meta_graph_def1.HasField("saver_def")) 1809 1810 # Verifies SAVERS is saved as bytes_list for meta_graph_def. 1811 collection_def = meta_graph_def.collection_def["savers"] 1812 kind = collection_def.WhichOneof("kind") 1813 self.assertEqual(kind, "bytes_list") 1814 1815 # Verifies that there are 2 entries in SAVERS collection. 1816 savers = getattr(collection_def, kind) 1817 self.assertEqual(2, len(savers.value)) 1818 1819 # Verifies SAVERS collection is saved as bytes_list for meta_graph_def1. 1820 collection_def = meta_graph_def1.collection_def["savers"] 1821 kind = collection_def.WhichOneof("kind") 1822 self.assertEqual(kind, "bytes_list") 1823 1824 # Verifies that there is 1 entry in SAVERS collection. 1825 savers = getattr(collection_def, kind) 1826 self.assertEqual(1, len(savers.value)) 1827 1828 # Verifies that saver0 graph nodes are omitted from the saver1 export 1829 self.assertEqual(33, len(meta_graph_def0.graph_def.node)) 1830 self.assertEqual(21, len(meta_graph_def1.graph_def.node)) 1831 1832 @test_util.run_deprecated_v1 1833 def testBinaryAndTextFormat(self): 1834 test_dir = self._get_test_dir("binary_and_text") 1835 filename = os.path.join(test_dir, "metafile") 1836 with self.session(graph=ops_lib.Graph()): 1837 # Creates a graph. 1838 variables.VariableV1(10.0, name="v0") 1839 # Exports the graph as binary format. 1840 saver_module.export_meta_graph(filename, as_text=False) 1841 with self.session(graph=ops_lib.Graph()): 1842 # Imports the binary format graph. 1843 saver = saver_module.import_meta_graph(filename) 1844 self.assertIsNotNone(saver) 1845 # Exports the graph as text format. 1846 saver.export_meta_graph(filename, as_text=True) 1847 with self.session(graph=ops_lib.Graph()): 1848 # Imports the text format graph. 1849 saver_module.import_meta_graph(filename) 1850 # Writes wrong contents to the file. 1851 graph_io.write_graph(saver.as_saver_def(), 1852 os.path.dirname(filename), 1853 os.path.basename(filename)) 1854 with self.session(graph=ops_lib.Graph()): 1855 # Import should fail. 1856 with self.assertRaisesWithPredicateMatch(IOError, 1857 lambda e: "Cannot parse file"): 1858 saver_module.import_meta_graph(filename) 1859 # Deletes the file 1860 gfile.Remove(filename) 1861 with self.assertRaisesWithPredicateMatch(IOError, 1862 lambda e: "does not exist"): 1863 saver_module.import_meta_graph(filename) 1864 1865 @test_util.run_v1_only("b/120545219") 1866 def testSliceVariable(self): 1867 test_dir = self._get_test_dir("slice_saver") 1868 filename = os.path.join(test_dir, "metafile") 1869 with self.cached_session(): 1870 v1 = variables.VariableV1([20.0], name="v1") 1871 v2 = variables.VariableV1([20.0], name="v2") 1872 v2._set_save_slice_info( 1873 variables.Variable.SaveSliceInfo("v1", [1], [0], [1])) 1874 1875 # The names are different and will work. 1876 slice_saver = saver_module.Saver({"first": v1, "second": v2}) 1877 self.evaluate(variables.global_variables_initializer()) 1878 # Exports to meta_graph 1879 meta_graph_def = slice_saver.export_meta_graph(filename) 1880 1881 with ops_lib.Graph().as_default(): 1882 # Restores from MetaGraphDef. 1883 new_saver = saver_module.import_meta_graph(filename) 1884 self.assertIsNotNone(new_saver) 1885 # Generates a new MetaGraphDef. 1886 new_meta_graph_def = new_saver.export_meta_graph() 1887 # It should be the same as the original. 1888 test_util.assert_meta_graph_protos_equal(self, meta_graph_def, 1889 new_meta_graph_def) 1890 1891 def _testGraphExtensionSave(self, test_dir): 1892 filename = os.path.join(test_dir, "metafile") 1893 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1894 # Creates an inference graph. 1895 # Hidden 1 1896 images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28]) 1897 with ops_lib.name_scope("hidden1"): 1898 weights = variables.VariableV1( 1899 random_ops.truncated_normal( 1900 [28, 128], stddev=1.0 / math.sqrt(float(28))), 1901 name="weights") 1902 # The use of control_flow_ops.cond here is purely for adding test coverage 1903 # the save and restore of control flow context (which doesn't make any 1904 # sense here from a machine learning perspective). The typical biases is 1905 # a simple Variable without the conditions. 1906 biases = variables.VariableV1( 1907 control_flow_ops.cond( 1908 math_ops.less(random.random(), 0.5), 1909 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 1910 name="biases") 1911 hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases) 1912 # Hidden 2 1913 with ops_lib.name_scope("hidden2"): 1914 weights = variables.VariableV1( 1915 random_ops.truncated_normal( 1916 [128, 32], stddev=1.0 / math.sqrt(float(128))), 1917 name="weights") 1918 1919 # The use of control_flow_ops.while_loop here is purely for adding test 1920 # coverage the save and restore of control flow context (which doesn't 1921 # make any sense here from a machine learning perspective). The typical 1922 # biases is a simple Variable without the conditions. 1923 def loop_cond(it, _): 1924 return it < 2 1925 1926 def loop_body(it, biases): 1927 biases += constant_op.constant(0.1, shape=[32]) 1928 return it + 1, biases 1929 1930 _, biases = control_flow_ops.while_loop( 1931 loop_cond, loop_body, 1932 [constant_op.constant(0), 1933 variables.VariableV1(array_ops.zeros([32]))]) 1934 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 1935 # Linear 1936 with ops_lib.name_scope("softmax_linear"): 1937 weights = variables.VariableV1( 1938 random_ops.truncated_normal( 1939 [32, 10], stddev=1.0 / math.sqrt(float(32))), 1940 name="weights") 1941 biases = variables.VariableV1(array_ops.zeros([10]), name="biases") 1942 logits = math_ops.matmul(hidden2, weights) + biases 1943 ops_lib.add_to_collection("logits", logits) 1944 init_all_op = variables.global_variables_initializer() 1945 1946 with self.cached_session() as sess: 1947 # Initializes all the variables. 1948 self.evaluate(init_all_op) 1949 # Runs to logit. 1950 self.evaluate(logits) 1951 # Creates a saver. 1952 saver0 = saver_module.Saver() 1953 saver0.save(sess, saver0_ckpt) 1954 # Generates MetaGraphDef. 1955 saver0.export_meta_graph(filename) 1956 1957 def _testGraphExtensionRestore(self, test_dir): 1958 filename = os.path.join(test_dir, "metafile") 1959 train_filename = os.path.join(test_dir, "train_metafile") 1960 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1961 with self.session(graph=ops_lib.Graph()) as sess: 1962 # Restores from MetaGraphDef. 1963 new_saver = saver_module.import_meta_graph(filename) 1964 # Generates a new MetaGraphDef. 1965 new_saver.export_meta_graph() 1966 # Restores from checkpoint. 1967 new_saver.restore(sess, saver0_ckpt) 1968 # Adds loss and train. 1969 labels = constant_op.constant(0, dtypes.int32, shape=[100], name="labels") 1970 batch_size = array_ops.size(labels) 1971 labels = array_ops.expand_dims(labels, 1) 1972 indices = array_ops.expand_dims(math_ops.range(0, batch_size), 1) 1973 concated = array_ops.concat([indices, labels], 1) 1974 onehot_labels = sparse_ops.sparse_to_dense( 1975 concated, array_ops.stack([batch_size, 10]), 1.0, 0.0) 1976 logits = ops_lib.get_collection("logits")[0] 1977 cross_entropy = nn_ops.softmax_cross_entropy_with_logits( 1978 labels=onehot_labels, logits=logits, name="xentropy") 1979 loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean") 1980 1981 summary.scalar("loss", loss) 1982 # Creates the gradient descent optimizer with the given learning rate. 1983 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 1984 1985 # Runs train_op. 1986 train_op = optimizer.minimize(loss) 1987 ops_lib.add_to_collection("train_op", train_op) 1988 1989 # Runs train_op. 1990 self.evaluate(train_op) 1991 1992 # Generates MetaGraphDef. 1993 saver_module.export_meta_graph(train_filename) 1994 1995 def _testRestoreFromTrainGraphWithControlContext(self, test_dir): 1996 train_filename = os.path.join(test_dir, "train_metafile") 1997 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 1998 with self.session(graph=ops_lib.Graph()) as sess: 1999 # Restores from MetaGraphDef. 2000 new_saver = saver_module.import_meta_graph(train_filename) 2001 # Restores from checkpoint. 2002 new_saver.restore(sess, saver0_ckpt) 2003 train_op = ops_lib.get_collection("train_op")[0] 2004 self.evaluate(train_op) 2005 2006 @test_util.run_deprecated_v1 2007 def testGraphExtension(self): 2008 test_dir = self._get_test_dir("graph_extension") 2009 self._testGraphExtensionSave(test_dir) 2010 self._testGraphExtensionRestore(test_dir) 2011 self._testRestoreFromTrainGraphWithControlContext(test_dir) 2012 2013 def _testGradientSerDes(self, graph_fn): 2014 """Tests that gradients can be computed after exporting and importing. 2015 2016 Builds a graph, exports it, and verifies that it can be imported and the 2017 gradient can be built and run correctly. 2018 2019 Args: 2020 graph_fn: takes a single float Tensor argument as input, outputs a single 2021 Tensor 2022 """ 2023 test_dir = self._get_test_dir("nested_control_flow") 2024 filename = os.path.join(test_dir, "metafile") 2025 saver_ckpt = os.path.join(test_dir, "saver.ckpt") 2026 2027 # Create while loop using `outer_body_fn`. 2028 with ops_lib.Graph().as_default(): 2029 var = variables.VariableV1(0.0) 2030 var_name = var.name 2031 output = graph_fn(var) 2032 output_name = output.name 2033 init_op = variables.global_variables_initializer() 2034 2035 # Generate a MetaGraphDef containing the while loop. 2036 with session.Session() as sess: 2037 self.evaluate(init_op) 2038 self.evaluate(output) 2039 saver = saver_module.Saver() 2040 saver.save(sess, saver_ckpt) 2041 saver.export_meta_graph(filename) 2042 2043 # Build and run the gradients of the while loop. We use this below to 2044 # verify that the gradients are correct with an imported MetaGraphDef. 2045 grad = gradients_impl.gradients([output], [var]) 2046 # Turn off constant folding to avoid breaking testNestedControlFlowSerDes. 2047 # It appears that a missing control dependency in the gradient graph 2048 # causes the fetch node to not be triggered. 2049 no_constfold_config = config_pb2.ConfigProto() 2050 no_constfold_config.graph_options.rewrite_options.constant_folding = ( 2051 rewriter_config_pb2.RewriterConfig.OFF) 2052 with session.Session(config=no_constfold_config) as sess: 2053 self.evaluate(init_op) 2054 expected_grad_value = self.evaluate(grad) 2055 2056 # Restore the MetaGraphDef into a new Graph. 2057 with ops_lib.Graph().as_default(): 2058 with session.Session() as sess: 2059 saver = saver_module.import_meta_graph(filename) 2060 saver.restore(sess, saver_ckpt) 2061 2062 # Make sure we can still build gradients and get the same result. 2063 var = ops_lib.get_default_graph().get_tensor_by_name(var_name) 2064 output = ops_lib.get_default_graph().get_tensor_by_name(output_name) 2065 grad = gradients_impl.gradients([output], [var]) 2066 2067 init_op = variables.global_variables_initializer() 2068 2069 with session.Session(config=no_constfold_config) as sess: 2070 self.evaluate(init_op) 2071 actual_grad_value = self.evaluate(grad) 2072 self.assertEqual(expected_grad_value, actual_grad_value) 2073 2074 def _testWhileLoopAndGradientSerDes(self, outer_body_fn): 2075 # Build a while loop with `outer_body_fn`, export it, and verify that it can 2076 # be imported and the gradient can be built and run correctly. 2077 # pylint: disable=g-long-lambda 2078 return self._testGradientSerDes( 2079 lambda x: control_flow_ops.while_loop( 2080 lambda i, y: i < 5, outer_body_fn, [0, x])[1]) 2081 # pylint: enable=g-long-lambda 2082 2083 def testNestedWhileLoopsSerDes(self): 2084 # Test two simple nested while loops. 2085 def body(i, x): 2086 _, r = control_flow_ops.while_loop(lambda j, y: j < 3, 2087 lambda j, y: (j + 1, y + x), 2088 [0, 0.0]) 2089 return i + 1, x + r 2090 self._testWhileLoopAndGradientSerDes(body) 2091 2092 def testNestedControlFlowSerDes(self): 2093 # Test while loop in a cond in a while loop. 2094 # pylint: disable=g-long-lambda 2095 def body(i, x): 2096 cond_result = control_flow_ops.cond( 2097 i > 0, 2098 lambda: control_flow_ops.while_loop( 2099 lambda j, y: j < 3, 2100 lambda j, y: (j + 1, y + x), 2101 [0, 0.0])[1], 2102 lambda: x) 2103 return i + 1, cond_result 2104 # pylint: enable=g-long-lambda 2105 self._testWhileLoopAndGradientSerDes(body) 2106 2107 def testNestedCondsSerDes(self): 2108 # Test conds in a cond. 2109 # pylint: disable=g-long-lambda 2110 self._testGradientSerDes(lambda x: control_flow_ops.cond( 2111 x > 0, 2112 lambda: control_flow_ops.cond(x > 3, 2113 lambda: array_ops.identity(x), 2114 lambda: math_ops.multiply(x, 2.0)), 2115 lambda: control_flow_ops.cond(x < -3, 2116 lambda: constant_op.constant(1.0), 2117 lambda: math_ops.multiply(x, -1.0)))) 2118 # pylint: enable=g-long-lambda 2119 2120 @test_util.run_v1_only("b/120545219") 2121 def testStrippedOpListDef(self): 2122 with self.cached_session(): 2123 # Creates a graph. 2124 v0 = variables.VariableV1(0.0) 2125 var = variables.VariableV1(10.0) 2126 math_ops.add(v0, var) 2127 2128 @function.Defun(dtypes.float32) 2129 def minus_one(x): 2130 return x - 1 2131 2132 minus_one(array_ops.identity(v0)) 2133 save = saver_module.Saver({"v0": v0}) 2134 variables.global_variables_initializer() 2135 2136 # Generates MetaGraphDef. 2137 meta_graph_def = save.export_meta_graph() 2138 ops = [o.name for o in meta_graph_def.meta_info_def.stripped_op_list.op] 2139 if save._write_version is saver_pb2.SaverDef.V1: 2140 self.assertEqual(ops, [ 2141 "Add", "Assign", "Const", "Identity", "NoOp", 2142 "PlaceholderWithDefault", "RestoreV2", "SaveSlices", "Sub", 2143 "VariableV2" 2144 ]) 2145 else: 2146 self.assertEqual(ops, [ 2147 "Add", "Assign", "Const", "Identity", "NoOp", 2148 "PlaceholderWithDefault", "RestoreV2", "SaveV2", "Sub", "VariableV2" 2149 ]) 2150 2151 # Test calling stripped_op_list_for_graph directly 2152 op_list = meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def) 2153 self.assertEqual(ops, [o.name for o in op_list.op]) 2154 for o in op_list.op: 2155 self.assertEqual(o.summary, "") 2156 self.assertEqual(o.description, "") 2157 2158 @test_util.run_deprecated_v1 2159 def testStripDefaultValuedAttrs(self): 2160 """Verifies that default valued attrs are stripped, unless disabled.""" 2161 2162 # With strip_default_attrs enabled, attributes "T" (float32) and "Tout" 2163 # (complex64) in the "Complex" op must be removed. 2164 with self.cached_session(): 2165 real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") 2166 imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") 2167 math_ops.complex(real_num, imag_num, name="complex") 2168 2169 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2170 variables.global_variables_initializer() 2171 2172 meta_graph_def = save.export_meta_graph(strip_default_attrs=True) 2173 node_def = test_util.get_node_def_from_graph("complex", 2174 meta_graph_def.graph_def) 2175 self.assertNotIn("T", node_def.attr) 2176 self.assertNotIn("Tout", node_def.attr) 2177 2178 # With strip_default_attrs disabled, attributes "T" (float32) and "Tout" 2179 # (complex64) in the "Complex" op must *not* be removed, even if they map 2180 # to their defaults. 2181 with self.session(graph=ops_lib.Graph()): 2182 real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real") 2183 imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag") 2184 math_ops.complex(real_num, imag_num, name="complex") 2185 2186 save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num}) 2187 variables.global_variables_initializer() 2188 2189 meta_graph_def = save.export_meta_graph(strip_default_attrs=False) 2190 node_def = test_util.get_node_def_from_graph("complex", 2191 meta_graph_def.graph_def) 2192 self.assertIn("T", node_def.attr) 2193 self.assertIn("Tout", node_def.attr) 2194 2195 @test_util.run_deprecated_v1 2196 def testImportIntoNamescope(self): 2197 # Test that we can import a meta graph into a namescope. 2198 test_dir = self._get_test_dir("import_into_namescope") 2199 filename = os.path.join(test_dir, "ckpt") 2200 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2201 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2202 with session.Session() as sess: 2203 weights = variables.VariableV1( 2204 random_ops.random_uniform([784, 10]), name="weights") 2205 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2206 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits") 2207 nn_ops.softmax(logit, name="prediction") 2208 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2209 logits=logit, name="cost") 2210 adam.AdamOptimizer().minimize(cost, name="optimize") 2211 saver = saver_module.Saver() 2212 self.evaluate(variables.global_variables_initializer()) 2213 saver.save(sess, filename) 2214 2215 graph = ops_lib.Graph() 2216 with session.Session(graph=graph) as sess: 2217 new_saver = saver_module.import_meta_graph( 2218 filename + ".meta", graph=graph, import_scope="new_model") 2219 new_saver.restore(sess, filename) 2220 sess.run(["new_model/optimize"], { 2221 "new_model/image:0": np.random.random([1, 784]), 2222 "new_model/label:0": np.random.randint( 2223 10, size=[1, 10]) 2224 }) 2225 2226 def testImportIntoNamescopeWithoutVariables(self): 2227 # Save a simple graph that contains no variables into a checkpoint. 2228 test_dir = self._get_test_dir("no_vars_graph") 2229 filename = os.path.join(test_dir, "ckpt") 2230 graph_1 = ops_lib.Graph() 2231 with session.Session(graph=graph_1) as sess: 2232 constant_op.constant([1, 2, 3], name="x") 2233 constant_op.constant([1, 2, 3], name="y") 2234 saver = saver_module.Saver(allow_empty=True) 2235 saver.save(sess, filename) 2236 2237 # Create a fresh graph. 2238 graph_2 = ops_lib.Graph() 2239 with session.Session(graph=graph_2) as sess: 2240 # Restore the above checkpoint under scope "subgraph_1". 2241 new_saver_1 = saver_module.import_meta_graph( 2242 filename + ".meta", graph=graph_2, import_scope="subgraph_1") 2243 # There are no variables to restore, so import_meta_graph should not 2244 # return a Saver. 2245 self.assertIsNone(new_saver_1) 2246 2247 # Create a variable in graph_2 under scope "my_scope". 2248 variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var") 2249 self.evaluate(variables.global_variables_initializer()) 2250 # Restore the checkpoint into a different scope "subgraph_2". 2251 new_saver_2 = saver_module.import_meta_graph( 2252 filename + ".meta", graph=graph_2, import_scope="subgraph_2") 2253 # Because the variable does not live in scope "subgraph_2", 2254 # import_meta_graph should not attempt to restore the variable. So, 2255 # import_meta_graph still won't return a Saver instance. 2256 self.assertIsNone(new_saver_2) 2257 2258 # However, if we restore the checkpoint under scope "my_scope", 2259 # import_meta_graph will detect the variable and return a Saver for 2260 # restoring it. This should happen even when the variable does not 2261 # originate from graph_1. 2262 new_saver_3 = saver_module.import_meta_graph( 2263 filename + ".meta", graph=graph_2, import_scope="my_scope") 2264 self.assertIsInstance(new_saver_3, saver_module.Saver) 2265 2266 @test_util.run_deprecated_v1 2267 def testImportIntoImplicitNamescope(self): 2268 # Test that we can import a meta graph into an implicit namescope. 2269 test_dir = self._get_test_dir("import_into_namescope") 2270 filename = os.path.join(test_dir, "ckpt") 2271 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2272 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2273 with session.Session() as sess: 2274 weights = variables.VariableV1( 2275 random_ops.random_uniform([784, 10]), name="weights") 2276 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2277 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits") 2278 nn_ops.softmax(logit, name="prediction") 2279 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2280 logits=logit, name="cost") 2281 adam.AdamOptimizer().minimize(cost, name="optimize") 2282 saver = saver_module.Saver() 2283 self.evaluate(variables.global_variables_initializer()) 2284 saver.save(sess, filename) 2285 2286 graph = ops_lib.Graph() 2287 with session.Session(graph=graph) as sess: 2288 with ops_lib.name_scope("new_model"): 2289 new_saver = saver_module.import_meta_graph( 2290 filename + ".meta", graph=graph) 2291 2292 new_saver.restore(sess, filename) 2293 sess.run(["new_model/optimize"], { 2294 "new_model/image:0": np.random.random([1, 784]), 2295 "new_model/label:0": np.random.randint( 2296 10, size=[1, 10]) 2297 }) 2298 2299 def testClearDevicesOnImport(self): 2300 # Test that we import a graph without its devices and run successfully. 2301 with ops_lib.Graph().as_default(): 2302 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2303 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2304 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2305 weights = variables.VariableV1( 2306 random_ops.random_uniform([784, 10]), name="weights") 2307 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2308 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2309 nn_ops.softmax(logit, name="prediction") 2310 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2311 logits=logit) 2312 adam.AdamOptimizer().minimize(cost, name="optimize") 2313 meta_graph_def = saver_module.export_meta_graph() 2314 2315 with session.Session(graph=ops_lib.Graph()) as sess: 2316 saver_module.import_meta_graph( 2317 meta_graph_def, clear_devices=False, import_scope="new_model") 2318 # Device refers to GPU, which is not available here. 2319 with self.assertRaises(errors_impl.InvalidArgumentError): 2320 self.evaluate(variables.global_variables_initializer()) 2321 2322 with session.Session(graph=ops_lib.Graph()) as sess: 2323 saver_module.import_meta_graph( 2324 meta_graph_def, clear_devices=True, import_scope="new_model") 2325 self.evaluate(variables.global_variables_initializer()) 2326 sess.run(["new_model/optimize"], { 2327 "new_model/image:0": np.random.random([1, 784]), 2328 "new_model/label:0": np.random.randint( 2329 10, size=[1, 10]) 2330 }) 2331 2332 def testClearDevicesOnExport(self): 2333 # Test that we export a graph without its devices and run successfully. 2334 with ops_lib.Graph().as_default(): 2335 with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"): 2336 image = array_ops.placeholder(dtypes.float32, [None, 784], name="image") 2337 label = array_ops.placeholder(dtypes.float32, [None, 10], name="label") 2338 weights = variables.VariableV1( 2339 random_ops.random_uniform([784, 10]), name="weights") 2340 bias = variables.VariableV1(array_ops.zeros([10]), name="bias") 2341 logit = nn_ops.relu(math_ops.matmul(image, weights) + bias) 2342 nn_ops.softmax(logit, name="prediction") 2343 cost = nn_ops.softmax_cross_entropy_with_logits(labels=label, 2344 logits=logit) 2345 adam.AdamOptimizer().minimize(cost, name="optimize") 2346 meta_graph_def = saver_module.export_meta_graph(clear_devices=True) 2347 graph_io.write_graph(meta_graph_def, self.get_temp_dir(), 2348 "meta_graph.pbtxt") 2349 2350 with session.Session(graph=ops_lib.Graph()) as sess: 2351 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2352 self.evaluate(variables.global_variables_initializer()) 2353 sess.run(["new_model/optimize"], { 2354 "new_model/image:0": np.random.random([1, 784]), 2355 "new_model/label:0": np.random.randint( 2356 10, size=[1, 10]) 2357 }) 2358 2359 def testPreserveDatasetAndFunctions(self): 2360 with ops_lib.Graph().as_default() as g: 2361 dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x) 2362 iterator = dataset_ops.make_one_shot_iterator(dataset) 2363 next_element = iterator.get_next() 2364 _ = array_ops.identity(next_element, name="output") 2365 2366 # Generate three MetaGraphDef protos using different code paths. 2367 meta_graph_def_simple = saver_module.export_meta_graph() 2368 meta_graph_def_devices_cleared = saver_module.export_meta_graph( 2369 clear_devices=True) 2370 meta_graph_def_from_graph_def = saver_module.export_meta_graph( 2371 clear_devices=True, graph_def=g.as_graph_def()) 2372 2373 for meta_graph_def in [meta_graph_def_simple, 2374 meta_graph_def_devices_cleared, 2375 meta_graph_def_from_graph_def]: 2376 with session.Session(graph=ops_lib.Graph()) as sess: 2377 saver_module.import_meta_graph(meta_graph_def, import_scope="new_model") 2378 self.evaluate(variables.global_variables_initializer()) 2379 for i in range(10): 2380 self.assertEqual(i * i, sess.run("new_model/output:0")) 2381 with self.assertRaises(errors.OutOfRangeError): 2382 sess.run("new_model/output:0") 2383 2384 2385class CheckpointReaderTest(test.TestCase): 2386 2387 _WRITE_VERSION = saver_pb2.SaverDef.V1 2388 2389 @test_util.run_deprecated_v1 2390 def testDebugString(self): 2391 # Builds a graph. 2392 v0 = variables.VariableV1( 2393 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2394 v1 = variables.VariableV1( 2395 [[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1") 2396 init_all_op = variables.global_variables_initializer() 2397 save = saver_module.Saver( 2398 { 2399 "v0": v0, 2400 "v1": v1 2401 }, write_version=self._WRITE_VERSION) 2402 save_path = os.path.join(self.get_temp_dir(), 2403 "ckpt_for_debug_string" + str(self._WRITE_VERSION)) 2404 with self.cached_session() as sess: 2405 self.evaluate(init_all_op) 2406 # Saves a checkpoint. 2407 save.save(sess, save_path) 2408 2409 # Creates a reader. 2410 reader = pywrap_tensorflow.NewCheckpointReader(save_path) 2411 # Verifies that the tensors exist. 2412 self.assertTrue(reader.has_tensor("v0")) 2413 self.assertTrue(reader.has_tensor("v1")) 2414 debug_string = reader.debug_string() 2415 # Verifies that debug string contains the right strings. 2416 self.assertTrue(compat.as_bytes("v0 (DT_FLOAT) [2,3]") in debug_string) 2417 self.assertTrue(compat.as_bytes("v1 (DT_FLOAT) [3,2,1]") in debug_string) 2418 # Verifies get_variable_to_shape_map() returns the correct information. 2419 var_map = reader.get_variable_to_shape_map() 2420 self.assertEqual([2, 3], var_map["v0"]) 2421 self.assertEqual([3, 2, 1], var_map["v1"]) 2422 # Verifies get_tensor() returns the tensor value. 2423 v0_tensor = reader.get_tensor("v0") 2424 v1_tensor = reader.get_tensor("v1") 2425 self.assertAllEqual(v0.eval(), v0_tensor) 2426 self.assertAllEqual(v1.eval(), v1_tensor) 2427 # Verifies get_tensor() fails for non-existent tensors. 2428 with self.assertRaisesRegexp(errors.NotFoundError, 2429 "v3 not found in checkpoint"): 2430 reader.get_tensor("v3") 2431 2432 def testNonexistentPath(self): 2433 with self.assertRaisesRegexp(errors.NotFoundError, 2434 "Unsuccessful TensorSliceReader"): 2435 pywrap_tensorflow.NewCheckpointReader("non-existent") 2436 2437 2438class CheckpointReaderForV2Test(CheckpointReaderTest): 2439 _WRITE_VERSION = saver_pb2.SaverDef.V2 2440 2441 2442class WriteGraphTest(test.TestCase): 2443 2444 def _get_test_dir(self, dirname): 2445 test_dir = os.path.join(self.get_temp_dir(), dirname) 2446 gfile.MakeDirs(test_dir) 2447 return test_dir 2448 2449 def testWriteGraph(self): 2450 test_dir = self._get_test_dir("write_graph_dir") 2451 variables.VariableV1( 2452 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2453 path = graph_io.write_graph(ops_lib.get_default_graph(), 2454 os.path.join(test_dir, "l1"), "graph.pbtxt") 2455 truth = os.path.join(test_dir, "l1", "graph.pbtxt") 2456 self.assertEqual(path, truth) 2457 self.assertTrue(os.path.exists(path)) 2458 2459 def testRecursiveCreate(self): 2460 test_dir = self._get_test_dir("deep_dir") 2461 variables.VariableV1( 2462 [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0") 2463 path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(), 2464 os.path.join(test_dir, "l1", "l2", "l3"), 2465 "graph.pbtxt") 2466 truth = os.path.join(test_dir, "l1", "l2", "l3", "graph.pbtxt") 2467 self.assertEqual(path, truth) 2468 self.assertTrue(os.path.exists(path)) 2469 2470 2471class ScopedGraphTest(test.TestCase): 2472 2473 def _get_test_dir(self, dirname): 2474 test_dir = os.path.join(self.get_temp_dir(), dirname) 2475 gfile.MakeDirs(test_dir) 2476 return test_dir 2477 2478 def _testScopedSave(self, test_dir, exported_filename, ckpt_filename): 2479 graph = ops_lib.Graph() 2480 with graph.as_default(): 2481 # Creates an inference graph. 2482 # Hidden 1 2483 images = constant_op.constant( 2484 1.2, dtypes.float32, shape=[100, 28], name="images") 2485 with ops_lib.name_scope("hidden1"): 2486 weights1 = variables.VariableV1( 2487 random_ops.truncated_normal( 2488 [28, 128], stddev=1.0 / math.sqrt(float(28))), 2489 name="weights") 2490 # The use of control_flow_ops.cond here is purely for adding test 2491 # coverage the save and restore of control flow context (which doesn't 2492 # make any sense here from a machine learning perspective). The typical 2493 # biases is a simple Variable without the conditions. 2494 biases1 = variables.VariableV1( 2495 control_flow_ops.cond( 2496 math_ops.less(random.random(), 0.5), 2497 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 2498 name="biases") 2499 hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1) 2500 2501 # Hidden 2 2502 with ops_lib.name_scope("hidden2"): 2503 weights2 = variables.VariableV1( 2504 random_ops.truncated_normal( 2505 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2506 name="weights") 2507 2508 # The use of control_flow_ops.while_loop here is purely for adding test 2509 # coverage the save and restore of control flow context (which doesn't 2510 # make any sense here from a machine learning perspective). The typical 2511 # biases is a simple Variable without the conditions. 2512 def loop_cond(it, _): 2513 return it < 2 2514 2515 def loop_body(it, biases2): 2516 biases2 += constant_op.constant(0.1, shape=[32]) 2517 return it + 1, biases2 2518 2519 _, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [ 2520 constant_op.constant(0), variables.VariableV1(array_ops.zeros([32])) 2521 ]) 2522 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2) 2523 # Linear 2524 with ops_lib.name_scope("softmax_linear"): 2525 weights3 = variables.VariableV1( 2526 random_ops.truncated_normal( 2527 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2528 name="weights") 2529 biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases") 2530 logits = math_ops.matmul(hidden2, weights3) + biases3 2531 ops_lib.add_to_collection("logits", logits) 2532 2533 # Adds user_defined proto in three formats: string, bytes and Any. 2534 # Any proto should just pass through. 2535 queue_runner = queue_runner_pb2.QueueRunnerDef(queue_name="test_queue") 2536 ops_lib.add_to_collection("user_defined_string_collection", 2537 str(queue_runner)) 2538 ops_lib.add_to_collection("user_defined_bytes_collection", 2539 queue_runner.SerializeToString()) 2540 any_buf = Any() 2541 any_buf.Pack(queue_runner) 2542 ops_lib.add_to_collection("user_defined_any_collection", any_buf) 2543 2544 _, var_list = meta_graph.export_scoped_meta_graph( 2545 filename=os.path.join(test_dir, exported_filename), 2546 graph=ops_lib.get_default_graph(), 2547 export_scope="hidden1") 2548 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2549 2550 with self.session(graph=graph) as sess: 2551 self.evaluate(variables.global_variables_initializer()) 2552 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2553 saver.save(sess, os.path.join(test_dir, ckpt_filename), write_state=False) 2554 2555 def _testScopedRestore(self, test_dir, exported_filename, 2556 new_exported_filename, ckpt_filename): 2557 graph = ops_lib.Graph() 2558 # Create all the missing inputs. 2559 with graph.as_default(): 2560 new_image = constant_op.constant( 2561 1.2, dtypes.float32, shape=[100, 28], name="images") 2562 var_list = meta_graph.import_scoped_meta_graph( 2563 os.path.join(test_dir, exported_filename), 2564 graph=graph, 2565 input_map={"$unbound_inputs_images": new_image}, 2566 import_scope="new_hidden1") 2567 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 2568 hidden1 = graph.as_graph_element("new_hidden1/Relu:0") 2569 weights1 = graph.as_graph_element("new_hidden1/weights:0") 2570 biases1 = graph.as_graph_element("new_hidden1/biases:0") 2571 2572 with graph.as_default(): 2573 # Hidden 2 2574 with ops_lib.name_scope("hidden2"): 2575 weights = variables.VariableV1( 2576 random_ops.truncated_normal( 2577 [128, 32], stddev=1.0 / math.sqrt(float(128))), 2578 name="weights") 2579 2580 # The use of control_flow_ops.while_loop here is purely for adding test 2581 # coverage the save and restore of control flow context (which doesn't 2582 # make any sense here from a machine learning perspective). The typical 2583 # biases is a simple Variable without the conditions. 2584 def loop_cond(it, _): 2585 return it < 2 2586 2587 def loop_body(it, biases): 2588 biases += constant_op.constant(0.1, shape=[32]) 2589 return it + 1, biases 2590 2591 _, biases = control_flow_ops.while_loop(loop_cond, loop_body, [ 2592 constant_op.constant(0), variables.VariableV1(array_ops.zeros([32])) 2593 ]) 2594 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases) 2595 # Linear 2596 with ops_lib.name_scope("softmax_linear"): 2597 weights = variables.VariableV1( 2598 random_ops.truncated_normal( 2599 [32, 10], stddev=1.0 / math.sqrt(float(32))), 2600 name="weights") 2601 biases = variables.VariableV1(array_ops.zeros([10]), name="biases") 2602 logits = math_ops.matmul(hidden2, weights) + biases 2603 ops_lib.add_to_collection("logits", logits) 2604 2605 # The rest of the variables. 2606 rest_variables = list( 2607 set(variables.global_variables()) - set(var_list.keys())) 2608 init_rest_op = variables.variables_initializer(rest_variables) 2609 2610 with self.session(graph=graph) as sess: 2611 saver = saver_module.Saver(var_list=var_list, max_to_keep=1) 2612 saver.restore(sess, os.path.join(test_dir, ckpt_filename)) 2613 # Verify that we have restored weights1 and biases1. 2614 self.evaluate([weights1, biases1]) 2615 # Initialize the rest of the variables and run logits. 2616 self.evaluate(init_rest_op) 2617 self.evaluate(logits) 2618 2619 # Verifies that we can save the subgraph under "hidden1" and restore it 2620 # into "new_hidden1" in the new graph. 2621 @test_util.run_deprecated_v1 2622 def testScopedSaveAndRestore(self): 2623 test_dir = self._get_test_dir("scoped_export_import") 2624 ckpt_filename = "ckpt" 2625 self._testScopedSave(test_dir, "exported_hidden1.pbtxt", ckpt_filename) 2626 self._testScopedRestore(test_dir, "exported_hidden1.pbtxt", 2627 "exported_new_hidden1.pbtxt", ckpt_filename) 2628 2629 # Verifies that we can copy the subgraph under "hidden1" and copy it 2630 # to different name scope in the same graph or different graph. 2631 @test_util.run_deprecated_v1 2632 def testCopyScopedGraph(self): 2633 test_dir = self._get_test_dir("scoped_copy") 2634 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2635 graph1 = ops_lib.Graph() 2636 with graph1.as_default(): 2637 with ops_lib.name_scope("hidden1"): 2638 images = constant_op.constant( 2639 1.0, dtypes.float32, shape=[3, 2], name="images") 2640 weights1 = variables.VariableV1( 2641 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2642 biases1 = variables.VariableV1([0.1] * 3, name="biases") 2643 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2644 2645 # Run the graph and save scoped checkpoint. 2646 with self.session(graph=graph1) as sess: 2647 self.evaluate(variables.global_variables_initializer()) 2648 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2649 export_scope="hidden1") 2650 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2651 saver.save(sess, saver0_ckpt, write_state=False) 2652 2653 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2654 2655 # Verifies copy to the same graph with the same name fails. 2656 with graph1.as_default(): 2657 with self.assertRaisesWithPredicateMatch( 2658 ValueError, lambda e: "need to be different" in str(e)): 2659 meta_graph.copy_scoped_meta_graph( 2660 from_scope="hidden1", to_scope="hidden1") 2661 2662 # Verifies copy to the same graph. 2663 with graph1.as_default(): 2664 var_list_2 = meta_graph.copy_scoped_meta_graph( 2665 from_scope="hidden1", to_scope="hidden2") 2666 2667 with self.session(graph=graph1) as sess: 2668 saver1 = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2669 saver1.restore(sess, saver0_ckpt) 2670 saver2 = saver_module.Saver(var_list=var_list_2, max_to_keep=1) 2671 saver2.restore(sess, saver0_ckpt) 2672 self.assertAllClose(expected, sess.run("hidden1/relu:0")) 2673 self.assertAllClose(expected, sess.run("hidden2/relu:0")) 2674 2675 # Verifies copy to differen graph. 2676 graph2 = ops_lib.Graph() 2677 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2678 from_scope="hidden1", 2679 to_scope="new_hidden1", 2680 from_graph=graph1, 2681 to_graph=graph2) 2682 2683 with self.session(graph=graph2) as sess: 2684 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2685 saver3.restore(sess, saver0_ckpt) 2686 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2687 2688 @test_util.run_deprecated_v1 2689 def testExportGraphDefWithScope(self): 2690 test_dir = self._get_test_dir("export_graph_def") 2691 saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") 2692 graph1 = ops_lib.Graph() 2693 with graph1.as_default(): 2694 with ops_lib.name_scope("hidden1"): 2695 images = constant_op.constant( 2696 1.0, dtypes.float32, shape=[3, 2], name="images") 2697 weights1 = variables.VariableV1( 2698 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 2699 biases1 = variables.VariableV1([0.1] * 3, name="biases") 2700 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 2701 2702 # Run the graph and save scoped checkpoint. 2703 with self.session(graph=graph1) as sess: 2704 self.evaluate(variables.global_variables_initializer()) 2705 _, var_list_1 = meta_graph.export_scoped_meta_graph( 2706 graph_def=graph1.as_graph_def(), export_scope="hidden1") 2707 saver = saver_module.Saver(var_list=var_list_1, max_to_keep=1) 2708 saver.save(sess, saver0_ckpt, write_state=False) 2709 2710 expected = np.reshape([[5.0999999, 7.0999999, 9.10000038] * 3], (3, 3)) 2711 2712 # Verifies that we can run successfully after restoring. 2713 graph2 = ops_lib.Graph() 2714 new_var_list_1 = meta_graph.copy_scoped_meta_graph( 2715 from_scope="hidden1", 2716 to_scope="new_hidden1", 2717 from_graph=graph1, 2718 to_graph=graph2) 2719 2720 with self.session(graph=graph2) as sess: 2721 saver3 = saver_module.Saver(var_list=new_var_list_1, max_to_keep=1) 2722 saver3.restore(sess, saver0_ckpt) 2723 self.assertAllClose(expected, sess.run("new_hidden1/relu:0")) 2724 2725 @test_util.run_deprecated_v1 2726 def testSerializeSaverWithScope(self): 2727 test_dir = self._get_test_dir("export_graph_def") 2728 saver1_ckpt = os.path.join(test_dir, "saver1.ckpt") 2729 saver2_ckpt = os.path.join(test_dir, "saver2.ckpt") 2730 graph = ops_lib.Graph() 2731 with graph.as_default(): 2732 with ops_lib.name_scope("hidden1"): 2733 variable1 = variables.VariableV1([1.0], name="variable1") 2734 saver1 = saver_module.Saver(var_list=[variable1]) 2735 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1) 2736 2737 with ops_lib.name_scope("hidden2"): 2738 variable2 = variables.VariableV1([2.0], name="variable2") 2739 saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/") 2740 graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2) 2741 2742 with self.session(graph=graph) as sess: 2743 self.evaluate(variables.global_variables_initializer()) 2744 saver1.save(sess, saver1_ckpt, write_state=False) 2745 saver2.save(sess, saver2_ckpt, write_state=False) 2746 2747 graph1 = ops_lib.Graph() 2748 var_dict1 = meta_graph.copy_scoped_meta_graph( 2749 from_scope="hidden1", 2750 to_scope="new_hidden1", 2751 from_graph=graph, 2752 to_graph=graph1) 2753 self.assertEqual(1, len(var_dict1)) 2754 2755 saver_list1 = graph1.get_collection(ops_lib.GraphKeys.SAVERS) 2756 self.assertEqual(1, len(saver_list1)) 2757 2758 with self.session(graph=graph1) as sess: 2759 saver_list1[0].restore(sess, saver1_ckpt) 2760 self.assertEqual(1.0, self.evaluate(var_dict1["variable1:0"])) 2761 2762 graph2 = ops_lib.Graph() 2763 var_dict2 = meta_graph.copy_scoped_meta_graph( 2764 from_scope="hidden2", 2765 to_scope="new_hidden2", 2766 from_graph=graph, 2767 to_graph=graph2) 2768 self.assertEqual(1, len(var_dict2)) 2769 2770 saver_list2 = graph2.get_collection(ops_lib.GraphKeys.SAVERS) 2771 self.assertEqual(1, len(saver_list2)) 2772 2773 with self.session(graph=graph2) as sess: 2774 saver_list2[0].restore(sess, saver2_ckpt) 2775 self.assertEqual(2.0, self.evaluate(var_dict2["variable2:0"])) 2776 2777 2778class _OwnsAVariableSimple(trackable_base.Trackable): 2779 """A Trackable object which can be saved using a tf.train.Saver.""" 2780 2781 def __init__(self): 2782 self.non_dep_variable = variable_scope.get_variable( 2783 name="non_dep_variable", initializer=6., use_resource=True) 2784 2785 def _gather_saveables_for_checkpoint(self): 2786 return {trackable_base.VARIABLE_VALUE_KEY: self.non_dep_variable} 2787 2788 # The Saver sorts by name before parsing, so we need a name property. 2789 @property 2790 def name(self): 2791 return self.non_dep_variable.name 2792 2793 2794class _MirroringSaveable( 2795 saver_module.BaseSaverBuilder.ResourceVariableSaveable): 2796 2797 def __init__(self, primary_variable, mirrored_variable, name): 2798 self._primary_variable = primary_variable 2799 self._mirrored_variable = mirrored_variable 2800 super(_MirroringSaveable, self).__init__( 2801 self._primary_variable, "", name) 2802 2803 def restore(self, restored_tensors, restored_shapes): 2804 """Restore the same value into both variables.""" 2805 tensor, = restored_tensors 2806 return control_flow_ops.group( 2807 self._primary_variable.assign(tensor), 2808 self._mirrored_variable.assign(tensor)) 2809 2810 2811class _OwnsMirroredVariables(trackable_base.Trackable): 2812 """A Trackable object which returns a more complex SaveableObject.""" 2813 2814 def __init__(self): 2815 self.non_dep_variable = variable_scope.get_variable( 2816 name="non_dep_variable", initializer=6., use_resource=True) 2817 self.mirrored = variable_scope.get_variable( 2818 name="mirrored", initializer=15., use_resource=True) 2819 2820 def _gather_saveables_for_checkpoint(self): 2821 def _saveable_factory(name=self.non_dep_variable.name): 2822 return _MirroringSaveable( 2823 primary_variable=self.non_dep_variable, 2824 mirrored_variable=self.mirrored, 2825 name=name) 2826 return {trackable_base.VARIABLE_VALUE_KEY: _saveable_factory} 2827 2828 # The Saver sorts by name before parsing, so we need a name property. 2829 @property 2830 def name(self): 2831 return self.non_dep_variable.name 2832 2833 2834class NonLayerTrackable(trackable_tracking.AutoTrackable): 2835 2836 def __init__(self): 2837 super(NonLayerTrackable, self).__init__() 2838 self.a_variable = trackable_utils.add_variable( 2839 self, name="a_variable", shape=[]) 2840 2841 2842class MyModel(training.Model): 2843 """A concrete Model for testing.""" 2844 2845 def __init__(self): 2846 super(MyModel, self).__init__() 2847 self._named_dense = core.Dense(1, use_bias=True) 2848 self._second = core.Dense(1, use_bias=False) 2849 # We can still track Trackables which aren't Layers. 2850 self._non_layer = NonLayerTrackable() 2851 2852 def call(self, values): 2853 ret = self._second(self._named_dense(values)) 2854 return ret 2855 2856 2857class TrackableCompatibilityTests(test.TestCase): 2858 2859 # TODO(allenl): Track down python3 reference cycles in these tests. 2860 @test_util.run_in_graph_and_eager_modes 2861 def testNotSaveableButIsTrackable(self): 2862 v = _OwnsAVariableSimple() 2863 test_dir = self.get_temp_dir() 2864 prefix = os.path.join(test_dir, "ckpt") 2865 for saver in (saver_module.Saver(var_list=[v]), 2866 saver_module.Saver(var_list={"v": v})): 2867 with self.cached_session() as sess: 2868 self.evaluate(v.non_dep_variable.assign(42.)) 2869 save_path = saver.save(sess, prefix) 2870 self.evaluate(v.non_dep_variable.assign(43.)) 2871 saver.restore(sess, save_path) 2872 self.assertEqual(42., self.evaluate(v.non_dep_variable)) 2873 2874 @test_util.run_in_graph_and_eager_modes 2875 def testMoreComplexSaveableReturned(self): 2876 v = _OwnsMirroredVariables() 2877 test_dir = self.get_temp_dir() 2878 prefix = os.path.join(test_dir, "ckpt") 2879 self.evaluate(v.non_dep_variable.assign(42.)) 2880 for saver in (saver_module.Saver(var_list=[v]), 2881 saver_module.Saver(var_list={"v": v})): 2882 with self.cached_session() as sess: 2883 save_path = saver.save(sess, prefix) 2884 self.evaluate(v.non_dep_variable.assign(43.)) 2885 self.evaluate(v.mirrored.assign(44.)) 2886 saver.restore(sess, save_path) 2887 self.assertEqual(42., self.evaluate(v.non_dep_variable)) 2888 self.assertEqual(42., self.evaluate(v.mirrored)) 2889 2890 def testSingleTensorEvaluation(self): 2891 2892 class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject): 2893 2894 def __init__(self, name): 2895 self.eval_count = 0 2896 def _tensor(): 2897 self.eval_count += 1 2898 return constant_op.constant([1.]) 2899 dummy_op = constant_op.constant([2.]) 2900 super(_CountingSaveable, self).__init__( 2901 dummy_op, 2902 [saver_module.BaseSaverBuilder.SaveSpec( 2903 _tensor, "", name, dtype=dummy_op.dtype)], 2904 name) 2905 2906 def restore(self, restored_tensors, restored_shapes): 2907 """Restore the same value into both variables.""" 2908 pass 2909 2910 with context.eager_mode(): 2911 v = _CountingSaveable("foo") 2912 saver = saver_module.Saver(var_list=[v]) 2913 test_dir = self.get_temp_dir() 2914 prefix = os.path.join(test_dir, "ckpt") 2915 with self.cached_session() as sess: 2916 save_path = saver.save(sess, prefix) 2917 self.assertEqual(1, v.eval_count) 2918 saver.restore(sess, save_path) 2919 self.assertEqual(1, v.eval_count) 2920 2921 def _initialized_model(self): 2922 input_value = constant_op.constant([[3.]]) 2923 model = MyModel() 2924 optimizer = adam.AdamOptimizer(0.001) 2925 optimizer_step = training_util.get_or_create_global_step() 2926 root_trackable = trackable_utils.Checkpoint( 2927 optimizer=optimizer, model=model, optimizer_step=optimizer_step) 2928 train_op = optimizer.minimize( 2929 functools.partial(model, input_value), 2930 global_step=optimizer_step) 2931 self.evaluate(trackable_utils.gather_initializers( 2932 root_trackable)) 2933 self.evaluate(train_op) 2934 # A regular variable, a slot variable, and a non-slot Optimizer variable 2935 # with known values to check when loading. 2936 self.evaluate(model._named_dense.bias.assign([1.])) 2937 self.evaluate(optimizer.get_slot( 2938 var=model._named_dense.bias, name="m").assign([2.])) 2939 beta1_power, _ = optimizer._get_beta_accumulators() 2940 self.evaluate(beta1_power.assign(3.)) 2941 return root_trackable 2942 2943 def _set_sentinels(self, root_trackable): 2944 self.evaluate(root_trackable.model._named_dense.bias.assign([101.])) 2945 self.evaluate( 2946 root_trackable.optimizer.get_slot( 2947 var=root_trackable.model._named_dense.bias, name="m") 2948 .assign([102.])) 2949 beta1_power, _ = root_trackable.optimizer._get_beta_accumulators() 2950 self.evaluate(beta1_power.assign(103.)) 2951 2952 def _check_sentinels(self, root_trackable): 2953 self.assertAllEqual( 2954 [1.], self.evaluate(root_trackable.model._named_dense.bias)) 2955 self.assertAllEqual([2.], self.evaluate( 2956 root_trackable.optimizer.get_slot( 2957 var=root_trackable.model._named_dense.bias, name="m"))) 2958 beta1_power, _ = root_trackable.optimizer._get_beta_accumulators() 2959 self.assertAllEqual(3., self.evaluate(beta1_power)) 2960 2961 def testVariableNotFoundErrorRaised(self): 2962 # Restore does some tricky exception handling to figure out if it should 2963 # load an object-based checkpoint. Tests that the exception handling isn't 2964 # too broad. 2965 checkpoint_directory = self.get_temp_dir() 2966 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 2967 2968 a = resource_variable_ops.ResourceVariable(1., name="a") 2969 b = resource_variable_ops.ResourceVariable(1., name="b") 2970 a_saver = saver_module.Saver([a]) 2971 b_saver = saver_module.Saver([b]) 2972 with self.cached_session() as sess: 2973 self.evaluate(a.initializer) 2974 save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) 2975 with self.assertRaisesRegexp( 2976 errors.NotFoundError, "Key b not found in checkpoint"): 2977 b_saver.restore(sess=sess, save_path=save_path) 2978 2979 with self.assertRaises(errors.NotFoundError) as cs: 2980 b_saver.restore(sess=sess, save_path=save_path) 2981 2982 # Make sure we don't have a confusing "During handling of the above 2983 # exception" block in Python 3. 2984 self.assertNotIn("NewCheckpointReader", cs.exception.message) 2985 2986 @test_util.run_v1_only("b/120545219") 2987 def testGraphChangedForRestoreErrorRaised(self): 2988 checkpoint_directory = self.get_temp_dir() 2989 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 2990 2991 with ops_lib.Graph().as_default() as g: 2992 a = variables.VariableV1(1., name="a") 2993 a_saver = saver_module.Saver([a]) 2994 2995 with self.session(graph=g) as sess: 2996 self.evaluate(a.initializer) 2997 save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix) 2998 2999 with ops_lib.Graph().as_default() as g: 3000 a = variables.VariableV1([1.], name="a") 3001 a_saver = saver_module.Saver([a]) 3002 with self.session(graph=g) as sess: 3003 with self.assertRaisesRegexp( 3004 errors.InvalidArgumentError, 3005 "a mismatch between the current graph and the graph"): 3006 a_saver.restore(sess=sess, save_path=save_path) 3007 3008 def testLoadFromObjectBasedGraph(self): 3009 checkpoint_directory = self.get_temp_dir() 3010 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 3011 3012 save_graph = ops_lib.Graph() 3013 with save_graph.as_default(), self.session(graph=save_graph) as sess: 3014 root = self._initialized_model() 3015 object_saver = trackable_utils.Checkpoint(root=root) 3016 save_path = object_saver.save(file_prefix=checkpoint_prefix) 3017 3018 # An incompatible object-based checkpoint to check error messages 3019 var = resource_variable_ops.ResourceVariable(1., name="a") 3020 self.evaluate(var.initializer) 3021 second_saver = trackable_utils.Checkpoint(v=var) 3022 second_path = second_saver.save(file_prefix=os.path.join( 3023 checkpoint_directory, "second")) 3024 3025 restore_graph = ops_lib.Graph() 3026 with restore_graph.as_default(), self.session( 3027 graph=restore_graph) as sess: 3028 root = self._initialized_model() 3029 self._set_sentinels(root) 3030 saver = saver_module.Saver() 3031 saver.restore(sess=sess, save_path=save_path) 3032 self._check_sentinels(root) 3033 before_second_restore_ops = restore_graph.get_operations() 3034 # Test that multiple restores do not pollute the graph 3035 saver.restore(sess=sess, save_path=save_path) 3036 self.assertEqual(before_second_restore_ops, 3037 restore_graph.get_operations()) 3038 with self.assertRaisesRegexp(errors.NotFoundError, 3039 "Could not find some variables"): 3040 saver.restore(sess=sess, save_path=second_path) 3041 3042 def testLoadFromObjectBasedEager(self): 3043 checkpoint_directory = self.get_temp_dir() 3044 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 3045 3046 save_graph = ops_lib.Graph() 3047 with save_graph.as_default(), self.session(graph=save_graph): 3048 root = self._initialized_model() 3049 object_saver = trackable_utils.Checkpoint(root=root) 3050 save_path = object_saver.save(file_prefix=checkpoint_prefix) 3051 3052 with context.eager_mode(): 3053 root = self._initialized_model() 3054 self._set_sentinels(root) 3055 saver = saver_module.Saver( 3056 root.model.variables + root.optimizer.variables()) 3057 saver.restore(sess=None, save_path=save_path) 3058 self._check_sentinels(root) 3059 3060 3061if __name__ == "__main__": 3062 test.main() 3063