1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Monitors tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import shutil
23import tempfile
24import time
25
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.contrib import testing
29from tensorflow.contrib.framework.python.framework import checkpoint_utils
30from tensorflow.contrib.learn.python import learn
31from tensorflow.contrib.learn.python.learn import estimators
32from tensorflow.python.client import session as session_lib
33from tensorflow.python.estimator import estimator as core_estimator
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import state_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.summary import summary
42from tensorflow.python.training import checkpoint_management
43from tensorflow.python.training import gradient_descent
44from tensorflow.python.training import monitored_session
45from tensorflow.python.training import training_util
46
47
48class _MyEveryN(learn.monitors.EveryN):
49
50  def __init__(self, every_n_steps=100, first_n_steps=1):
51    super(_MyEveryN, self).__init__(
52        every_n_steps=every_n_steps, first_n_steps=first_n_steps)
53    self._steps_begun = []
54    self._steps_ended = []
55    self._post_steps = []
56
57  @property
58  def steps_begun(self):
59    return self._steps_begun
60
61  @property
62  def steps_ended(self):
63    return self._steps_ended
64
65  @property
66  def post_steps(self):
67    return self._post_steps
68
69  def every_n_step_begin(self, step):
70    super(_MyEveryN, self).every_n_step_begin(step)
71    self._steps_begun.append(step)
72    return []
73
74  def every_n_step_end(self, step, outputs):
75    super(_MyEveryN, self).every_n_step_end(step, outputs)
76    self._steps_ended.append(step)
77    return False
78
79  def every_n_post_step(self, step, session):
80    super(_MyEveryN, self).every_n_post_step(step, session)
81    self._post_steps.append(step)
82    return False
83
84
85class MonitorsTest(test.TestCase):
86  """Monitors tests."""
87
88  def setUp(self):
89    # Mock out logging calls so we can verify whether correct tensors are being
90    # monitored.
91    self._actual_log = logging.info
92
93    def mockLog(*args, **kwargs):  # pylint: disable=invalid-name
94      self.logged_message = args
95      self._actual_log(*args, **kwargs)
96
97    logging.info = mockLog
98
99  def tearDown(self):
100    logging.info = self._actual_log
101
102  def _run_monitor(self,
103                   monitor,
104                   num_epochs=3,
105                   num_steps_per_epoch=10,
106                   pass_max_steps=True):
107    if pass_max_steps:
108      max_steps = num_epochs * num_steps_per_epoch - 1
109    else:
110      max_steps = None
111    monitor.begin(max_steps=max_steps)
112    for epoch in xrange(num_epochs):
113      monitor.epoch_begin(epoch)
114      should_stop = False
115      step = epoch * num_steps_per_epoch
116      next_epoch_step = step + num_steps_per_epoch
117      while (not should_stop) and (step < next_epoch_step):
118        tensors = monitor.step_begin(step)
119        output = ops.get_default_session().run(tensors) if tensors else {}
120        output = dict(
121            zip([t.name if isinstance(t, ops.Tensor) else t for t in tensors],
122                output))
123        should_stop = monitor.step_end(step=step, output=output)
124        monitor.post_step(step=step, session=None)
125        step += 1
126      monitor.epoch_end(epoch)
127    monitor.end()
128
129  def test_base_monitor(self):
130    with ops.Graph().as_default() as g, self.session(g):
131      self._run_monitor(learn.monitors.BaseMonitor())
132
133  def test_every_0(self):
134    monitor = _MyEveryN(every_n_steps=0, first_n_steps=-1)
135    with ops.Graph().as_default() as g, self.session(g):
136      self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
137      expected_steps = list(range(30))
138      self.assertAllEqual(expected_steps, monitor.steps_begun)
139      self.assertAllEqual(expected_steps, monitor.steps_ended)
140      self.assertAllEqual(expected_steps, monitor.post_steps)
141
142  def test_every_1(self):
143    monitor = _MyEveryN(every_n_steps=1, first_n_steps=-1)
144    with ops.Graph().as_default() as g, self.session(g):
145      self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
146      expected_steps = list(range(1, 30))
147      self.assertEqual(expected_steps, monitor.steps_begun)
148      self.assertEqual(expected_steps, monitor.steps_ended)
149      self.assertEqual(expected_steps, monitor.post_steps)
150
151  def test_every_2(self):
152    monitor = _MyEveryN(every_n_steps=2, first_n_steps=-1)
153    with ops.Graph().as_default() as g, self.session(g):
154      self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
155      expected_steps = list(range(2, 29, 2)) + [29]
156      self.assertEqual(expected_steps, monitor.steps_begun)
157      self.assertEqual(expected_steps, monitor.steps_ended)
158      self.assertEqual(expected_steps, monitor.post_steps)
159
160  def test_every_8(self):
161    monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
162    with ops.Graph().as_default() as g, self.session(g):
163      self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
164      expected_steps = [0, 1, 2, 10, 18, 26, 29]
165      self.assertEqual(expected_steps, monitor.steps_begun)
166      self.assertEqual(expected_steps, monitor.steps_ended)
167      self.assertEqual(expected_steps, monitor.post_steps)
168
169  def test_every_8_no_max_steps(self):
170    monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
171    with ops.Graph().as_default() as g, self.session(g):
172      self._run_monitor(
173          monitor, num_epochs=3, num_steps_per_epoch=10, pass_max_steps=False)
174      begin_end_steps = [0, 1, 2, 10, 18, 26]
175      post_steps = [0, 1, 2, 10, 18, 26, 29]
176      self.assertEqual(begin_end_steps, monitor.steps_begun)
177      self.assertEqual(begin_end_steps, monitor.steps_ended)
178      self.assertEqual(post_steps, monitor.post_steps)
179
180  def test_every_8_recovered_after_step_begin(self):
181    monitor = _MyEveryN(every_n_steps=8)
182    with ops.Graph().as_default() as g, self.session(g):
183      for step in [8, 16]:
184        monitor.step_begin(step)
185        monitor.step_begin(step)
186        monitor.step_end(step, output=None)
187        monitor.post_step(step, session=None)
188      # It should call begin again since, end was not called
189      self.assertEqual([8, 8, 16, 16], monitor.steps_begun)
190      self.assertEqual([8, 16], monitor.steps_ended)
191      self.assertEqual([8, 16], monitor.post_steps)
192
193  def test_every_8_recovered_after_step_end(self):
194    monitor = _MyEveryN(every_n_steps=8)
195    with ops.Graph().as_default() as g, self.session(g):
196      for step in [8, 16]:
197        monitor.step_begin(step)
198        monitor.step_end(step, output=None)
199        monitor.post_step(step, session=None)
200        monitor.step_begin(step)
201        monitor.step_end(step, output=None)
202        monitor.post_step(step, session=None)
203      # It should not call begin twice since end was called
204      self.assertEqual([8, 16], monitor.steps_begun)
205      self.assertEqual([8, 16], monitor.steps_ended)
206      self.assertEqual([8, 16], monitor.post_steps)
207
208  def test_every_8_call_post_step_at_the_end(self):
209    monitor = _MyEveryN(every_n_steps=8)
210    with ops.Graph().as_default() as g, self.session(g):
211      monitor.begin()
212      for step in [8, 16]:
213        monitor.step_begin(step)
214        monitor.step_end(step, output=None)
215        monitor.post_step(step, session=None)
216      monitor.step_begin(19)
217      monitor.step_end(19, output=None)
218      monitor.post_step(19, session=None)
219      monitor.end(session=None)
220      # It should not call begin twice since end was called
221      self.assertEqual([8, 16], monitor.steps_begun)
222      self.assertEqual([8, 16], monitor.steps_ended)
223      self.assertEqual([8, 16, 19], monitor.post_steps)
224
225  def test_every_8_call_post_step_should_not_be_called_twice(self):
226    monitor = _MyEveryN(every_n_steps=8)
227    with ops.Graph().as_default() as g, self.session(g):
228      monitor.begin()
229      for step in [8, 16]:
230        monitor.step_begin(step)
231        monitor.step_end(step, output=None)
232        monitor.post_step(step, session=None)
233      monitor.step_begin(16)
234      monitor.step_end(16, output=None)
235      monitor.post_step(16, session=None)
236      monitor.end(session=None)
237      # It should not call begin twice since end was called
238      self.assertEqual([8, 16], monitor.steps_begun)
239      self.assertEqual([8, 16], monitor.steps_ended)
240      self.assertEqual([8, 16], monitor.post_steps)
241
242  def test_print(self):
243    with ops.Graph().as_default() as g, self.session(g):
244      t = constant_op.constant(42.0, name='foo')
245      self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
246      self.assertRegexpMatches(str(self.logged_message), t.name)
247
248  def test_logging_trainable(self):
249    with ops.Graph().as_default() as g, self.session(g):
250      var = variables.VariableV1(constant_op.constant(42.0), name='foo')
251      var.initializer.run()
252      cof = constant_op.constant(1.0)
253      loss = math_ops.subtract(
254          math_ops.multiply(var, cof), constant_op.constant(1.0))
255      train_step = gradient_descent.GradientDescentOptimizer(0.5).minimize(loss)
256      ops.get_default_session().run(train_step)
257      self._run_monitor(learn.monitors.LoggingTrainable('foo'))
258      self.assertRegexpMatches(str(self.logged_message), var.name)
259
260  def test_summary_saver(self):
261    with ops.Graph().as_default() as g, self.session(g):
262      log_dir = 'log/dir'
263      summary_writer = testing.FakeSummaryWriter(log_dir, g)
264      var = variables.VariableV1(0.0)
265      var.initializer.run()
266      tensor = state_ops.assign_add(var, 1.0)
267      summary_op = summary.scalar('my_summary', tensor)
268      self._run_monitor(
269          learn.monitors.SummarySaver(
270              summary_op=summary_op,
271              save_steps=8,
272              summary_writer=summary_writer),
273          num_epochs=3,
274          num_steps_per_epoch=10)
275      summary_writer.assert_summaries(
276          test_case=self,
277          expected_logdir=log_dir,
278          expected_graph=g,
279          expected_summaries={
280              0: {
281                  'my_summary': 1.0
282              },
283              1: {
284                  'my_summary': 2.0
285              },
286              9: {
287                  'my_summary': 3.0
288              },
289              17: {
290                  'my_summary': 4.0
291              },
292              25: {
293                  'my_summary': 5.0
294              },
295              29: {
296                  'my_summary': 6.0
297              },
298          })
299
300  def _assert_validation_monitor(self,
301                                 monitor,
302                                 expected_early_stopped=False,
303                                 expected_best_step=None,
304                                 expected_best_value=None,
305                                 expected_best_metrics=None):
306    self.assertEqual(expected_early_stopped, monitor.early_stopped)
307    self.assertEqual(expected_best_step, monitor.best_step)
308    self.assertEqual(expected_best_value, monitor.best_value)
309    self.assertEqual(expected_best_metrics, monitor.best_metrics)
310
311  def test_validation_monitor_no_estimator(self):
312    monitor = learn.monitors.ValidationMonitor(
313        x=constant_op.constant(2.0), every_n_steps=0)
314    self._assert_validation_monitor(monitor)
315    with ops.Graph().as_default() as g, self.session(g):
316      with self.assertRaisesRegexp(ValueError, 'set_estimator'):
317        self._run_monitor(monitor)
318
319  @test.mock.patch.object(estimators, 'Estimator', autospec=True)
320  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
321  def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
322                                      mock_estimator_class):
323    estimator = mock_estimator_class()
324    model_dir = 'model/dir'
325    estimator.model_dir = model_dir
326    mock_latest_checkpoint.return_value = None
327
328    # Do nothing with no checkpoint.
329    monitor = learn.monitors.ValidationMonitor(
330        x=constant_op.constant(2.0), every_n_steps=0)
331    self._assert_validation_monitor(monitor)
332    monitor.set_estimator(estimator)
333    with ops.Graph().as_default() as g, self.session(g):
334      self._run_monitor(monitor)
335      self._assert_validation_monitor(monitor)
336      mock_latest_checkpoint.assert_called_with(model_dir)
337
338  @test.mock.patch.object(estimators, 'Estimator', autospec=True)
339  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
340  def test_validation_monitor_no_early_stopping_rounds(self,
341                                                       mock_latest_checkpoint,
342                                                       mock_estimator_class):
343    estimator = mock_estimator_class()
344    model_dir = 'model/dir'
345    estimator.model_dir = model_dir
346    estimator.evaluate.return_value = {}
347    mock_latest_checkpoint.return_value = '%s/ckpt' % model_dir
348
349    # Do nothing with early_stopping_rounds=None.
350    monitor = learn.monitors.ValidationMonitor(
351        x=constant_op.constant(2.0), every_n_steps=0)
352    self._assert_validation_monitor(monitor)
353    monitor.set_estimator(estimator)
354    with ops.Graph().as_default() as g, self.session(g):
355      self._run_monitor(monitor)
356      self._assert_validation_monitor(monitor)
357
358  @test.mock.patch.object(estimators, 'Estimator', autospec=True)
359  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
360  def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
361                                             mock_estimator_class):
362    estimator = mock_estimator_class()
363    model_dir = 'model/dir'
364    estimator.model_dir = model_dir
365    estimator.evaluate.return_value = {}
366    mock_latest_checkpoint.return_value = '%s/ckpt' % model_dir
367
368    # Fail for missing metric.
369    monitor = learn.monitors.ValidationMonitor(
370        x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=1)
371    self._assert_validation_monitor(monitor)
372    monitor.set_estimator(estimator)
373    with ops.Graph().as_default() as g, self.session(g):
374      with self.assertRaisesRegexp(ValueError, 'missing from outputs'):
375        self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
376
377  @test.mock.patch.object(estimators, 'Estimator', autospec=True)
378  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
379  def test_validation_monitor(self, mock_latest_checkpoint,
380                              mock_estimator_class):
381    estimator = mock_estimator_class()
382    model_dir = 'model/dir'
383    estimator.model_dir = model_dir
384    validation_outputs = {'loss': None, 'auc': None}
385    estimator.evaluate.return_value = validation_outputs
386
387    monitor = learn.monitors.ValidationMonitor(
388        x=constant_op.constant(2.0),
389        every_n_steps=0,
390        early_stopping_rounds=2,
391        check_interval_secs=None)
392
393    self._assert_validation_monitor(monitor)
394    monitor.set_estimator(estimator)
395    with ops.Graph().as_default() as g, self.session(g):
396      monitor.begin(max_steps=100)
397      monitor.epoch_begin(epoch=0)
398      self.assertEqual(0, estimator.evaluate.call_count)
399
400      # Step 0, initial loss.
401      step = 0
402      mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
403      validation_outputs['loss'] = 42.0
404      validation_outputs['auc'] = 0.5
405      self.assertEqual(0, len(monitor.step_begin(step=step)))
406      self.assertFalse(monitor.step_end(step=step, output={}))
407      self.assertEqual(1, estimator.evaluate.call_count)
408      self._assert_validation_monitor(
409          monitor, expected_best_step=0, expected_best_value=42.0,
410          expected_best_metrics={'loss': 42.0, 'auc': 0.5})
411      monitor.post_step(step=step, session=None)
412
413      # Step 1, same checkpoint, no eval.
414      step = 1
415      self.assertEqual(0, len(monitor.step_begin(step=step)))
416      self.assertFalse(monitor.step_end(step=step, output={}))
417      self.assertEqual(1, estimator.evaluate.call_count)
418      self._assert_validation_monitor(
419          monitor, expected_best_step=0, expected_best_value=42.0,
420          expected_best_metrics={'loss': 42.0, 'auc': 0.5})
421      monitor.post_step(step=step, session=None)
422
423      # Step 2, lower loss.
424      step = 2
425      mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
426      validation_outputs['loss'] = 40.0
427      validation_outputs['auc'] = 0.6
428      self.assertEqual(0, len(monitor.step_begin(step=step)))
429      self.assertFalse(monitor.step_end(step=step, output={}))
430      self.assertEqual(2, estimator.evaluate.call_count)
431      self._assert_validation_monitor(
432          monitor, expected_best_step=2, expected_best_value=40.0,
433          expected_best_metrics={'loss': 40.0, 'auc': 0.6})
434      monitor.post_step(step=step, session=None)
435
436      # Step 3, higher loss.
437      step = 3
438      mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
439      validation_outputs['loss'] = 44.0
440      validation_outputs['auc'] = 0.7
441      self.assertEqual(0, len(monitor.step_begin(step=step)))
442      self.assertFalse(monitor.step_end(step=step, output={}))
443      self.assertEqual(3, estimator.evaluate.call_count)
444      self._assert_validation_monitor(
445          monitor, expected_best_step=2, expected_best_value=40.0,
446          expected_best_metrics={'loss': 40.0, 'auc': 0.6})
447      monitor.post_step(step=step, session=None)
448
449      # Step 4, higher loss for 2 steps, early stopping.
450      step = 4
451      mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
452      validation_outputs['loss'] = 43.0
453      self.assertEqual(0, len(monitor.step_begin(step=step)))
454      self.assertTrue(monitor.step_end(step=step, output={}))
455      self.assertEqual(4, estimator.evaluate.call_count)
456      self._assert_validation_monitor(
457          monitor,
458          expected_early_stopped=True,
459          expected_best_step=2,
460          expected_best_value=40.0,
461          expected_best_metrics={'loss': 40.0, 'auc': 0.6})
462      monitor.post_step(step=step, session=None)
463
464      monitor.epoch_end(epoch=0)
465      monitor.end()
466
467  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
468  def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
469    estimator = test.mock.Mock(spec=core_estimator.Estimator)
470    model_dir = 'model/dir'
471    estimator.model_dir = model_dir
472    validation_outputs = {'loss': None, 'auc': None}
473    estimator.evaluate.return_value = validation_outputs
474
475    monitor = learn.monitors.ValidationMonitor(
476        input_fn=lambda: constant_op.constant(2.0),
477        every_n_steps=0, early_stopping_rounds=2)
478    self._assert_validation_monitor(monitor)
479    monitor.set_estimator(estimator)
480    with ops.Graph().as_default() as g, self.session(g):
481      monitor.begin(max_steps=100)
482      monitor.epoch_begin(epoch=0)
483      self.assertEqual(0, estimator.evaluate.call_count)
484
485      # Step 0, initial loss.
486      step = 0
487      mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
488      validation_outputs['loss'] = 42.0
489      validation_outputs['auc'] = 0.5
490      self.assertEqual(0, len(monitor.step_begin(step=step)))
491      self.assertFalse(monitor.step_end(step=step, output={}))
492      self.assertEqual(1, estimator.evaluate.call_count)
493      self._assert_validation_monitor(
494          monitor, expected_best_step=0, expected_best_value=42.0,
495          expected_best_metrics={'loss': 42.0, 'auc': 0.5})
496      monitor.post_step(step=step, session=None)
497
498  @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
499  def test_validation_monitor_fail_with_core_estimator_and_metrics(
500      self, mock_latest_checkpoint):
501    estimator = test.mock.Mock(spec=core_estimator.Estimator)
502    model_dir = 'model/dir'
503    estimator.model_dir = model_dir
504    validation_outputs = {'loss': None}
505    estimator.evaluate.return_value = validation_outputs
506
507    monitor = learn.monitors.ValidationMonitor(
508        input_fn=lambda: constant_op.constant(2.0),
509        metrics=constant_op.constant(2.0),
510        every_n_steps=0, early_stopping_rounds=2)
511    monitor.set_estimator(estimator)
512    with ops.Graph().as_default() as g, self.session(g):
513      monitor.begin(max_steps=100)
514      monitor.epoch_begin(epoch=0)
515
516      with self.assertRaisesRegexp(
517          ValueError,
518          'tf.estimator.Estimator does not support .* metrics'):
519        step = 0
520        mock_latest_checkpoint.return_value = '%s/ckpt.%s' % (model_dir, step)
521        validation_outputs['loss'] = 42.0
522        self.assertEqual(0, len(monitor.step_begin(step=step)))
523        self.assertFalse(monitor.step_end(step=step, output={}))
524
525  def test_graph_dump(self):
526    monitor0 = learn.monitors.GraphDump()
527    monitor1 = learn.monitors.GraphDump()
528    with ops.Graph().as_default() as g, self.session(g):
529      const_var = variables.VariableV1(42.0, name='my_const')
530      counter_var = variables.VariableV1(0.0, name='my_counter')
531      assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
532      variables.global_variables_initializer().run()
533
534      self._run_monitor(monitor0, num_epochs=3, num_steps_per_epoch=10)
535      self.assertEqual({
536          step: {
537              const_var.name: 42.0,
538              counter_var.name: step + 1.0,
539              assign_add.name: step + 1.0,
540          }
541          for step in xrange(30)
542      }, monitor0.data)
543
544      self._run_monitor(monitor1, num_epochs=3, num_steps_per_epoch=10)
545      self.assertEqual({
546          step: {
547              const_var.name: 42.0,
548              counter_var.name: step + 31.0,
549              assign_add.name: step + 31.0,
550          }
551          for step in xrange(30)
552      }, monitor1.data)
553
554      for step in xrange(30):
555        matched, non_matched = monitor1.compare(monitor0, step=step)
556        self.assertEqual([const_var.name], matched)
557        self.assertEqual({
558            assign_add.name: (step + 31.0, step + 1.0),
559            counter_var.name: (step + 31.0, step + 1.0),
560        }, non_matched)
561        matched, non_matched = monitor0.compare(monitor1, step=step)
562        self.assertEqual([const_var.name], matched)
563        self.assertEqual({
564            assign_add.name: (step + 1.0, step + 31.0),
565            counter_var.name: (step + 1.0, step + 31.0),
566        }, non_matched)
567
568  def test_capture_variable(self):
569    monitor = learn.monitors.CaptureVariable(
570        var_name='my_assign_add:0', every_n=8, first_n=2)
571    with ops.Graph().as_default() as g, self.session(g):
572      var = variables.VariableV1(0.0, name='my_var')
573      var.initializer.run()
574      state_ops.assign_add(var, 1.0, name='my_assign_add')
575      self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
576      self.assertEqual({
577          0: 1.0,
578          1: 2.0,
579          2: 3.0,
580          10: 4.0,
581          18: 5.0,
582          26: 6.0,
583          29: 7.0,
584      }, monitor.values)
585
586
587class StopAtStepTest(test.TestCase):
588
589  def test_raise_in_both_last_step_and_num_steps(self):
590    with self.assertRaises(ValueError):
591      learn.monitors.StopAtStep(num_steps=10, last_step=20)
592
593  def test_stop_based_on_last_step(self):
594    m = learn.monitors.StopAtStep(last_step=10)
595    m.step_begin(5)
596    self.assertFalse(m.step_end(5, None))
597    m.step_begin(9)
598    self.assertFalse(m.step_end(9, None))
599    m.step_begin(10)
600    self.assertTrue(m.step_end(10, None))
601    m.step_begin(11)
602    self.assertTrue(m.step_end(11, None))
603
604  def test_stop_based_on_num_step(self):
605    m = learn.monitors.StopAtStep(num_steps=10)
606    m.step_begin(5)
607    self.assertFalse(m.step_end(5, None))
608    m.step_begin(13)
609    self.assertFalse(m.step_end(13, None))
610    m.step_begin(14)
611    self.assertTrue(m.step_end(14, None))
612    m.step_begin(15)
613    self.assertTrue(m.step_end(15, None))
614
615
616class CheckpointSaverTest(test.TestCase):
617
618  def setUp(self):
619    self.model_dir = tempfile.mkdtemp()
620    self.graph = ops.Graph()
621    with self.graph.as_default():
622      self.scaffold = monitored_session.Scaffold()
623      self.global_step = training_util.get_or_create_global_step()
624      self.train_op = state_ops.assign_add(self.global_step, 1)
625
626  def tearDown(self):
627    shutil.rmtree(self.model_dir, ignore_errors=True)
628
629  def _run(self, monitor, step, train_op, sess):
630    monitor.step_begin(step)
631    sess.run(train_op)
632    monitor.post_step(step, sess)
633
634  def test_raise_in_both_secs_and_steps(self):
635    with self.assertRaises(ValueError):
636      learn.monitors.CheckpointSaver(
637          self.model_dir, save_secs=10, save_steps=20)
638
639  def test_raise_in_none_secs_and_steps(self):
640    with self.assertRaises(ValueError):
641      learn.monitors.CheckpointSaver(self.model_dir)
642
643  def test_save_secs_saves_in_first_step(self):
644    with self.graph.as_default():
645      monitor = learn.monitors.CheckpointSaver(
646          self.model_dir, save_secs=2, scaffold=self.scaffold)
647      monitor.begin()
648      self.scaffold.finalize()
649      with session_lib.Session() as sess:
650        sess.run(self.scaffold.init_op)
651        self._run(monitor, 1, self.train_op, sess)
652        self.assertEqual(1,
653                         checkpoint_utils.load_variable(self.model_dir,
654                                                        self.global_step.name))
655
656  # TODO(gunan): Reenable this test after b/32446874 is fixed.
657  def disabled_test_save_secs_saves_periodically(self):
658    with self.graph.as_default():
659      monitor = learn.monitors.CheckpointSaver(
660          self.model_dir, save_secs=2, scaffold=self.scaffold)
661      monitor.begin()
662      self.scaffold.finalize()
663      with session_lib.Session() as sess:
664        sess.run(self.scaffold.init_op)
665        self._run(monitor, 1, self.train_op, sess)
666        self._run(monitor, 2, self.train_op, sess)
667        # Not saved
668        self.assertEqual(1,
669                         checkpoint_utils.load_variable(self.model_dir,
670                                                        self.global_step.name))
671        time.sleep(2.5)
672        self._run(monitor, 3, self.train_op, sess)
673        # saved
674        self.assertEqual(3,
675                         checkpoint_utils.load_variable(self.model_dir,
676                                                        self.global_step.name))
677        self._run(monitor, 4, self.train_op, sess)
678        self._run(monitor, 5, self.train_op, sess)
679        # Not saved
680        self.assertEqual(3,
681                         checkpoint_utils.load_variable(self.model_dir,
682                                                        self.global_step.name))
683        time.sleep(2.5)
684        self._run(monitor, 6, self.train_op, sess)
685        # saved
686        self.assertEqual(6,
687                         checkpoint_utils.load_variable(self.model_dir,
688                                                        self.global_step.name))
689
690  def test_save_steps_saves_in_first_step(self):
691    with self.graph.as_default():
692      monitor = learn.monitors.CheckpointSaver(
693          self.model_dir, save_steps=2, scaffold=self.scaffold)
694      monitor.begin()
695      self.scaffold.finalize()
696      with session_lib.Session() as sess:
697        sess.run(self.scaffold.init_op)
698        self._run(monitor, 1, self.train_op, sess)
699        self.assertEqual(1,
700                         checkpoint_utils.load_variable(self.model_dir,
701                                                        self.global_step.name))
702
703  def test_save_steps_saves_periodically(self):
704    with self.graph.as_default():
705      monitor = learn.monitors.CheckpointSaver(
706          self.model_dir, save_steps=2, scaffold=self.scaffold)
707      monitor.begin()
708      self.scaffold.finalize()
709      with session_lib.Session() as sess:
710        sess.run(self.scaffold.init_op)
711        self._run(monitor, 1, self.train_op, sess)
712        self._run(monitor, 2, self.train_op, sess)
713        # Not saved
714        self.assertEqual(1,
715                         checkpoint_utils.load_variable(self.model_dir,
716                                                        self.global_step.name))
717        self._run(monitor, 3, self.train_op, sess)
718        # saved
719        self.assertEqual(3,
720                         checkpoint_utils.load_variable(self.model_dir,
721                                                        self.global_step.name))
722        self._run(monitor, 4, self.train_op, sess)
723        # Not saved
724        self.assertEqual(3,
725                         checkpoint_utils.load_variable(self.model_dir,
726                                                        self.global_step.name))
727        self._run(monitor, 5, self.train_op, sess)
728        # saved
729        self.assertEqual(5,
730                         checkpoint_utils.load_variable(self.model_dir,
731                                                        self.global_step.name))
732
733  def test_save_saves_at_end(self):
734    with self.graph.as_default():
735      monitor = learn.monitors.CheckpointSaver(
736          self.model_dir, save_secs=2, scaffold=self.scaffold)
737      monitor.begin()
738      self.scaffold.finalize()
739      with session_lib.Session() as sess:
740        sess.run(self.scaffold.init_op)
741        self._run(monitor, 1, self.train_op, sess)
742        self._run(monitor, 2, self.train_op, sess)
743        monitor.end(sess)
744        self.assertEqual(2,
745                         checkpoint_utils.load_variable(self.model_dir,
746                                                        self.global_step.name))
747
748
749class FakeMonitor(learn.monitors.BaseMonitor):
750
751  def __init__(self):
752    learn.monitors.BaseMonitor.__init__(self)
753    self.should_stop = False
754    self.requested_tensors = []
755    self.call_counter = collections.Counter()
756    self.last_begin_step = None
757    self.last_end_step = None
758    self.last_post_step = None
759
760  def begin(self, max_steps):
761    self.call_counter['begin'] += 1
762
763  def end(self, session):
764    self.call_counter['end'] += 1
765
766  def step_begin(self, step):
767    self.call_counter['step_begin'] += 1
768    self.last_begin_step = step
769    return self.requested_tensors
770
771  def step_end(self, step, output):
772    self.call_counter['step_end'] += 1
773    self.last_end_step = step
774    self.output = output
775    return self.should_stop
776
777  def post_step(self, step, session):
778    self.call_counter['post_step'] += 1
779    self.last_post_step = step
780    self.session = session
781
782
783class RunHookAdapterForMonitorsTest(test.TestCase):
784
785  def test_calls_and_steps(self):
786    with ops.Graph().as_default(), session_lib.Session() as sess:
787      global_step_tensor = training_util.create_global_step()
788      inc_5 = state_ops.assign_add(global_step_tensor, 5)
789      mock_mon = FakeMonitor()
790      mock_mon2 = FakeMonitor()
791
792      hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
793      hook.begin()
794      for mon in [mock_mon, mock_mon2]:
795        self.assertEqual(mon.call_counter['begin'], 1)
796
797      sess.run(variables.global_variables_initializer())
798      sess.run(global_step_tensor.assign(10))
799
800      mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
801
802      mon_sess.run(inc_5)
803      for mon in [mock_mon, mock_mon2]:
804        self.assertEqual(mon.output, {})
805        self.assertEqual(mon.last_begin_step, 11)
806        self.assertEqual(mon.last_end_step, 11)
807        self.assertEqual(mon.last_post_step, 11)
808        self.assertEqual(mon.call_counter['step_end'], 1)
809        self.assertEqual(mon.call_counter['step_begin'], 1)
810        self.assertEqual(mon.call_counter['post_step'], 1)
811
812      mon_sess.run(inc_5)
813      for mon in [mock_mon, mock_mon2]:
814        self.assertEqual(mon.output, {})
815        self.assertEqual(mon.last_begin_step, 16)
816        self.assertEqual(mon.last_end_step, 16)
817        self.assertEqual(mon.last_post_step, 16)
818        self.assertEqual(mon.call_counter['step_end'], 2)
819        self.assertEqual(mon.call_counter['step_begin'], 2)
820        self.assertEqual(mon.call_counter['post_step'], 2)
821
822      hook.end(sess)
823      for mon in [mock_mon, mock_mon2]:
824        self.assertEqual(mon.call_counter['end'], 1)
825
826  def test_requests(self):
827    with ops.Graph().as_default(), session_lib.Session() as sess:
828      training_util.create_global_step()
829      mock_mon = FakeMonitor()
830      mock_mon2 = FakeMonitor()
831
832      hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
833      hook.begin()
834
835      mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
836
837      a_tensor = constant_op.constant([0], name='a_tensor')
838      constant_op.constant([5], name='another_tensor')
839      constant_op.constant([10], name='third_tensor')
840      mock_mon.requested_tensors = ['another_tensor']
841      mock_mon2.requested_tensors = ['third_tensor']
842      sess.run(variables.global_variables_initializer())
843
844      output = mon_sess.run(a_tensor)
845      self.assertEqual(output, [0])
846      self.assertEqual(mock_mon.output['another_tensor'], [5])
847      self.assertEqual(mock_mon2.output['third_tensor'], [10])
848
849
850if __name__ == '__main__':
851  test.main()
852