1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for training.py."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import glob
23import json
24import os
25import random
26import shutil
27import tempfile
28import time
29
30import numpy as np
31
32from tensorflow.python.estimator import estimator as estimator_lib
33from tensorflow.python.estimator import exporter as exporter_lib
34from tensorflow.python.estimator import run_config as run_config_lib
35from tensorflow.python.estimator import training
36from tensorflow.python.estimator.canned import dnn
37from tensorflow.python.estimator.canned import prediction_keys
38from tensorflow.python.estimator.export import export as export_lib
39from tensorflow.python.estimator.inputs import numpy_io
40from tensorflow.python.feature_column import feature_column
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.platform import gfile
44from tensorflow.python.platform import test
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.summary import summary_iterator
47from tensorflow.python.summary.writer import writer_cache
48from tensorflow.python.training import basic_session_run_hooks
49from tensorflow.python.training import monitored_session
50from tensorflow.python.training import server_lib
51from tensorflow.python.training import session_run_hook
52from tensorflow.python.util import compat
53
54_DEFAULT_EVAL_STEPS = 100
55_DEFAULT_EVAL_DELAY_SECS = 120
56_DEFAULT_EVAL_THROTTLE_SECS = 600
57_DELAY_SECS_PER_WORKER = 5
58_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP
59_INVALID_INPUT_FN_MSG = '`input_fn` must be callable'
60_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'
61_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'
62_INVALID_STEPS_MSG = 'Must specify steps > 0'
63_INVALID_NAME_MSG = '`name` must be string'
64_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0'
65_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
66_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
67_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
68_INVALID_EXPORTER_MSG = '`exporters` must be an Exporter'
69_INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name'
70_DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.'
71_NONE_EXPORTER_NAME_MSG = (
72    'An Exporter cannot have a name that is `None` or empty.')
73_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'
74_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'
75_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
76_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
77_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
78_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
79# The message should NOT have 'local' word as part of it. As (?!word) is looking
80# ahead, so, the $ (ending) check is required; otherwise, it will match
81# partially and return successuful.
82_INVALID_TASK_TO_RUN = (
83    'Task type .* is not supported. Supported task types are ((?!local).)*$')
84_INVALID_EMPTY_EVAL_RESULT_ERR = (
85    'Internal error: `Estimator.evaluate` should never return empty metrics')
86_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
87_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
88    'Internal error: `Estimator.evaluate` result should have `global_step`')
89_INVALID_EVAL_TASK_ID_ERR = (
90    'there can only be one `evaluator` task .*with task id 0')
91
92_TF_CONFIG_FOR_CHIEF = {
93    'cluster': {
94        run_config_lib.TaskType.CHIEF: ['host0:0'],
95        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
96        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
97    },
98    'task': {
99        'type': run_config_lib.TaskType.CHIEF,
100        'index': 0
101    }
102}
103
104_TF_CONFIG_FOR_MASTER = {
105    'cluster': {
106        run_config_lib.TaskType.MASTER: ['host0:0'],
107        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
108        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
109    },
110    'task': {
111        'type': run_config_lib.TaskType.MASTER,
112        'index': 0
113    }
114}
115
116_TF_CONFIG_FOR_WORKER = {
117    'cluster': {
118        run_config_lib.TaskType.CHIEF: ['host0:0'],
119        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
120        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
121    },
122    'task': {
123        'type': run_config_lib.TaskType.WORKER,
124        'index': 1
125    }
126}
127
128_TF_CONFIG_FOR_PS = {
129    'cluster': {
130        run_config_lib.TaskType.CHIEF: ['host0:0'],
131        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
132        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
133    },
134    'task': {
135        'type': run_config_lib.TaskType.PS,
136        'index': 1
137    }
138}
139
140_TF_CONFIG_FOR_EVALUATOR = {
141    'cluster': {
142        run_config_lib.TaskType.CHIEF: ['host0:0'],
143        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
144        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
145    },
146    'task': {
147        'type': run_config_lib.TaskType.EVALUATOR,
148        'index': 0
149    }
150}
151
152_TF_CONFIG_FOR_GOOGLE = {'environment': 'google'}
153
154
155class _FakeHook(session_run_hook.SessionRunHook):
156  """Fake implementation of `SessionRunHook`."""
157
158
159class _InvalidHook(object):
160  """Invalid hook (not a subclass of `SessionRunHook`)."""
161
162
163def _create_exporter(name):
164  class FakeExporter(exporter_lib.Exporter):
165
166    def __init__(self, name):
167      self._name = name
168
169    @property
170    def name(self):
171      return self._name
172
173    def export(self, *args, **kwargs):
174      del args, kwargs
175
176  return FakeExporter(name=name)
177
178
179def _create_run_config_with_cluster_spec(tf_config):
180  with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}):
181    return run_config_lib.RunConfig()
182
183
184class TrainSpecTest(test.TestCase):
185  """Tests TrainSpec."""
186
187  def testRequiredArgumentsSet(self):
188    """Tests that no errors are raised when all required arguments are set."""
189    spec = training.TrainSpec(input_fn=lambda: 1)
190    self.assertEqual(1, spec.input_fn())
191    self.assertIsNone(spec.max_steps)
192    self.assertEqual(0, len(spec.hooks))
193
194  def testAllArgumentsSet(self):
195    """Tests that no errors are raised when all arguments are set."""
196    hooks = [_FakeHook()]
197    spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks)
198    self.assertEqual(1, spec.input_fn())
199    self.assertEqual(2, spec.max_steps)
200    self.assertEqual(tuple(hooks), spec.hooks)
201
202  def testInvalidInputFn(self):
203    with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
204      training.TrainSpec(input_fn='invalid')
205
206  def testInvalidMaxStep(self):
207    with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG):
208      training.TrainSpec(input_fn=lambda: 1, max_steps=0)
209
210  def testInvalidHook(self):
211    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
212      training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
213
214
215class EvalSpecTest(test.TestCase):
216  """Tests EvalSpec."""
217
218  def testRequiredArgumentsSet(self):
219    """Tests that no errors are raised when all required arguments are set."""
220    spec = training.EvalSpec(input_fn=lambda: 1)
221    self.assertEqual(1, spec.input_fn())
222    self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps)
223    self.assertIsNone(spec.name)
224    self.assertEqual(0, len(spec.hooks))
225    self.assertEqual(0, len(spec.exporters))
226    self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs)
227    self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)
228
229  def testAllArgumentsSet(self):
230    """Tests that no errors are raised when all arguments are set."""
231    hooks = [_FakeHook()]
232    exporter = _create_exporter('a')
233
234    spec = training.EvalSpec(
235        input_fn=lambda: 1,
236        steps=2,
237        name='name',
238        hooks=hooks,
239        exporters=exporter,
240        start_delay_secs=3,
241        throttle_secs=4)
242    self.assertEqual(1, spec.input_fn())
243    self.assertEqual(2, spec.steps)
244    self.assertEqual('name', spec.name)
245    self.assertEqual(tuple(hooks), spec.hooks)
246    self.assertEqual((exporter,), spec.exporters)
247    self.assertEqual(3, spec.start_delay_secs)
248    self.assertEqual(4, spec.throttle_secs)
249
250  def testListOfExporters(self):
251    """Tests that no errors are raised with multiple exporters."""
252    exporters = [_create_exporter('a'), _create_exporter('b')]
253
254    spec = training.EvalSpec(input_fn=lambda: 1, exporters=exporters)
255    self.assertEqual(1, spec.input_fn())
256    self.assertEqual(tuple(exporters), spec.exporters)
257
258  def testInvalidInputFn(self):
259    with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):
260      training.EvalSpec(input_fn='invalid')
261
262  def testInvalidMaxStep(self):
263    with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG):
264      training.EvalSpec(input_fn=lambda: 1, steps=0)
265
266  def testInvalidName(self):
267    with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG):
268      training.EvalSpec(input_fn=lambda: 1, name=123)
269
270  def testInvalidHook(self):
271    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
272      training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])
273
274  def testInvalidDelaySecs(self):
275    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):
276      training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1)
277
278  def testInvalidThrottleSecs(self):
279    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):
280      training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)
281
282  def testInvalidTypeOfListOfExporters(self):
283    with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):
284      training.EvalSpec(
285          input_fn=lambda: 1, exporters=[_create_exporter('a'),
286                                         _FakeHook()])
287
288  def testInvalidTypeOfIndividualExporter(self):
289    with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):
290      training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook())
291
292  def testInvalidTypeOfExporterName(self):
293    with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG):
294      training.EvalSpec(input_fn=lambda: 1,
295                        exporters=_create_exporter(name=123))
296
297  def testMultipleExportersWithTheSameName(self):
298    with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG):
299      training.EvalSpec(
300          input_fn=lambda: 1,
301          exporters=[_create_exporter('a'), _create_exporter('a')])
302
303  def testMultipleExportersAndOneWithoutAName(self):
304    with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):
305      training.EvalSpec(
306          input_fn=lambda: 1,
307          exporters=[_create_exporter('a'),
308                     _create_exporter(None)])
309
310  def testSingleExporterWithoutAName(self):
311    with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):
312      training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None))
313
314
315class TrainAndEvaluateTest(test.TestCase):
316
317  def test_run_task(self):
318    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
319    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
320    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
321
322    with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
323      mock_executor_instance = test.mock.Mock()
324      mock_executor.return_value = mock_executor_instance
325      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
326      mock_executor.assert_called_with(estimator=mock_est,
327                                       train_spec=mock_train_spec,
328                                       eval_spec=mock_eval_spec)
329      self.assertTrue(mock_executor_instance.run.called)
330
331  def test_error_out_if_evaluator_task_id_is_non_zero(self):
332    tf_config = {
333        'cluster': {
334            run_config_lib.TaskType.CHIEF: ['host0:0'],
335        },
336        'task': {
337            'type': run_config_lib.TaskType.EVALUATOR,
338            'index': 1
339        }
340    }
341
342    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
343    mock_est.config = _create_run_config_with_cluster_spec(tf_config)
344    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
345    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
346
347    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
348      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
349
350  def test_invalid_estimator(self):
351    invalid_estimator = object()
352    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
353    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
354
355    with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
356      training.train_and_evaluate(invalid_estimator, mock_train_spec,
357                                  mock_eval_spec)
358
359
360class TrainingExecutorConstructorTest(test.TestCase):
361  """Tests constructor of _TrainingExecutor."""
362
363  def testRequiredArgumentsSet(self):
364    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
365    train_spec = training.TrainSpec(input_fn=lambda: 1)
366    eval_spec = training.EvalSpec(input_fn=lambda: 1)
367
368    executor = training._TrainingExecutor(estimator, train_spec, eval_spec)
369    self.assertEqual(estimator, executor.estimator)
370
371  def test_invalid_estimator(self):
372    invalid_estimator = object()
373    train_spec = training.TrainSpec(input_fn=lambda: 1)
374    eval_spec = training.EvalSpec(input_fn=lambda: 1)
375
376    with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
377      training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
378
379  def test_invalid_train_spec(self):
380    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
381    invalid_train_spec = object()
382    eval_spec = training.EvalSpec(input_fn=lambda: 1)
383
384    with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):
385      training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)
386
387  def test_invalid_eval_spec(self):
388    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
389    train_spec = training.TrainSpec(input_fn=lambda: 1)
390    invalid_eval_spec = object()
391
392    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
393      training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
394
395  def test_invalid_train_hooks(self):
396    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
397    train_spec = training.TrainSpec(input_fn=lambda: 1)
398    eval_spec = training.EvalSpec(input_fn=lambda: 1)
399    invalid_train_hooks = [object()]
400
401    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):
402      training._TrainingExecutor(
403          estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks)
404
405  def test_invalid_continuous_eval_listener(self):
406    estimator = estimator_lib.Estimator(model_fn=lambda features: features)
407    train_spec = training.TrainSpec(input_fn=lambda: 1)
408    eval_spec = training.EvalSpec(input_fn=lambda: 1)
409    invalid_continuous_eval_listener = object()
410
411    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG):
412      training._TrainingExecutor(
413          estimator,
414          train_spec,
415          eval_spec,
416          continuous_eval_listener=invalid_continuous_eval_listener)
417
418
419class _TrainingExecutorTrainingTest(object):
420  """Tests training of _TrainingExecutor."""
421
422  def __init__(self, run_config):
423    self._run_config = run_config
424
425  def _run_task(self, executor):
426    # We should not call executor.run as the test here is intended to test
427    # run_foo explicitly (foo is the task type).
428    return getattr(executor, 'run_' + self._run_config.task_type)()
429
430  @test.mock.patch.object(time, 'sleep')
431  @test.mock.patch.object(server_lib, 'Server')
432  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
433    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
434    mock_est.config = self._run_config
435    train_spec = training.TrainSpec(
436        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
437    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
438    mock_server_instance = mock_server.return_value
439
440    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
441    self._run_task(executor)
442
443    mock_server.assert_called_with(
444        mock_est.config.cluster_spec,
445        job_name=mock_est.config.task_type,
446        task_index=mock_est.config.task_id,
447        config=test.mock.ANY,
448        start=False)
449
450    self.assertTrue(mock_server_instance.start.called)
451
452    mock_est.train.assert_called_with(
453        input_fn=train_spec.input_fn,
454        max_steps=train_spec.max_steps,
455        hooks=list(train_spec.hooks),
456        saving_listeners=test.mock.ANY)
457    mock_est.evaluate.assert_not_called()
458    mock_est.export_savedmodel.assert_not_called()
459
460  @test.mock.patch.object(time, 'sleep')
461  @test.mock.patch.object(server_lib, 'Server')
462  def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
463    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
464    mock_est.config = self._run_config
465    train_spec = training.TrainSpec(
466        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
467    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
468    extra_hooks = [_FakeHook()]
469
470    executor = training._TrainingExecutor(
471        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
472    self._run_task(executor)
473
474    mock_est.train.assert_called_with(
475        input_fn=train_spec.input_fn,
476        max_steps=train_spec.max_steps,
477        hooks=list(train_spec.hooks) + extra_hooks,
478        saving_listeners=test.mock.ANY)
479
480  @test.mock.patch.object(time, 'sleep')
481  @test.mock.patch.object(server_lib, 'Server')
482  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
483    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
484    mock_est.config = self._run_config
485    mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
486    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
487
488    executor = training._TrainingExecutor(mock_est, mock_train_spec,
489                                          mock_eval_spec)
490    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
491    with test.mock.patch.dict('os.environ', tf_config):
492      self._run_task(executor)
493      mock_server.assert_not_called()
494
495  def test_fail_with_empty_cluster_spec(self):
496    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
497    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
498    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
499
500    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
501    mock_est.config.cluster_spec = None
502    mock_est.config.master = 'grpc://...'
503    mock_est.config.task_type = 'worker'
504    mock_est.config.task_id = 2
505
506    with self.assertRaisesRegexp(RuntimeError,
507                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
508      self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
509                                                mock_eval_spec))
510
511  def test_fail_with_empty_master(self):
512    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
513    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
514    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
515
516    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
517    mock_est.config.cluster_spec = server_lib.ClusterSpec(
518        {'worker': ['dummy', 'dummy1']})
519    mock_est.config.master = ''
520    mock_est.config.task_type = 'worker'
521    mock_est.config.task_id = 2
522
523    with self.assertRaisesRegexp(RuntimeError,
524                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
525      self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
526                                                mock_eval_spec))
527
528  @test.mock.patch.object(time, 'sleep')
529  @test.mock.patch.object(server_lib, 'Server')
530  def test_single_worker_node_with_empty_tf_master(
531      self, mock_server, unused_mock_sleep):
532    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
533    mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
534    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
535
536    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
537    # Single node cluster.
538    mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
539    mock_est.config.master = ''
540    mock_est.config.task_type = 'worker'
541    mock_est.config.task_id = 2
542
543    self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
544                                              mock_eval_spec))
545    self.assertTrue(mock_est.train.called)
546    mock_server.assert_not_called()
547
548  def test_fail_with_empty_task_type(self):
549    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
550    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
551    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
552
553    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
554    mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
555    mock_est.config.master = 'grpc://...'
556    mock_est.config.task_type = ''
557    mock_est.config.task_id = 2
558
559    with self.assertRaisesRegexp(RuntimeError,
560                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
561      self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
562                                                mock_eval_spec))
563
564  def test_fail_with_none_task_id(self):
565    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
566    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
567    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
568
569    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
570    mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']})
571    mock_est.config.master = 'grpc://...'
572    mock_est.config.task_type = 'worker'
573    mock_est.config.task_id = None
574
575    with self.assertRaisesRegexp(RuntimeError,
576                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
577      self._run_task(training._TrainingExecutor(mock_est, mock_train_spec,
578                                                mock_eval_spec))
579
580
581class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,
582                                    test.TestCase):
583  """Tests run_worker of _TrainingExecutor."""
584
585  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
586    test.TestCase.__init__(self, methodName)
587    _TrainingExecutorTrainingTest.__init__(
588        self,
589        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
590
591  @test.mock.patch.object(server_lib, 'Server')
592  def test_delay_for_worker(self, _):
593    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
594    mock_est.config = self._run_config
595    mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
596    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
597
598    executor = training._TrainingExecutor(mock_est, mock_train_spec,
599                                          mock_eval_spec)
600
601    expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER
602    with test.mock.patch.object(time, 'sleep') as mock_sleep:
603      mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s)
604      self._run_task(executor)
605      self.assertTrue(mock_sleep.called)
606
607
608class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,
609                                   test.TestCase):
610  """Tests run_chief of _TrainingExecutor."""
611
612  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
613    test.TestCase.__init__(self, methodName)
614    _TrainingExecutorTrainingTest.__init__(
615        self,
616        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
617
618  @test.mock.patch.object(server_lib, 'Server')
619  def test_no_delay_for_chief(self, _):
620    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
621    mock_est.config = self._run_config
622    mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
623    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
624
625    executor = training._TrainingExecutor(mock_est, mock_train_spec,
626                                          mock_eval_spec)
627
628    with test.mock.patch.object(time, 'sleep') as mock_sleep:
629      self._run_task(executor)
630      mock_sleep.assert_not_called()
631
632
633class TrainingExecutorRunMasterTest(test.TestCase):
634  """Tests run_chief of _TrainingExecutor."""
635
636  def setUp(self):
637    self._run_config = _create_run_config_with_cluster_spec(
638        _TF_CONFIG_FOR_MASTER)
639
640  @test.mock.patch.object(server_lib, 'Server')
641  def test_no_delay_for_master(self, _):
642    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
643    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
644    mock_est.config = self._run_config
645    mock_train_spec = test.mock.Mock(
646        spec=training.TrainSpec, max_steps=123, hooks=[])
647    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
648
649    executor = training._TrainingExecutor(mock_est, mock_train_spec,
650                                          mock_eval_spec)
651
652    with test.mock.patch.object(time, 'sleep') as mock_sleep:
653      executor.run_master()
654      mock_sleep.assert_not_called()
655
656  @test.mock.patch.object(time, 'sleep')
657  @test.mock.patch.object(server_lib, 'Server')
658  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
659    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
660    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
661    mock_est.config = self._run_config
662    train_spec = training.TrainSpec(
663        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
664    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
665    mock_server_instance = mock_server.return_value
666
667    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
668    executor.run_master()
669
670    mock_server.assert_called_with(
671        mock_est.config.cluster_spec,
672        job_name=mock_est.config.task_type,
673        task_index=mock_est.config.task_id,
674        config=test.mock.ANY,
675        start=False)
676
677    self.assertTrue(mock_server_instance.start.called)
678
679    mock_est.train.assert_called_with(
680        input_fn=train_spec.input_fn,
681        max_steps=train_spec.max_steps,
682        hooks=list(train_spec.hooks),
683        saving_listeners=test.mock.ANY)
684    mock_est.export_savedmodel.assert_not_called()
685
686  @test.mock.patch.object(time, 'sleep')
687  @test.mock.patch.object(server_lib, 'Server')
688  def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
689    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
690    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
691    mock_est.config = self._run_config
692    train_spec = training.TrainSpec(
693        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
694    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
695    extra_hooks = [_FakeHook()]
696
697    executor = training._TrainingExecutor(
698        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
699    executor.run_master()
700
701    mock_est.train.assert_called_with(
702        input_fn=train_spec.input_fn,
703        max_steps=train_spec.max_steps,
704        hooks=list(train_spec.hooks) + extra_hooks,
705        saving_listeners=test.mock.ANY)
706
707  @test.mock.patch.object(time, 'sleep')
708  @test.mock.patch.object(server_lib, 'Server')
709  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
710    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
711    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
712    mock_est.config = self._run_config
713    mock_train_spec = test.mock.Mock(
714        spec=training.TrainSpec, max_steps=123, hooks=[])
715    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
716
717    executor = training._TrainingExecutor(mock_est, mock_train_spec,
718                                          mock_eval_spec)
719    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
720    with test.mock.patch.dict('os.environ', tf_config):
721      executor.run_master()
722      mock_server.assert_not_called()
723
724  def test_fail_with_empty_cluster_spec(self):
725    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
726    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
727    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
728
729    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
730    mock_est.config.cluster_spec = None
731    mock_est.config.master = 'grpc://...'
732    mock_est.config.task_type = 'master'
733    mock_est.config.task_id = 2
734
735    with self.assertRaisesRegexp(RuntimeError,
736                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
737      training._TrainingExecutor(
738          mock_est, mock_train_spec, mock_eval_spec).run_master()
739
740  def test_fail_with_empty_master(self):
741    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
742    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
743    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
744
745    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
746    mock_est.config.cluster_spec = server_lib.ClusterSpec(
747        {'master': ['dummy'], 'worker': ['dummy1']})
748    mock_est.config.master = ''
749    mock_est.config.task_type = 'master'
750    mock_est.config.task_id = 0
751
752    with self.assertRaisesRegexp(RuntimeError,
753                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
754      training._TrainingExecutor(
755          mock_est, mock_train_spec, mock_eval_spec).run_master()
756
757  @test.mock.patch.object(time, 'sleep')
758  @test.mock.patch.object(server_lib, 'Server')
759  def test_single_master_node_with_empty_tf_master(
760      self, mock_server, unused_mock_sleep):
761    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
762    mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123}
763
764    mock_train_spec = test.mock.Mock(
765        spec=training.TrainSpec, max_steps=123, hooks=[])
766    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[])
767
768    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
769    mock_est.config.cluster_spec = server_lib.ClusterSpec(
770        {'master': ['dummy']})
771    mock_est.config.master = ''
772    mock_est.config.task_type = 'master'
773    mock_est.config.task_id = 0
774
775    executor = training._TrainingExecutor(
776        mock_est, mock_train_spec, mock_eval_spec)
777    executor.run_master()
778
779    mock_server.assert_not_called()
780    self.assertTrue(mock_est.train.called)
781
782  def test_fail_with_empty_task_type(self):
783    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
784    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
785    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
786
787    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
788    mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']})
789    mock_est.config.master = 'grpc://...'
790    mock_est.config.task_type = ''
791    mock_est.config.task_id = 2
792
793    with self.assertRaisesRegexp(RuntimeError,
794                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
795      training._TrainingExecutor(
796          mock_est, mock_train_spec, mock_eval_spec).run_master()
797
798  def test_fail_with_none_task_id(self):
799    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
800    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
801    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
802
803    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
804    mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']})
805    mock_est.config.master = 'grpc://...'
806    mock_est.config.task_type = 'master'
807    mock_est.config.task_id = None
808
809    with self.assertRaisesRegexp(RuntimeError,
810                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
811      training._TrainingExecutor(
812          mock_est, mock_train_spec, mock_eval_spec).run_master()
813
814  @test.mock.patch.object(server_lib, 'Server')
815  def test_run_master_triggers_evaluate_and_export(self, _):
816
817    def estimator_train(saving_listeners, *args, **kwargs):
818      #  There shalt be a saving_listener.  Estimator is going to call
819      # `after_save`.
820      del args, kwargs
821      saving_listeners[0].begin()
822      saving_listeners[0].after_save(session=None, global_step_value=None)
823
824    mock_est = test.mock.Mock(
825        spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)
826    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
827    mock_est.config = self._run_config
828
829    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
830    exporter.name = 'see_whether_export_is_called'
831
832    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
833    eval_spec = training.EvalSpec(
834        input_fn=lambda: 1, steps=2, exporters=exporter)
835    eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}
836    mock_est.evaluate.return_value = eval_result
837
838    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
839    executor.run_master()
840
841    mock_est.evaluate.assert_called_with(
842        name=eval_spec.name,
843        input_fn=eval_spec.input_fn,
844        steps=eval_spec.steps,
845        checkpoint_path='checkpoint_path/',
846        hooks=eval_spec.hooks)
847    self.assertEqual(1, exporter.export.call_count)
848    exporter.export.assert_called_with(
849        estimator=mock_est,
850        export_path=os.path.join('path/', 'export', exporter.name),
851        checkpoint_path='checkpoint_path/',
852        eval_result=eval_result,
853        is_the_final_export=True)
854
855  @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
856  @test.mock.patch.object(server_lib, 'Server')
857  def test_run_master_throttle_eval(self, _, mock_timer_class):
858    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
859
860    mock_timer = test.mock.Mock()
861    mock_timer_class.return_value = mock_timer
862
863    def estimator_train(saving_listeners, *args, **kwargs):
864      del args, kwargs
865      saving_listeners[0].begin()
866
867      # Call three times.
868      mock_timer.should_trigger_for_step.return_value = True
869      saving_listeners[0].after_save(session=None, global_step_value=None)
870
871      mock_timer.should_trigger_for_step.return_value = False
872      saving_listeners[0].after_save(session=None, global_step_value=None)
873
874      mock_timer.should_trigger_for_step.return_value = True
875      saving_listeners[0].after_save(session=None, global_step_value=None)
876
877    mock_est.train = estimator_train
878    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
879    mock_est.config = self._run_config
880
881    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
882    exporter.name = 'see_whether_export_is_called'
883
884    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
885    eval_spec = training.EvalSpec(
886        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
887
888    mock_est.evaluate.side_effect = [
889        {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
890        {_GLOBAL_STEP_KEY: train_spec.max_steps}
891    ]
892
893    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
894    executor.run_master()
895
896    self.assertEqual(2, mock_est.evaluate.call_count)
897    self.assertEqual(2, exporter.export.call_count)
898
899    is_final_export_list = [call[1]['is_the_final_export']
900                            for call in exporter.export.call_args_list]
901    self.assertEqual([False, True], is_final_export_list)
902
903  @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer')
904  @test.mock.patch.object(server_lib, 'Server')
905  def test_run_master_throttle_eval_which_skips_final_ckpt(
906      self, _, mock_timer_class):
907    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
908
909    mock_timer = test.mock.Mock()
910    mock_timer_class.return_value = mock_timer
911
912    def estimator_train(saving_listeners, *args, **kwargs):
913      del args, kwargs
914      saving_listeners[0].begin()
915
916      # Call two times.
917      mock_timer.should_trigger_for_step.return_value = True
918      saving_listeners[0].after_save(session=None, global_step_value=None)
919
920      # The final ckpt is skipped by the timer. It will be picked up the final
921      # export check in the code.
922      mock_timer.should_trigger_for_step.return_value = False
923      saving_listeners[0].after_save(session=None, global_step_value=None)
924
925    mock_est.train = estimator_train
926    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
927    mock_est.config = self._run_config
928
929    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
930    exporter.name = 'see_whether_export_is_called'
931
932    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
933    eval_spec = training.EvalSpec(
934        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
935
936    mock_est.evaluate.side_effect = [
937        {_GLOBAL_STEP_KEY: train_spec.max_steps //2},
938        {_GLOBAL_STEP_KEY: train_spec.max_steps}
939    ]
940
941    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
942    executor.run_master()
943
944    self.assertEqual(2, mock_est.evaluate.call_count)
945    self.assertEqual(2, exporter.export.call_count)
946
947    is_final_export_list = [call[1]['is_the_final_export']
948                            for call in exporter.export.call_args_list]
949    self.assertEqual([False, True], is_final_export_list)
950
951
952class TrainingExecutorRunEvaluatorTest(test.TestCase):
953  """Tests run_evaluator of _TrainingExecutor."""
954
955  def _set_up_mock_est_to_train_and_evaluate_once(self, mock_est,
956                                                  mock_train_spec):
957    """Sets global step in eval result to end the while True eval loop."""
958    training_max_step = 200
959    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}
960    mock_train_spec.max_steps = training_max_step
961
962  def test_evaluate_with_evaluate_spec(self):
963    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
964    mock_est.latest_checkpoint.return_value = 'latest_it_is'
965    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
966    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
967
968    eval_spec = training.EvalSpec(
969        input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
970        start_delay_secs=0, throttle_secs=0)
971
972    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
973    executor.run_evaluator()
974
975    mock_est.evaluate.assert_called_with(
976        name='cont_eval',
977        input_fn=eval_spec.input_fn,
978        steps=eval_spec.steps,
979        checkpoint_path='latest_it_is',
980        hooks=eval_spec.hooks)
981    self.assertFalse(mock_est.train.called)
982
983  def test_evaluate_with_train_hooks(self):
984    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
985    mock_est.latest_checkpoint.return_value = 'latest_it_is'
986    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
987    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
988
989    eval_spec = training.EvalSpec(
990        input_fn=lambda: 1,
991        steps=2,
992        hooks=[_FakeHook()],
993        name='cont_eval',
994        start_delay_secs=0,
995        throttle_secs=0)
996
997    # The train_hooks will not be called during eval.
998    mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook)
999    executor = training._TrainingExecutor(
1000        mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])
1001    executor.run_evaluator()
1002
1003    mock_hook.begin.assert_not_called()
1004
1005  def test_evaluate_multiple_times(self):
1006    training_max_step = 200
1007
1008    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1009    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
1010    mock_est.evaluate.side_effect = [
1011        {_GLOBAL_STEP_KEY: training_max_step // 2},
1012        {_GLOBAL_STEP_KEY: training_max_step}
1013    ]
1014    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
1015
1016    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1017    mock_train_spec.max_steps = training_max_step
1018
1019    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1020    exporter.name = 'see_how_many_times_export_is_called'
1021
1022    mock_est.times_export_was_called = 0
1023    mock_est.times_final_export_was_called = 0
1024    def export(estimator, export_path, checkpoint_path, eval_result,
1025               is_the_final_export):
1026      del export_path, checkpoint_path, eval_result
1027      estimator.times_export_was_called += 1
1028      # final_export is happened at the end.
1029      self.assertEqual(0, estimator.times_final_export_was_called)
1030      if is_the_final_export:
1031        estimator.times_final_export_was_called += 1
1032
1033    exporter.export = export
1034
1035    eval_spec = training.EvalSpec(
1036        input_fn=lambda: 1,
1037        start_delay_secs=0,
1038        throttle_secs=0,
1039        exporters=exporter)
1040
1041    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1042    executor.run_evaluator()
1043
1044    self.assertEqual(2, mock_est.evaluate.call_count)
1045    self.assertEqual(2, mock_est.times_export_was_called)
1046    self.assertEqual(1, mock_est.times_final_export_was_called)
1047
1048  def test_evaluate_listener_before_eval(self):
1049    training_max_step = 200
1050
1051    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1052    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
1053    # Without early stopping, this eval will be run twice.
1054    mock_est.evaluate.side_effect = [{
1055        _GLOBAL_STEP_KEY: training_max_step // 2
1056    }, {
1057        _GLOBAL_STEP_KEY: training_max_step
1058    }]
1059    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
1060
1061    mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[])
1062    mock_train_spec.max_steps = training_max_step
1063
1064    class _Listener(training._ContinuousEvalListener):
1065
1066      def __init__(self):
1067        self.call_count = 0
1068
1069      def before_eval(self):
1070        self.call_count += 1
1071        return  self.call_count == 1
1072
1073    listener = _Listener()
1074
1075    eval_spec = training.EvalSpec(
1076        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
1077
1078    training._TrainingExecutor(
1079        mock_est, mock_train_spec, eval_spec,
1080        continuous_eval_listener=listener).run_evaluator()
1081
1082    # Before_eval returns False during the second time, so, evaluate will be
1083    # called once.
1084    self.assertEqual(1, mock_est.evaluate.call_count)
1085    self.assertEqual(2, listener.call_count)
1086
1087  def test_evaluate_listener_after_eval(self):
1088    training_max_step = 200
1089
1090    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1091    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
1092    # Without early stopping, this eval will be run twice.
1093    expected_eval_metrics = [{
1094        _GLOBAL_STEP_KEY: training_max_step // 2
1095    }, {
1096        _GLOBAL_STEP_KEY: training_max_step
1097    }]
1098    mock_est.evaluate.side_effect = expected_eval_metrics
1099    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
1100
1101    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1102    mock_train_spec.max_steps = training_max_step
1103
1104    class _Listener(training._ContinuousEvalListener):
1105
1106      def __init__(self):
1107        self.call_count = 0
1108
1109      def after_eval(self, eval_result):
1110        self.call_count += 1
1111        self.eval_result = eval_result
1112        return False
1113
1114    listener = _Listener()
1115
1116    eval_spec = training.EvalSpec(
1117        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
1118
1119    training._TrainingExecutor(
1120        mock_est, mock_train_spec, eval_spec,
1121        continuous_eval_listener=listener).run_evaluator()
1122
1123    # after_eval returns False during the first time, so, evaluate will be
1124    # called once.
1125    self.assertEqual(1, mock_est.evaluate.call_count)
1126    self.assertEqual(1, listener.call_count)
1127    self.assertAllEqual(expected_eval_metrics[0], listener.eval_result.metrics)
1128    self.assertEqual('path_1', listener.eval_result.checkpoint_path)
1129
1130  def test_final_export_is_true_in_the_end(self):
1131    training_max_step = 200
1132
1133    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1134    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
1135    mock_est.evaluate.side_effect = [
1136        {_GLOBAL_STEP_KEY: training_max_step // 2},
1137        {_GLOBAL_STEP_KEY: training_max_step}
1138    ]
1139    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
1140
1141    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1142    mock_train_spec.max_steps = training_max_step
1143
1144    mock_est.times_export_fn_was_called = 0
1145    mock_est.times_the_final_export_was_true = 0
1146    def export(estimator, export_path, checkpoint_path, eval_result,
1147               is_the_final_export):
1148      del export_path, checkpoint_path, eval_result
1149      estimator.times_export_fn_was_called += 1
1150      if is_the_final_export:
1151        estimator.times_the_final_export_was_true += 1
1152
1153    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1154    exporter.name = 'see_how_many_times_export_is_called'
1155    exporter.export = export
1156
1157    eval_spec = training.EvalSpec(
1158        input_fn=lambda: 1,
1159        start_delay_secs=0,
1160        throttle_secs=0,
1161        exporters=exporter)
1162
1163    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1164    executor.run_evaluator()
1165
1166    self.assertEqual(2, mock_est.evaluate.call_count)
1167    self.assertEqual(2, mock_est.times_export_fn_was_called)
1168    self.assertEqual(1, mock_est.times_the_final_export_was_true)
1169
1170  def test_skip_evaluation_due_to_ckpt(self):
1171    training_max_step = 200
1172    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1173    mock_est.evaluate.side_effect = [
1174        {_GLOBAL_STEP_KEY: training_max_step // 2},
1175        {_GLOBAL_STEP_KEY: training_max_step}
1176    ]
1177    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1178    mock_train_spec.max_steps = training_max_step
1179
1180    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
1181
1182    # First two items are invalid, next two items are same.
1183    mock_est.latest_checkpoint.side_effect = [
1184        None, '', 'same', 'same', 'path_2'
1185    ]
1186
1187    eval_spec = training.EvalSpec(
1188        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
1189
1190    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1191    with test.mock.patch.object(logging, 'warning') as mock_log:
1192      executor.run_evaluator()
1193
1194    # Three checkpoint paths are invalid.
1195    self.assertEqual(5, mock_est.latest_checkpoint.call_count)
1196    self.assertEqual(2, mock_est.evaluate.call_count)
1197
1198    # Two warning logs are expected (last warning time is reset after a
1199    # successuful evaluation)
1200    self.assertEqual(2, mock_log.call_count)
1201
1202  def test_continuous_eval_listener_eval_result(self):
1203    training_max_step = 200
1204    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1205    expected_eval_metrics = [{
1206        _GLOBAL_STEP_KEY: training_max_step // 2
1207    }, {
1208        _GLOBAL_STEP_KEY: training_max_step
1209    }]
1210    mock_est.evaluate.side_effect = expected_eval_metrics
1211    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1212    mock_train_spec.max_steps = training_max_step
1213
1214    class _Listener(training._ContinuousEvalListener):
1215
1216      def __init__(self):
1217        self.eval_results = []
1218
1219      def after_eval(self, eval_result):
1220        self.eval_results.append(eval_result)
1221        return True
1222
1223    continuous_eval_listener = _Listener()
1224
1225    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
1226
1227    # First two items are invalid, next two items are same.
1228    mock_est.latest_checkpoint.side_effect = [
1229        None, '', 'same', 'same', 'path_2'
1230    ]
1231    expected_eval_results = [
1232        training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),
1233        training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),
1234        training._EvalResult(
1235            training._EvalStatus.EVALUATED,
1236            metrics=expected_eval_metrics[0],
1237            checkpoint_path='same'),
1238        training._EvalResult(training._EvalStatus.NO_NEW_CHECKPOINT),
1239        training._EvalResult(
1240            training._EvalStatus.EVALUATED,
1241            metrics=expected_eval_metrics[1],
1242            checkpoint_path='path_2'),
1243    ]
1244
1245    eval_spec = training.EvalSpec(
1246        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
1247
1248    executor = training._TrainingExecutor(
1249        mock_est,
1250        mock_train_spec,
1251        eval_spec,
1252        continuous_eval_listener=continuous_eval_listener)
1253    executor.run_evaluator()
1254
1255    # Three checkpoint paths are invalid.
1256    self.assertEqual(5, mock_est.latest_checkpoint.call_count)
1257    self.assertEqual(2, mock_est.evaluate.call_count)
1258
1259    self.assertEqual(5, len(continuous_eval_listener.eval_results))
1260    for i, result in enumerate(continuous_eval_listener.eval_results):
1261      self.assertEqual(expected_eval_results[i].status, result.status)
1262      self.assertAllEqual(expected_eval_results[i].metrics, result.metrics)
1263      self.assertEqual(expected_eval_results[i].checkpoint_path,
1264                       result.checkpoint_path)
1265
1266  def test_sleep_start_delay_secs(self):
1267    training_max_step = 200
1268    start_delay_secs = 123
1269
1270    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1271    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}
1272    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
1273    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1274    mock_train_spec.max_steps = training_max_step
1275
1276    eval_spec = training.EvalSpec(
1277        input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
1278        start_delay_secs=start_delay_secs, throttle_secs=0)
1279
1280    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1281    with test.mock.patch.object(time, 'sleep') as mock_sleep:
1282      executor.run_evaluator()
1283      mock_sleep.assert_called_with(start_delay_secs)
1284      self.assertTrue(mock_est.evaluate.called)
1285
1286  @test.mock.patch.object(time, 'time')
1287  @test.mock.patch.object(time, 'sleep')
1288  def test_throttle_secs(self, mock_sleep, mock_time):
1289    throttle_secs = 123
1290    operation_secs = 12
1291
1292    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1293    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1294    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
1295
1296    eval_spec = training.EvalSpec(
1297        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs)
1298
1299    mock_time.side_effect = [921, 921 + operation_secs]
1300
1301    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1302    # Disable logging as it calls time.time also.
1303    with test.mock.patch.object(logging, 'info'):
1304      executor.run_evaluator()
1305    mock_sleep.assert_called_with(throttle_secs - operation_secs)
1306    self.assertTrue(mock_est.evaluate.called)
1307
1308  def test_that_export_is_called(self):
1309    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1310    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1311    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
1312
1313    def export(estimator, *args, **kwargs):
1314      del args, kwargs
1315      estimator.export_was_called = True
1316
1317    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1318    exporter.name = 'see_whether_export_is_called'
1319    exporter.export = export
1320
1321    eval_spec = training.EvalSpec(
1322        input_fn=lambda: 1,
1323        steps=2,
1324        start_delay_secs=0,
1325        throttle_secs=0,
1326        exporters=exporter)
1327
1328    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1329    executor.run_evaluator()
1330
1331    # Verify that export was called on the right estimator.
1332    self.assertTrue(mock_est.export_was_called)
1333
1334  def test_errors_out_if_evaluate_returns_empty_dict(self):
1335    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1336    train_spec = training.TrainSpec(input_fn=lambda: 1)
1337    eval_spec = training.EvalSpec(input_fn=(lambda: 1),
1338                                  start_delay_secs=0, throttle_secs=0)
1339    mock_est.evaluate.return_value = {}
1340
1341    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1342    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
1343      executor.run_evaluator()
1344
1345  def test_errors_out_if_evaluate_returns_non_dict(self):
1346    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1347    train_spec = training.TrainSpec(input_fn=lambda: 1)
1348    eval_spec = training.EvalSpec(input_fn=(lambda: 1),
1349                                  start_delay_secs=0, throttle_secs=0)
1350    mock_est.evaluate.return_value = 123
1351
1352    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1353    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
1354      executor.run_evaluator()
1355
1356  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
1357    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1358    train_spec = training.TrainSpec(input_fn=lambda: 1)
1359    eval_spec = training.EvalSpec(input_fn=(lambda: 1),
1360                                  start_delay_secs=0, throttle_secs=0)
1361    mock_est.evaluate.return_value = {'loss': 123}
1362
1363    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1364    with self.assertRaisesRegexp(ValueError,
1365                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
1366      executor.run_evaluator()
1367
1368
1369class TrainingExecutorRunPsTest(test.TestCase):
1370  """Tests run_ps of _TrainingExecutor."""
1371
1372  @test.mock.patch.object(server_lib, 'Server')
1373  def test_std_server(self, mock_server):
1374    mock_server_instance = test.mock.Mock()
1375    mock_server.return_value = mock_server_instance
1376
1377    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1378    mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)
1379    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1380    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1381
1382    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1383                                          mock_eval_spec)
1384    executor.run_ps()
1385
1386    mock_server.assert_called_with(
1387        mock_est.config.cluster_spec,
1388        job_name=mock_est.config.task_type,
1389        task_index=mock_est.config.task_id,
1390        config=test.mock.ANY,
1391        start=False)
1392
1393    self.assertTrue(mock_server_instance.start.called)
1394    self.assertTrue(mock_server_instance.join.called)
1395
1396  def test_fail_with_empty_cluster_spec(self):
1397    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1398    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1399    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1400
1401    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
1402    mock_est.config.cluster_spec = None
1403    mock_est.config.master = 'grpc://...'
1404    mock_est.config.task_type = 'ps'
1405    mock_est.config.task_id = 2
1406
1407    with self.assertRaisesRegexp(RuntimeError,
1408                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
1409      training._TrainingExecutor(mock_est, mock_train_spec,
1410                                 mock_eval_spec).run_ps()
1411
1412  def test_fail_with_empty_master(self):
1413    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1414    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1415    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1416
1417    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
1418    mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
1419    mock_est.config.master = ''
1420    mock_est.config.task_type = 'ps'
1421    mock_est.config.task_id = 2
1422
1423    with self.assertRaisesRegexp(RuntimeError,
1424                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
1425      training._TrainingExecutor(mock_est, mock_train_spec,
1426                                 mock_eval_spec).run_ps()
1427
1428  def test_fail_with_empty_task_type(self):
1429    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1430    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1431    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1432
1433    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
1434    mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
1435    mock_est.config.master = 'grpc://...'
1436    mock_est.config.task_type = ''
1437    mock_est.config.task_id = 2
1438
1439    with self.assertRaisesRegexp(RuntimeError,
1440                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
1441      training._TrainingExecutor(mock_est, mock_train_spec,
1442                                 mock_eval_spec).run_ps()
1443
1444  def test_fail_with_none_task_id(self):
1445    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1446    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1447    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1448
1449    mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig)
1450    mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']})
1451    mock_est.config.master = 'grpc://...'
1452    mock_est.config.task_type = 'ps'
1453    mock_est.config.task_id = None
1454
1455    with self.assertRaisesRegexp(RuntimeError,
1456                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
1457      training._TrainingExecutor(mock_est, mock_train_spec,
1458                                 mock_eval_spec).run_ps()
1459
1460
1461class StopAtSecsHookTest(test.TestCase):
1462  """Tests StopAtSecsHook."""
1463
1464  @test.mock.patch.object(time, 'time')
1465  def test_stops_after_time(self, mock_time):
1466    mock_time.return_value = 1484695987.209386
1467    hook = training._StopAtSecsHook(1000)
1468    with ops.Graph().as_default():
1469      no_op = control_flow_ops.no_op()
1470      # some time passed before training starts
1471      mock_time.return_value += 250
1472      with monitored_session.MonitoredSession(hooks=[hook]) as sess:
1473        self.assertFalse(sess.should_stop())
1474        sess.run(no_op)
1475        self.assertFalse(sess.should_stop())
1476        mock_time.return_value += 500
1477        sess.run(no_op)
1478        self.assertFalse(sess.should_stop())
1479        mock_time.return_value += 400
1480        sess.run(no_op)
1481        self.assertFalse(sess.should_stop())
1482        mock_time.return_value += 200
1483        sess.run(no_op)
1484        self.assertTrue(sess.should_stop())
1485
1486
1487class TrainingExecutorRunLocalTest(test.TestCase):
1488  """Tests run_local of _TrainingExecutor."""
1489
1490  def unique_checkpoint_every_time_fn(self):
1491    return 'checkpoint_path_%s/' % random.random()
1492
1493  def test_send_stop_at_secs_to_train(self):
1494    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1495    mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
1496    train_spec = training.TrainSpec(
1497        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
1498    eval_spec = training.EvalSpec(
1499        input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
1500    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
1501
1502    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1503    executor.run_local()
1504
1505    stop_hook = mock_est.train.call_args[1]['hooks'][-1]
1506    self.assertIsInstance(stop_hook, training._StopAtSecsHook)
1507    self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
1508
1509  def test_runs_in_a_loop_until_max_steps(self):
1510    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1511    mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
1512
1513    mock_est.times_export_was_called = 0
1514    mock_est.times_final_export_was_called = 0
1515    def export(estimator, export_path, checkpoint_path, eval_result,
1516               is_the_final_export):
1517      del export_path, checkpoint_path, eval_result
1518      estimator.times_export_was_called += 1
1519      # final_export is happened at the end.
1520      self.assertEqual(0, estimator.times_final_export_was_called)
1521      if is_the_final_export:
1522        estimator.times_final_export_was_called += 1
1523
1524    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1525    exporter.name = 'see_how_many_times_export_is_called'
1526    exporter.export = export
1527
1528    train_spec = training.TrainSpec(
1529        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
1530    eval_spec = training.EvalSpec(
1531        input_fn=lambda: 1,
1532        hooks=[_FakeHook()],
1533        throttle_secs=100,
1534        exporters=exporter)
1535    # should be called 3 times.
1536    mock_est.evaluate.side_effect = [{
1537        _GLOBAL_STEP_KEY: train_spec.max_steps - 100
1538    }, {
1539        _GLOBAL_STEP_KEY: train_spec.max_steps - 50
1540    }, {
1541        _GLOBAL_STEP_KEY: train_spec.max_steps
1542    }]
1543
1544    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1545    executor.run_local()
1546
1547    self.assertEqual(3, mock_est.train.call_count)
1548    self.assertEqual(3, mock_est.evaluate.call_count)
1549    self.assertEqual(3, mock_est.times_export_was_called)
1550    self.assertEqual(1, mock_est.times_final_export_was_called)
1551
1552  def test_handles_no_new_checkpoint_found(self):
1553    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1554    mock_est.latest_checkpoint.return_value = (
1555        'no_new_checkpoints_after_the_first_train_step')
1556    train_spec = training.TrainSpec(
1557        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
1558    eval_spec = training.EvalSpec(
1559        input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
1560    # It was going to be called 3 times.
1561    mock_est.evaluate.side_effect = [{
1562        _GLOBAL_STEP_KEY: train_spec.max_steps - 100
1563    }, {
1564        _GLOBAL_STEP_KEY: train_spec.max_steps - 50
1565    }, {
1566        _GLOBAL_STEP_KEY: train_spec.max_steps
1567    }]
1568
1569    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1570    with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
1571      executor.run_local()
1572
1573  def test_final_export_is_true_in_the_end(self):
1574    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1575    mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
1576
1577    mock_est.times_export_fn_was_called = 0
1578    mock_est.times_the_final_export_was_true = 0
1579    def export(estimator, export_path, checkpoint_path, eval_result,
1580               is_the_final_export):
1581      del export_path, checkpoint_path, eval_result
1582      estimator.times_export_fn_was_called += 1
1583      if is_the_final_export:
1584        estimator.times_the_final_export_was_true += 1
1585
1586    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1587    exporter.name = 'see_how_many_times_export_is_called'
1588    exporter.export = export
1589
1590    train_spec = training.TrainSpec(
1591        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
1592    eval_spec = training.EvalSpec(
1593        input_fn=lambda: 1,
1594        hooks=[_FakeHook()],
1595        throttle_secs=100,
1596        exporters=exporter)
1597    # should be called 3 times.
1598    mock_est.evaluate.side_effect = [{
1599        _GLOBAL_STEP_KEY: train_spec.max_steps - 100
1600    }, {
1601        _GLOBAL_STEP_KEY: train_spec.max_steps - 50
1602    }, {
1603        _GLOBAL_STEP_KEY: train_spec.max_steps
1604    }]
1605
1606    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1607    executor.run_local()
1608
1609    self.assertEqual(3, mock_est.train.call_count)
1610    self.assertEqual(3, mock_est.evaluate.call_count)
1611    self.assertEqual(3, mock_est.times_export_fn_was_called)
1612    self.assertEqual(1, mock_est.times_the_final_export_was_true)
1613
1614  def test_train_and_evaluate_args(self):
1615    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1616    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
1617    train_spec = training.TrainSpec(
1618        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
1619    eval_spec = training.EvalSpec(
1620        input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
1621    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
1622
1623    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1624    executor.run_local()
1625
1626    mock_est.evaluate.assert_called_with(
1627        name=eval_spec.name,
1628        input_fn=eval_spec.input_fn,
1629        steps=eval_spec.steps,
1630        checkpoint_path='checkpoint_path/',
1631        hooks=eval_spec.hooks)
1632
1633    train_args = mock_est.train.call_args[1]
1634    self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1]))
1635    self.assertEqual(train_spec.input_fn, train_args['input_fn'])
1636    self.assertEqual(train_spec.max_steps, train_args['max_steps'])
1637
1638  def test_train_hooks(self):
1639    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
1640    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
1641    train_spec = training.TrainSpec(
1642        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
1643    eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2)
1644    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
1645    extra_hooks = [_FakeHook()]
1646
1647    executor = training._TrainingExecutor(
1648        mock_est, train_spec, eval_spec, train_hooks=extra_hooks)
1649    executor.run_local()
1650
1651    train_args = mock_est.train.call_args[1]
1652    self.assertEqual(
1653        list(train_spec.hooks) + extra_hooks, [
1654            h for h in train_args['hooks']
1655            if not isinstance(h, training._StopAtSecsHook)
1656        ])
1657
1658  def test_errors_out_if_throttle_secs_is_zero(self):
1659    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1660    train_spec = training.TrainSpec(input_fn=lambda: 1)
1661    eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0)
1662
1663    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1664    with self.assertRaisesRegexp(ValueError, 'throttle_secs'):
1665      executor.run_local()
1666
1667  def test_that_export_is_called_with_run_local(self):
1668    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1669    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1670    mock_train_spec.max_steps = 200
1671    mock_est.evaluate.return_value = {
1672        _GLOBAL_STEP_KEY: mock_train_spec.max_steps
1673    }
1674    # _validate_hooks would have made sure that train_spec.hooks is [], when
1675    # None were passed.
1676    mock_train_spec.hooks = []
1677
1678    def export(estimator, *args, **kwargs):
1679      del args, kwargs
1680      estimator.export_was_called = True
1681
1682    exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
1683    exporter.name = 'see_whether_export_is_called'
1684    exporter.export = export
1685
1686    eval_spec = training.EvalSpec(
1687        input_fn=lambda: 1,
1688        steps=2,
1689        start_delay_secs=0,
1690        throttle_secs=213,
1691        exporters=exporter)
1692
1693    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
1694    executor.run_local()
1695
1696    self.assertTrue(mock_est.export_was_called)
1697
1698  def test_errors_out_if_evaluate_returns_empty_dict(self):
1699    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1700    train_spec = training.TrainSpec(input_fn=lambda: 1)
1701    eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
1702    mock_est.evaluate.return_value = {}
1703
1704    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1705    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
1706      executor.run_local()
1707
1708  def test_errors_out_if_evaluate_returns_non_dict(self):
1709    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1710    train_spec = training.TrainSpec(input_fn=lambda: 1)
1711    eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
1712    mock_est.evaluate.return_value = 123
1713
1714    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1715    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
1716      executor.run_local()
1717
1718  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
1719    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1720    train_spec = training.TrainSpec(input_fn=lambda: 1)
1721    eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
1722    mock_est.evaluate.return_value = {'loss': 123}
1723
1724    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
1725    with self.assertRaisesRegexp(ValueError,
1726                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
1727      executor.run_local()
1728
1729
1730class TrainAndEvaluateRunTest(test.TestCase):
1731
1732  def _test_run_task_and_executor(self, run_config):
1733    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1734    mock_est.config = run_config
1735    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1736    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1737
1738    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1739                                          mock_eval_spec)
1740
1741    executor.call_task = {}
1742
1743    def task_fn(name):
1744
1745      def _fn():
1746        executor.call_task[name] = 1
1747
1748      return _fn
1749
1750    executor.run_chief = task_fn('chief')
1751    executor.run_master = task_fn('master')
1752    executor.run_ps = task_fn('ps')
1753    executor.run_evaluator = task_fn('evaluator')
1754    executor.run_worker = task_fn('worker')
1755    executor.run_local = task_fn('local')
1756    return executor
1757
1758  def test_run_chief(self):
1759    executor = self._test_run_task_and_executor(
1760        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))
1761    executor.run()
1762    self.assertEqual(1, executor.call_task['chief'])
1763
1764  def test_run_worker(self):
1765    executor = self._test_run_task_and_executor(
1766        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))
1767    executor.run()
1768    self.assertEqual(1, executor.call_task['worker'])
1769
1770  def test_run_ps(self):
1771    executor = self._test_run_task_and_executor(
1772        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS))
1773    executor.run()
1774    self.assertEqual(1, executor.call_task['ps'])
1775
1776  def test_run_evaluator(self):
1777    executor = self._test_run_task_and_executor(
1778        run_config=_create_run_config_with_cluster_spec(
1779            _TF_CONFIG_FOR_EVALUATOR))
1780    executor.run()
1781    self.assertEqual(1, executor.call_task['evaluator'])
1782
1783  def test_run_local(self):
1784    executor = self._test_run_task_and_executor(
1785        run_config=run_config_lib.RunConfig())
1786    executor.run()
1787    self.assertEqual(1, executor.call_task['local'])
1788
1789  def test_invalid_local_task(self):
1790    tf_config = {
1791        'cluster': {
1792            run_config_lib.TaskType.CHIEF: ['host0:0'],
1793            'local': ['hos1:1'],
1794        },
1795        'task': {
1796            'type': 'local',  # invalid task type.
1797            'index': 0
1798        }
1799    }
1800    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1801    mock_est.config = _create_run_config_with_cluster_spec(tf_config)
1802    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1803    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1804
1805    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1806                                          mock_eval_spec)
1807    with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER):
1808      executor.run()
1809
1810  def test_unsupported_task_due_to_missing_run_task(self):
1811    unsupported_task = 'alloc'
1812    tf_config = {
1813        'cluster': {
1814            run_config_lib.TaskType.CHIEF: ['host0:0'],
1815            unsupported_task: ['hos1:1'],
1816        },
1817        'task': {
1818            'type': unsupported_task,
1819            'index': 0
1820        }
1821    }
1822    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1823    mock_est.config = _create_run_config_with_cluster_spec(tf_config)
1824    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1825    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1826
1827    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1828                                          mock_eval_spec)
1829    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
1830      executor.run()
1831
1832  def test_unsupported_task_due_to_not_callable(self):
1833    unsupported_task = 'alloc'
1834    tf_config = {
1835        'cluster': {
1836            run_config_lib.TaskType.CHIEF: ['host0:0'],
1837            unsupported_task: ['hos1:1'],
1838        },
1839        'task': {
1840            'type': unsupported_task,
1841            'index': 0
1842        }
1843    }
1844    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1845    mock_est.config = _create_run_config_with_cluster_spec(tf_config)
1846    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1847    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1848
1849    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1850                                          mock_eval_spec)
1851    executor.run_alloc = 123  # not callable
1852    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):
1853      executor.run()
1854
1855  def test_invalid_task_type(self):
1856    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
1857    mock_est.config = test.mock.Mock()
1858    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
1859    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)
1860
1861    mock_est.config = test.mock.Mock()
1862    mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']})
1863    mock_est.config.task_type = ''
1864
1865    executor = training._TrainingExecutor(mock_est, mock_train_spec,
1866                                          mock_eval_spec)
1867    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
1868      executor.run()
1869
1870
1871class TrainAndEvaluateIntegrationTest(test.TestCase):
1872
1873  def setUp(self):
1874    self._model_dir = tempfile.mkdtemp()
1875
1876  def tearDown(self):
1877    if self._model_dir:
1878      shutil.rmtree(self._model_dir)
1879
1880  def _as_label(self, data_in_float):
1881    return np.rint(data_in_float).astype(np.int64)
1882
1883  def _get_exporter(self, name, fc):
1884    feature_spec = feature_column.make_parse_example_spec(fc)
1885    serving_input_receiver_fn = (
1886        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
1887    return exporter_lib.LatestExporter(
1888        name, serving_input_receiver_fn=serving_input_receiver_fn)
1889
1890  def _extract_loss_and_global_step(self, event_folder):
1891    """Returns the loss and global step in last event."""
1892    event_paths = glob.glob(os.path.join(event_folder, 'events*'))
1893
1894    loss = None
1895    global_step_count = None
1896
1897    for e in summary_iterator.summary_iterator(event_paths[-1]):
1898      current_loss = None
1899      for v in e.summary.value:
1900        if v.tag == 'loss':
1901          current_loss = v.simple_value
1902
1903      # If loss is not found, global step is meaningless.
1904      if current_loss is None:
1905        continue
1906
1907      current_global_step = e.step
1908      if global_step_count is None or current_global_step > global_step_count:
1909        global_step_count = current_global_step
1910        loss = current_loss
1911
1912    return (loss, global_step_count)
1913
1914  def test_complete_flow_with_non_distributed_configuration(self):
1915    n_classes = 3
1916    input_dimension = 2
1917    batch_size = 10
1918
1919    eval_name = 'foo'
1920    exporter_name = 'saved_model_exporter'
1921
1922    # max_steps should be larger than save_summary_steps
1923    max_steps = 10
1924    save_summary_steps = 2
1925
1926    data = np.linspace(
1927        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
1928    x_data = data.reshape(batch_size, input_dimension)
1929    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
1930
1931    # learn y = x
1932    train_input_fn = numpy_io.numpy_input_fn(
1933        x={'x': x_data},
1934        y=y_data,
1935        batch_size=batch_size,
1936        num_epochs=None,
1937        shuffle=True)
1938
1939    eval_input_fn = numpy_io.numpy_input_fn(
1940        x={'x': x_data},
1941        y=y_data,
1942        batch_size=batch_size,
1943        num_epochs=1,
1944        shuffle=False)
1945
1946    predict_input_fn = numpy_io.numpy_input_fn(
1947        x={'x': x_data},
1948        batch_size=batch_size,
1949        shuffle=False)
1950
1951    feature_columns = [
1952        feature_column.numeric_column('x', shape=(input_dimension,))]
1953
1954    est = dnn.DNNClassifier(
1955        hidden_units=(2, 2),
1956        feature_columns=feature_columns,
1957        n_classes=n_classes,
1958        config=run_config_lib.RunConfig(save_summary_steps=save_summary_steps),
1959        model_dir=self._model_dir)
1960
1961    train_spec = training.TrainSpec(input_fn=train_input_fn,
1962                                    max_steps=max_steps)
1963
1964    eval_spec = training.EvalSpec(
1965        name=eval_name, input_fn=eval_input_fn, steps=None,
1966        exporters=self._get_exporter(exporter_name, feature_columns),
1967        throttle_secs=2)
1968
1969    training.train_and_evaluate(est, train_spec, eval_spec)
1970
1971    # Make sure nothing is stuck in limbo.
1972    writer_cache.FileWriterCache.clear()
1973
1974    # Examine the training events. Use a range to check global step to avoid
1975    # flakyness due to global step race condition.
1976    training_loss, training_global_step = self._extract_loss_and_global_step(
1977        est.model_dir)
1978    self.assertIsNotNone(training_loss)
1979    self.assertTrue(
1980        max_steps - save_summary_steps < training_global_step <= max_steps)
1981
1982    # Examine the eval events. The global step should be accurate.
1983    eval_loss, eval_global_step = self._extract_loss_and_global_step(
1984        event_folder=os.path.join(est.model_dir, 'eval_' + eval_name))
1985    self.assertIsNotNone(eval_loss)
1986    self.assertEqual(max_steps, eval_global_step)
1987
1988    # Examine the export folder.
1989    export_dir = os.path.join(os.path.join(est.model_dir, 'export'),
1990                              exporter_name)
1991    self.assertTrue(gfile.Exists(export_dir))
1992
1993    # Examine the ckpt for predict.
1994    predicted_proba = np.array([
1995        x[prediction_keys.PredictionKeys.PROBABILITIES]
1996        for x in est.predict(predict_input_fn)
1997    ])
1998    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
1999
2000
2001if __name__ == '__main__':
2002  test.main()
2003