1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Monitors tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import shutil 23import tempfile 24import time 25 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.contrib import testing 29from tensorflow.contrib.framework.python.framework import checkpoint_utils 30from tensorflow.contrib.learn.python import learn 31from tensorflow.contrib.learn.python.learn import estimators 32from tensorflow.python.client import session as session_lib 33from tensorflow.python.estimator import estimator as core_estimator 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import state_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import test 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.summary import summary 42from tensorflow.python.training import checkpoint_management 43from tensorflow.python.training import gradient_descent 44from tensorflow.python.training import monitored_session 45from tensorflow.python.training import training_util 46 47 48class _MyEveryN(learn.monitors.EveryN): 49 50 def __init__(self, every_n_steps=100, first_n_steps=1): 51 super(_MyEveryN, self).__init__( 52 every_n_steps=every_n_steps, first_n_steps=first_n_steps) 53 self._steps_begun = [] 54 self._steps_ended = [] 55 self._post_steps = [] 56 57 @property 58 def steps_begun(self): 59 return self._steps_begun 60 61 @property 62 def steps_ended(self): 63 return self._steps_ended 64 65 @property 66 def post_steps(self): 67 return self._post_steps 68 69 def every_n_step_begin(self, step): 70 super(_MyEveryN, self).every_n_step_begin(step) 71 self._steps_begun.append(step) 72 return [] 73 74 def every_n_step_end(self, step, outputs): 75 super(_MyEveryN, self).every_n_step_end(step, outputs) 76 self._steps_ended.append(step) 77 return False 78 79 def every_n_post_step(self, step, session): 80 super(_MyEveryN, self).every_n_post_step(step, session) 81 self._post_steps.append(step) 82 return False 83 84 85class MonitorsTest(test.TestCase): 86 """Monitors tests.""" 87 88 def setUp(self): 89 # Mock out logging calls so we can verify whether correct tensors are being 90 # monitored. 91 self._actual_log = logging.info 92 93 def mockLog(*args, **kwargs): # pylint: disable=invalid-name 94 self.logged_message = args 95 self._actual_log(*args, **kwargs) 96 97 logging.info = mockLog 98 99 def tearDown(self): 100 logging.info = self._actual_log 101 102 def _run_monitor(self, 103 monitor, 104 num_epochs=3, 105 num_steps_per_epoch=10, 106 pass_max_steps=True): 107 if pass_max_steps: 108 max_steps = num_epochs * num_steps_per_epoch - 1 109 else: 110 max_steps = None 111 monitor.begin(max_steps=max_steps) 112 for epoch in xrange(num_epochs): 113 monitor.epoch_begin(epoch) 114 should_stop = False 115 step = epoch * num_steps_per_epoch 116 next_epoch_step = step + num_steps_per_epoch 117 while (not should_stop) and (step < next_epoch_step): 118 tensors = monitor.step_begin(step) 119 output = ops.get_default_session().run(tensors) if tensors else {} 120 output = dict( 121 zip([t.name if isinstance(t, ops.Tensor) else t for t in tensors], 122 output)) 123 should_stop = monitor.step_end(step=step, output=output) 124 monitor.post_step(step=step, session=None) 125 step += 1 126 monitor.epoch_end(epoch) 127 monitor.end() 128 129 def test_base_monitor(self): 130 with ops.Graph().as_default() as g, self.session(g): 131 self._run_monitor(learn.monitors.BaseMonitor()) 132 133 def test_every_0(self): 134 monitor = _MyEveryN(every_n_steps=0, first_n_steps=-1) 135 with ops.Graph().as_default() as g, self.session(g): 136 self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) 137 expected_steps = list(range(30)) 138 self.assertAllEqual(expected_steps, monitor.steps_begun) 139 self.assertAllEqual(expected_steps, monitor.steps_ended) 140 self.assertAllEqual(expected_steps, monitor.post_steps) 141 142 def test_every_1(self): 143 monitor = _MyEveryN(every_n_steps=1, first_n_steps=-1) 144 with ops.Graph().as_default() as g, self.session(g): 145 self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) 146 expected_steps = list(range(1, 30)) 147 self.assertEqual(expected_steps, monitor.steps_begun) 148 self.assertEqual(expected_steps, monitor.steps_ended) 149 self.assertEqual(expected_steps, monitor.post_steps) 150 151 def test_every_2(self): 152 monitor = _MyEveryN(every_n_steps=2, first_n_steps=-1) 153 with ops.Graph().as_default() as g, self.session(g): 154 self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) 155 expected_steps = list(range(2, 29, 2)) + [29] 156 self.assertEqual(expected_steps, monitor.steps_begun) 157 self.assertEqual(expected_steps, monitor.steps_ended) 158 self.assertEqual(expected_steps, monitor.post_steps) 159 160 def test_every_8(self): 161 monitor = _MyEveryN(every_n_steps=8, first_n_steps=2) 162 with ops.Graph().as_default() as g, self.session(g): 163 self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) 164 expected_steps = [0, 1, 2, 10, 18, 26, 29] 165 self.assertEqual(expected_steps, monitor.steps_begun) 166 self.assertEqual(expected_steps, monitor.steps_ended) 167 self.assertEqual(expected_steps, monitor.post_steps) 168 169 def test_every_8_no_max_steps(self): 170 monitor = _MyEveryN(every_n_steps=8, first_n_steps=2) 171 with ops.Graph().as_default() as g, self.session(g): 172 self._run_monitor( 173 monitor, num_epochs=3, num_steps_per_epoch=10, pass_max_steps=False) 174 begin_end_steps = [0, 1, 2, 10, 18, 26] 175 post_steps = [0, 1, 2, 10, 18, 26, 29] 176 self.assertEqual(begin_end_steps, monitor.steps_begun) 177 self.assertEqual(begin_end_steps, monitor.steps_ended) 178 self.assertEqual(post_steps, monitor.post_steps) 179 180 def test_every_8_recovered_after_step_begin(self): 181 monitor = _MyEveryN(every_n_steps=8) 182 with ops.Graph().as_default() as g, self.session(g): 183 for step in [8, 16]: 184 monitor.step_begin(step) 185 monitor.step_begin(step) 186 monitor.step_end(step, output=None) 187 monitor.post_step(step, session=None) 188 # It should call begin again since, end was not called 189 self.assertEqual([8, 8, 16, 16], monitor.steps_begun) 190 self.assertEqual([8, 16], monitor.steps_ended) 191 self.assertEqual([8, 16], monitor.post_steps) 192 193 def test_every_8_recovered_after_step_end(self): 194 monitor = _MyEveryN(every_n_steps=8) 195 with ops.Graph().as_default() as g, self.session(g): 196 for step in [8, 16]: 197 monitor.step_begin(step) 198 monitor.step_end(step, output=None) 199 monitor.post_step(step, session=None) 200 monitor.step_begin(step) 201 monitor.step_end(step, output=None) 202 monitor.post_step(step, session=None) 203 # It should not call begin twice since end was called 204 self.assertEqual([8, 16], monitor.steps_begun) 205 self.assertEqual([8, 16], monitor.steps_ended) 206 self.assertEqual([8, 16], monitor.post_steps) 207 208 def test_every_8_call_post_step_at_the_end(self): 209 monitor = _MyEveryN(every_n_steps=8) 210 with ops.Graph().as_default() as g, self.session(g): 211 monitor.begin() 212 for step in [8, 16]: 213 monitor.step_begin(step) 214 monitor.step_end(step, output=None) 215 monitor.post_step(step, session=None) 216 monitor.step_begin(19) 217 monitor.step_end(19, output=None) 218 monitor.post_step(19, session=None) 219 monitor.end(session=None) 220 # It should not call begin twice since end was called 221 self.assertEqual([8, 16], monitor.steps_begun) 222 self.assertEqual([8, 16], monitor.steps_ended) 223 self.assertEqual([8, 16, 19], monitor.post_steps) 224 225 def test_every_8_call_post_step_should_not_be_called_twice(self): 226 monitor = _MyEveryN(every_n_steps=8) 227 with ops.Graph().as_default() as g, self.session(g): 228 monitor.begin() 229 for step in [8, 16]: 230 monitor.step_begin(step) 231 monitor.step_end(step, output=None) 232 monitor.post_step(step, session=None) 233 monitor.step_begin(16) 234 monitor.step_end(16, output=None) 235 monitor.post_step(16, session=None) 236 monitor.end(session=None) 237 # It should not call begin twice since end was called 238 self.assertEqual([8, 16], monitor.steps_begun) 239 self.assertEqual([8, 16], monitor.steps_ended) 240 self.assertEqual([8, 16], monitor.post_steps) 241 242 def test_print(self): 243 with ops.Graph().as_default() as g, self.session(g): 244 t = constant_op.constant(42.0, name='foo') 245 self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name])) 246 self.assertRegexpMatches(str(self.logged_message), t.name) 247 248 def test_logging_trainable(self): 249 with ops.Graph().as_default() as g, self.session(g): 250 var = variables.VariableV1(constant_op.constant(42.0), name='foo') 251 var.initializer.run() 252 cof = constant_op.constant(1.0) 253 loss = math_ops.subtract( 254 math_ops.multiply(var, cof), constant_op.constant(1.0)) 255 train_step = gradient_descent.GradientDescentOptimizer(0.5).minimize(loss) 256 ops.get_default_session().run(train_step) 257 self._run_monitor(learn.monitors.LoggingTrainable('foo')) 258 self.assertRegexpMatches(str(self.logged_message), var.name) 259 260 def test_summary_saver(self): 261 with ops.Graph().as_default() as g, self.session(g): 262 log_dir = 'log/dir' 263 summary_writer = testing.FakeSummaryWriter(log_dir, g) 264 var = variables.VariableV1(0.0) 265 var.initializer.run() 266 tensor = state_ops.assign_add(var, 1.0) 267 summary_op = summary.scalar('my_summary', tensor) 268 self._run_monitor( 269 learn.monitors.SummarySaver( 270 summary_op=summary_op, 271 save_steps=8, 272 summary_writer=summary_writer), 273 num_epochs=3, 274 num_steps_per_epoch=10) 275 summary_writer.assert_summaries( 276 test_case=self, 277 expected_logdir=log_dir, 278 expected_graph=g, 279 expected_summaries={ 280 0: { 281 'my_summary': 1.0 282 }, 283 1: { 284 'my_summary': 2.0 285 }, 286 9: { 287 'my_summary': 3.0 288 }, 289 17: { 290 'my_summary': 4.0 291 }, 292 25: { 293 'my_summary': 5.0 294 }, 295 29: { 296 'my_summary': 6.0 297 }, 298 }) 299 300 def _assert_validation_monitor(self, 301 monitor, 302 expected_early_stopped=False, 303 expected_best_step=None, 304 expected_best_value=None, 305 expected_best_metrics=None): 306 self.assertEqual(expected_early_stopped, monitor.early_stopped) 307 self.assertEqual(expected_best_step, monitor.best_step) 308 self.assertEqual(expected_best_value, monitor.best_value) 309 self.assertEqual(expected_best_metrics, monitor.best_metrics) 310 311 def test_validation_monitor_no_estimator(self): 312 monitor = learn.monitors.ValidationMonitor( 313 x=constant_op.constant(2.0), every_n_steps=0) 314 self._assert_validation_monitor(monitor) 315 with ops.Graph().as_default() as g, self.session(g): 316 with self.assertRaisesRegexp(ValueError, 'set_estimator'): 317 self._run_monitor(monitor) 318 319 @test.mock.patch.object(estimators, 'Estimator', autospec=True) 320 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 321 def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint, 322 mock_estimator_class): 323 estimator = mock_estimator_class() 324 model_dir = 'model/dir' 325 estimator.model_dir = model_dir 326 mock_latest_checkpoint.return_value = None 327 328 # Do nothing with no checkpoint. 329 monitor = learn.monitors.ValidationMonitor( 330 x=constant_op.constant(2.0), every_n_steps=0) 331 self._assert_validation_monitor(monitor) 332 monitor.set_estimator(estimator) 333 with ops.Graph().as_default() as g, self.session(g): 334 self._run_monitor(monitor) 335 self._assert_validation_monitor(monitor) 336 mock_latest_checkpoint.assert_called_with(model_dir) 337 338 @test.mock.patch.object(estimators, 'Estimator', autospec=True) 339 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 340 def test_validation_monitor_no_early_stopping_rounds(self, 341 mock_latest_checkpoint, 342 mock_estimator_class): 343 estimator = mock_estimator_class() 344 model_dir = 'model/dir' 345 estimator.model_dir = model_dir 346 estimator.evaluate.return_value = {} 347 mock_latest_checkpoint.return_value = '%s/ckpt' % model_dir 348 349 # Do nothing with early_stopping_rounds=None. 350 monitor = learn.monitors.ValidationMonitor( 351 x=constant_op.constant(2.0), every_n_steps=0) 352 self._assert_validation_monitor(monitor) 353 monitor.set_estimator(estimator) 354 with ops.Graph().as_default() as g, self.session(g): 355 self._run_monitor(monitor) 356 self._assert_validation_monitor(monitor) 357 358 @test.mock.patch.object(estimators, 'Estimator', autospec=True) 359 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 360 def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint, 361 mock_estimator_class): 362 estimator = mock_estimator_class() 363 model_dir = 'model/dir' 364 estimator.model_dir = model_dir 365 estimator.evaluate.return_value = {} 366 mock_latest_checkpoint.return_value = '%s/ckpt' % model_dir 367 368 # Fail for missing metric. 369 monitor = learn.monitors.ValidationMonitor( 370 x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=1) 371 self._assert_validation_monitor(monitor) 372 monitor.set_estimator(estimator) 373 with ops.Graph().as_default() as g, self.session(g): 374 with self.assertRaisesRegexp(ValueError, 'missing from outputs'): 375 self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1) 376 377 @test.mock.patch.object(estimators, 'Estimator', autospec=True) 378 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 379 def test_validation_monitor(self, mock_latest_checkpoint, 380 mock_estimator_class): 381 estimator = mock_estimator_class() 382 model_dir = 'model/dir' 383 estimator.model_dir = model_dir 384 validation_outputs = {'loss': None, 'auc': None} 385 estimator.evaluate.return_value = validation_outputs 386 387 monitor = learn.monitors.ValidationMonitor( 388 x=constant_op.constant(2.0), 389 every_n_steps=0, 390 early_stopping_rounds=2, 391 check_interval_secs=None) 392 393 self._assert_validation_monitor(monitor) 394 monitor.set_estimator(estimator) 395 with ops.Graph().as_default() as g, self.session(g): 396 monitor.begin(max_steps=100) 397 monitor.epoch_begin(epoch=0) 398 self.assertEqual(0, estimator.evaluate.call_count) 399 400 # Step 0, initial loss. 401 step = 0 402 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 403 validation_outputs['loss'] = 42.0 404 validation_outputs['auc'] = 0.5 405 self.assertEqual(0, len(monitor.step_begin(step=step))) 406 self.assertFalse(monitor.step_end(step=step, output={})) 407 self.assertEqual(1, estimator.evaluate.call_count) 408 self._assert_validation_monitor( 409 monitor, expected_best_step=0, expected_best_value=42.0, 410 expected_best_metrics={'loss': 42.0, 'auc': 0.5}) 411 monitor.post_step(step=step, session=None) 412 413 # Step 1, same checkpoint, no eval. 414 step = 1 415 self.assertEqual(0, len(monitor.step_begin(step=step))) 416 self.assertFalse(monitor.step_end(step=step, output={})) 417 self.assertEqual(1, estimator.evaluate.call_count) 418 self._assert_validation_monitor( 419 monitor, expected_best_step=0, expected_best_value=42.0, 420 expected_best_metrics={'loss': 42.0, 'auc': 0.5}) 421 monitor.post_step(step=step, session=None) 422 423 # Step 2, lower loss. 424 step = 2 425 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 426 validation_outputs['loss'] = 40.0 427 validation_outputs['auc'] = 0.6 428 self.assertEqual(0, len(monitor.step_begin(step=step))) 429 self.assertFalse(monitor.step_end(step=step, output={})) 430 self.assertEqual(2, estimator.evaluate.call_count) 431 self._assert_validation_monitor( 432 monitor, expected_best_step=2, expected_best_value=40.0, 433 expected_best_metrics={'loss': 40.0, 'auc': 0.6}) 434 monitor.post_step(step=step, session=None) 435 436 # Step 3, higher loss. 437 step = 3 438 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 439 validation_outputs['loss'] = 44.0 440 validation_outputs['auc'] = 0.7 441 self.assertEqual(0, len(monitor.step_begin(step=step))) 442 self.assertFalse(monitor.step_end(step=step, output={})) 443 self.assertEqual(3, estimator.evaluate.call_count) 444 self._assert_validation_monitor( 445 monitor, expected_best_step=2, expected_best_value=40.0, 446 expected_best_metrics={'loss': 40.0, 'auc': 0.6}) 447 monitor.post_step(step=step, session=None) 448 449 # Step 4, higher loss for 2 steps, early stopping. 450 step = 4 451 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 452 validation_outputs['loss'] = 43.0 453 self.assertEqual(0, len(monitor.step_begin(step=step))) 454 self.assertTrue(monitor.step_end(step=step, output={})) 455 self.assertEqual(4, estimator.evaluate.call_count) 456 self._assert_validation_monitor( 457 monitor, 458 expected_early_stopped=True, 459 expected_best_step=2, 460 expected_best_value=40.0, 461 expected_best_metrics={'loss': 40.0, 'auc': 0.6}) 462 monitor.post_step(step=step, session=None) 463 464 monitor.epoch_end(epoch=0) 465 monitor.end() 466 467 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 468 def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint): 469 estimator = test.mock.Mock(spec=core_estimator.Estimator) 470 model_dir = 'model/dir' 471 estimator.model_dir = model_dir 472 validation_outputs = {'loss': None, 'auc': None} 473 estimator.evaluate.return_value = validation_outputs 474 475 monitor = learn.monitors.ValidationMonitor( 476 input_fn=lambda: constant_op.constant(2.0), 477 every_n_steps=0, early_stopping_rounds=2) 478 self._assert_validation_monitor(monitor) 479 monitor.set_estimator(estimator) 480 with ops.Graph().as_default() as g, self.session(g): 481 monitor.begin(max_steps=100) 482 monitor.epoch_begin(epoch=0) 483 self.assertEqual(0, estimator.evaluate.call_count) 484 485 # Step 0, initial loss. 486 step = 0 487 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 488 validation_outputs['loss'] = 42.0 489 validation_outputs['auc'] = 0.5 490 self.assertEqual(0, len(monitor.step_begin(step=step))) 491 self.assertFalse(monitor.step_end(step=step, output={})) 492 self.assertEqual(1, estimator.evaluate.call_count) 493 self._assert_validation_monitor( 494 monitor, expected_best_step=0, expected_best_value=42.0, 495 expected_best_metrics={'loss': 42.0, 'auc': 0.5}) 496 monitor.post_step(step=step, session=None) 497 498 @test.mock.patch.object(checkpoint_management, 'latest_checkpoint') 499 def test_validation_monitor_fail_with_core_estimator_and_metrics( 500 self, mock_latest_checkpoint): 501 estimator = test.mock.Mock(spec=core_estimator.Estimator) 502 model_dir = 'model/dir' 503 estimator.model_dir = model_dir 504 validation_outputs = {'loss': None} 505 estimator.evaluate.return_value = validation_outputs 506 507 monitor = learn.monitors.ValidationMonitor( 508 input_fn=lambda: constant_op.constant(2.0), 509 metrics=constant_op.constant(2.0), 510 every_n_steps=0, early_stopping_rounds=2) 511 monitor.set_estimator(estimator) 512 with ops.Graph().as_default() as g, self.session(g): 513 monitor.begin(max_steps=100) 514 monitor.epoch_begin(epoch=0) 515 516 with self.assertRaisesRegexp( 517 ValueError, 518 'tf.estimator.Estimator does not support .* metrics'): 519 step = 0 520 mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step) 521 validation_outputs['loss'] = 42.0 522 self.assertEqual(0, len(monitor.step_begin(step=step))) 523 self.assertFalse(monitor.step_end(step=step, output={})) 524 525 def test_graph_dump(self): 526 monitor0 = learn.monitors.GraphDump() 527 monitor1 = learn.monitors.GraphDump() 528 with ops.Graph().as_default() as g, self.session(g): 529 const_var = variables.VariableV1(42.0, name='my_const') 530 counter_var = variables.VariableV1(0.0, name='my_counter') 531 assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add') 532 variables.global_variables_initializer().run() 533 534 self._run_monitor(monitor0, num_epochs=3, num_steps_per_epoch=10) 535 self.assertEqual({ 536 step: { 537 const_var.name: 42.0, 538 counter_var.name: step + 1.0, 539 assign_add.name: step + 1.0, 540 } 541 for step in xrange(30) 542 }, monitor0.data) 543 544 self._run_monitor(monitor1, num_epochs=3, num_steps_per_epoch=10) 545 self.assertEqual({ 546 step: { 547 const_var.name: 42.0, 548 counter_var.name: step + 31.0, 549 assign_add.name: step + 31.0, 550 } 551 for step in xrange(30) 552 }, monitor1.data) 553 554 for step in xrange(30): 555 matched, non_matched = monitor1.compare(monitor0, step=step) 556 self.assertEqual([const_var.name], matched) 557 self.assertEqual({ 558 assign_add.name: (step + 31.0, step + 1.0), 559 counter_var.name: (step + 31.0, step + 1.0), 560 }, non_matched) 561 matched, non_matched = monitor0.compare(monitor1, step=step) 562 self.assertEqual([const_var.name], matched) 563 self.assertEqual({ 564 assign_add.name: (step + 1.0, step + 31.0), 565 counter_var.name: (step + 1.0, step + 31.0), 566 }, non_matched) 567 568 def test_capture_variable(self): 569 monitor = learn.monitors.CaptureVariable( 570 var_name='my_assign_add:0', every_n=8, first_n=2) 571 with ops.Graph().as_default() as g, self.session(g): 572 var = variables.VariableV1(0.0, name='my_var') 573 var.initializer.run() 574 state_ops.assign_add(var, 1.0, name='my_assign_add') 575 self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10) 576 self.assertEqual({ 577 0: 1.0, 578 1: 2.0, 579 2: 3.0, 580 10: 4.0, 581 18: 5.0, 582 26: 6.0, 583 29: 7.0, 584 }, monitor.values) 585 586 587class StopAtStepTest(test.TestCase): 588 589 def test_raise_in_both_last_step_and_num_steps(self): 590 with self.assertRaises(ValueError): 591 learn.monitors.StopAtStep(num_steps=10, last_step=20) 592 593 def test_stop_based_on_last_step(self): 594 m = learn.monitors.StopAtStep(last_step=10) 595 m.step_begin(5) 596 self.assertFalse(m.step_end(5, None)) 597 m.step_begin(9) 598 self.assertFalse(m.step_end(9, None)) 599 m.step_begin(10) 600 self.assertTrue(m.step_end(10, None)) 601 m.step_begin(11) 602 self.assertTrue(m.step_end(11, None)) 603 604 def test_stop_based_on_num_step(self): 605 m = learn.monitors.StopAtStep(num_steps=10) 606 m.step_begin(5) 607 self.assertFalse(m.step_end(5, None)) 608 m.step_begin(13) 609 self.assertFalse(m.step_end(13, None)) 610 m.step_begin(14) 611 self.assertTrue(m.step_end(14, None)) 612 m.step_begin(15) 613 self.assertTrue(m.step_end(15, None)) 614 615 616class CheckpointSaverTest(test.TestCase): 617 618 def setUp(self): 619 self.model_dir = tempfile.mkdtemp() 620 self.graph = ops.Graph() 621 with self.graph.as_default(): 622 self.scaffold = monitored_session.Scaffold() 623 self.global_step = training_util.get_or_create_global_step() 624 self.train_op = state_ops.assign_add(self.global_step, 1) 625 626 def tearDown(self): 627 shutil.rmtree(self.model_dir, ignore_errors=True) 628 629 def _run(self, monitor, step, train_op, sess): 630 monitor.step_begin(step) 631 sess.run(train_op) 632 monitor.post_step(step, sess) 633 634 def test_raise_in_both_secs_and_steps(self): 635 with self.assertRaises(ValueError): 636 learn.monitors.CheckpointSaver( 637 self.model_dir, save_secs=10, save_steps=20) 638 639 def test_raise_in_none_secs_and_steps(self): 640 with self.assertRaises(ValueError): 641 learn.monitors.CheckpointSaver(self.model_dir) 642 643 def test_save_secs_saves_in_first_step(self): 644 with self.graph.as_default(): 645 monitor = learn.monitors.CheckpointSaver( 646 self.model_dir, save_secs=2, scaffold=self.scaffold) 647 monitor.begin() 648 self.scaffold.finalize() 649 with session_lib.Session() as sess: 650 sess.run(self.scaffold.init_op) 651 self._run(monitor, 1, self.train_op, sess) 652 self.assertEqual(1, 653 checkpoint_utils.load_variable(self.model_dir, 654 self.global_step.name)) 655 656 # TODO(gunan): Reenable this test after b/32446874 is fixed. 657 def disabled_test_save_secs_saves_periodically(self): 658 with self.graph.as_default(): 659 monitor = learn.monitors.CheckpointSaver( 660 self.model_dir, save_secs=2, scaffold=self.scaffold) 661 monitor.begin() 662 self.scaffold.finalize() 663 with session_lib.Session() as sess: 664 sess.run(self.scaffold.init_op) 665 self._run(monitor, 1, self.train_op, sess) 666 self._run(monitor, 2, self.train_op, sess) 667 # Not saved 668 self.assertEqual(1, 669 checkpoint_utils.load_variable(self.model_dir, 670 self.global_step.name)) 671 time.sleep(2.5) 672 self._run(monitor, 3, self.train_op, sess) 673 # saved 674 self.assertEqual(3, 675 checkpoint_utils.load_variable(self.model_dir, 676 self.global_step.name)) 677 self._run(monitor, 4, self.train_op, sess) 678 self._run(monitor, 5, self.train_op, sess) 679 # Not saved 680 self.assertEqual(3, 681 checkpoint_utils.load_variable(self.model_dir, 682 self.global_step.name)) 683 time.sleep(2.5) 684 self._run(monitor, 6, self.train_op, sess) 685 # saved 686 self.assertEqual(6, 687 checkpoint_utils.load_variable(self.model_dir, 688 self.global_step.name)) 689 690 def test_save_steps_saves_in_first_step(self): 691 with self.graph.as_default(): 692 monitor = learn.monitors.CheckpointSaver( 693 self.model_dir, save_steps=2, scaffold=self.scaffold) 694 monitor.begin() 695 self.scaffold.finalize() 696 with session_lib.Session() as sess: 697 sess.run(self.scaffold.init_op) 698 self._run(monitor, 1, self.train_op, sess) 699 self.assertEqual(1, 700 checkpoint_utils.load_variable(self.model_dir, 701 self.global_step.name)) 702 703 def test_save_steps_saves_periodically(self): 704 with self.graph.as_default(): 705 monitor = learn.monitors.CheckpointSaver( 706 self.model_dir, save_steps=2, scaffold=self.scaffold) 707 monitor.begin() 708 self.scaffold.finalize() 709 with session_lib.Session() as sess: 710 sess.run(self.scaffold.init_op) 711 self._run(monitor, 1, self.train_op, sess) 712 self._run(monitor, 2, self.train_op, sess) 713 # Not saved 714 self.assertEqual(1, 715 checkpoint_utils.load_variable(self.model_dir, 716 self.global_step.name)) 717 self._run(monitor, 3, self.train_op, sess) 718 # saved 719 self.assertEqual(3, 720 checkpoint_utils.load_variable(self.model_dir, 721 self.global_step.name)) 722 self._run(monitor, 4, self.train_op, sess) 723 # Not saved 724 self.assertEqual(3, 725 checkpoint_utils.load_variable(self.model_dir, 726 self.global_step.name)) 727 self._run(monitor, 5, self.train_op, sess) 728 # saved 729 self.assertEqual(5, 730 checkpoint_utils.load_variable(self.model_dir, 731 self.global_step.name)) 732 733 def test_save_saves_at_end(self): 734 with self.graph.as_default(): 735 monitor = learn.monitors.CheckpointSaver( 736 self.model_dir, save_secs=2, scaffold=self.scaffold) 737 monitor.begin() 738 self.scaffold.finalize() 739 with session_lib.Session() as sess: 740 sess.run(self.scaffold.init_op) 741 self._run(monitor, 1, self.train_op, sess) 742 self._run(monitor, 2, self.train_op, sess) 743 monitor.end(sess) 744 self.assertEqual(2, 745 checkpoint_utils.load_variable(self.model_dir, 746 self.global_step.name)) 747 748 749class FakeMonitor(learn.monitors.BaseMonitor): 750 751 def __init__(self): 752 learn.monitors.BaseMonitor.__init__(self) 753 self.should_stop = False 754 self.requested_tensors = [] 755 self.call_counter = collections.Counter() 756 self.last_begin_step = None 757 self.last_end_step = None 758 self.last_post_step = None 759 760 def begin(self, max_steps): 761 self.call_counter['begin'] += 1 762 763 def end(self, session): 764 self.call_counter['end'] += 1 765 766 def step_begin(self, step): 767 self.call_counter['step_begin'] += 1 768 self.last_begin_step = step 769 return self.requested_tensors 770 771 def step_end(self, step, output): 772 self.call_counter['step_end'] += 1 773 self.last_end_step = step 774 self.output = output 775 return self.should_stop 776 777 def post_step(self, step, session): 778 self.call_counter['post_step'] += 1 779 self.last_post_step = step 780 self.session = session 781 782 783class RunHookAdapterForMonitorsTest(test.TestCase): 784 785 def test_calls_and_steps(self): 786 with ops.Graph().as_default(), session_lib.Session() as sess: 787 global_step_tensor = training_util.create_global_step() 788 inc_5 = state_ops.assign_add(global_step_tensor, 5) 789 mock_mon = FakeMonitor() 790 mock_mon2 = FakeMonitor() 791 792 hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2]) 793 hook.begin() 794 for mon in [mock_mon, mock_mon2]: 795 self.assertEqual(mon.call_counter['begin'], 1) 796 797 sess.run(variables.global_variables_initializer()) 798 sess.run(global_step_tensor.assign(10)) 799 800 mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook]) 801 802 mon_sess.run(inc_5) 803 for mon in [mock_mon, mock_mon2]: 804 self.assertEqual(mon.output, {}) 805 self.assertEqual(mon.last_begin_step, 11) 806 self.assertEqual(mon.last_end_step, 11) 807 self.assertEqual(mon.last_post_step, 11) 808 self.assertEqual(mon.call_counter['step_end'], 1) 809 self.assertEqual(mon.call_counter['step_begin'], 1) 810 self.assertEqual(mon.call_counter['post_step'], 1) 811 812 mon_sess.run(inc_5) 813 for mon in [mock_mon, mock_mon2]: 814 self.assertEqual(mon.output, {}) 815 self.assertEqual(mon.last_begin_step, 16) 816 self.assertEqual(mon.last_end_step, 16) 817 self.assertEqual(mon.last_post_step, 16) 818 self.assertEqual(mon.call_counter['step_end'], 2) 819 self.assertEqual(mon.call_counter['step_begin'], 2) 820 self.assertEqual(mon.call_counter['post_step'], 2) 821 822 hook.end(sess) 823 for mon in [mock_mon, mock_mon2]: 824 self.assertEqual(mon.call_counter['end'], 1) 825 826 def test_requests(self): 827 with ops.Graph().as_default(), session_lib.Session() as sess: 828 training_util.create_global_step() 829 mock_mon = FakeMonitor() 830 mock_mon2 = FakeMonitor() 831 832 hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2]) 833 hook.begin() 834 835 mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook]) 836 837 a_tensor = constant_op.constant([0], name='a_tensor') 838 constant_op.constant([5], name='another_tensor') 839 constant_op.constant([10], name='third_tensor') 840 mock_mon.requested_tensors = ['another_tensor'] 841 mock_mon2.requested_tensors = ['third_tensor'] 842 sess.run(variables.global_variables_initializer()) 843 844 output = mon_sess.run(a_tensor) 845 self.assertEqual(output, [0]) 846 self.assertEqual(mock_mon.output['another_tensor'], [5]) 847 self.assertEqual(mock_mon2.output['third_tensor'], [10]) 848 849 850if __name__ == '__main__': 851 test.main() 852