1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for training.py.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import glob 23import json 24import os 25import random 26import shutil 27import tempfile 28import time 29 30import numpy as np 31 32from tensorflow.python.estimator import estimator as estimator_lib 33from tensorflow.python.estimator import exporter as exporter_lib 34from tensorflow.python.estimator import run_config as run_config_lib 35from tensorflow.python.estimator import training 36from tensorflow.python.estimator.canned import dnn 37from tensorflow.python.estimator.canned import prediction_keys 38from tensorflow.python.estimator.export import export as export_lib 39from tensorflow.python.estimator.inputs import numpy_io 40from tensorflow.python.feature_column import feature_column 41from tensorflow.python.framework import ops 42from tensorflow.python.ops import control_flow_ops 43from tensorflow.python.platform import gfile 44from tensorflow.python.platform import test 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.summary import summary_iterator 47from tensorflow.python.summary.writer import writer_cache 48from tensorflow.python.training import basic_session_run_hooks 49from tensorflow.python.training import monitored_session 50from tensorflow.python.training import server_lib 51from tensorflow.python.training import session_run_hook 52from tensorflow.python.util import compat 53 54_DEFAULT_EVAL_STEPS = 100 55_DEFAULT_EVAL_DELAY_SECS = 120 56_DEFAULT_EVAL_THROTTLE_SECS = 600 57_DELAY_SECS_PER_WORKER = 5 58_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP 59_INVALID_INPUT_FN_MSG = '`input_fn` must be callable' 60_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances' 61_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0' 62_INVALID_STEPS_MSG = 'Must specify steps > 0' 63_INVALID_NAME_MSG = '`name` must be string' 64_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0' 65_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0' 66_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`' 67_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.' 68_INVALID_EXPORTER_MSG = '`exporters` must be an Exporter' 69_INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name' 70_DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.' 71_NONE_EXPORTER_NAME_MSG = ( 72 'An Exporter cannot have a name that is `None` or empty.') 73_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`' 74_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`' 75_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`' 76_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG' 77_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`' 78_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.' 79# The message should NOT have 'local' word as part of it. As (?!word) is looking 80# ahead, so, the $ (ending) check is required; otherwise, it will match 81# partially and return successuful. 82_INVALID_TASK_TO_RUN = ( 83 'Task type .* is not supported. Supported task types are ((?!local).)*$') 84_INVALID_EMPTY_EVAL_RESULT_ERR = ( 85 'Internal error: `Estimator.evaluate` should never return empty metrics') 86_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.' 87_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = ( 88 'Internal error: `Estimator.evaluate` result should have `global_step`') 89_INVALID_EVAL_TASK_ID_ERR = ( 90 'there can only be one `evaluator` task .*with task id 0') 91 92_TF_CONFIG_FOR_CHIEF = { 93 'cluster': { 94 run_config_lib.TaskType.CHIEF: ['host0:0'], 95 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 96 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 97 }, 98 'task': { 99 'type': run_config_lib.TaskType.CHIEF, 100 'index': 0 101 } 102} 103 104_TF_CONFIG_FOR_MASTER = { 105 'cluster': { 106 run_config_lib.TaskType.MASTER: ['host0:0'], 107 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 108 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 109 }, 110 'task': { 111 'type': run_config_lib.TaskType.MASTER, 112 'index': 0 113 } 114} 115 116_TF_CONFIG_FOR_WORKER = { 117 'cluster': { 118 run_config_lib.TaskType.CHIEF: ['host0:0'], 119 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 120 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 121 }, 122 'task': { 123 'type': run_config_lib.TaskType.WORKER, 124 'index': 1 125 } 126} 127 128_TF_CONFIG_FOR_PS = { 129 'cluster': { 130 run_config_lib.TaskType.CHIEF: ['host0:0'], 131 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 132 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 133 }, 134 'task': { 135 'type': run_config_lib.TaskType.PS, 136 'index': 1 137 } 138} 139 140_TF_CONFIG_FOR_EVALUATOR = { 141 'cluster': { 142 run_config_lib.TaskType.CHIEF: ['host0:0'], 143 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 144 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 145 }, 146 'task': { 147 'type': run_config_lib.TaskType.EVALUATOR, 148 'index': 0 149 } 150} 151 152_TF_CONFIG_FOR_GOOGLE = {'environment': 'google'} 153 154 155class _FakeHook(session_run_hook.SessionRunHook): 156 """Fake implementation of `SessionRunHook`.""" 157 158 159class _InvalidHook(object): 160 """Invalid hook (not a subclass of `SessionRunHook`).""" 161 162 163def _create_exporter(name): 164 class FakeExporter(exporter_lib.Exporter): 165 166 def __init__(self, name): 167 self._name = name 168 169 @property 170 def name(self): 171 return self._name 172 173 def export(self, *args, **kwargs): 174 del args, kwargs 175 176 return FakeExporter(name=name) 177 178 179def _create_run_config_with_cluster_spec(tf_config): 180 with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): 181 return run_config_lib.RunConfig() 182 183 184class TrainSpecTest(test.TestCase): 185 """Tests TrainSpec.""" 186 187 def testRequiredArgumentsSet(self): 188 """Tests that no errors are raised when all required arguments are set.""" 189 spec = training.TrainSpec(input_fn=lambda: 1) 190 self.assertEqual(1, spec.input_fn()) 191 self.assertIsNone(spec.max_steps) 192 self.assertEqual(0, len(spec.hooks)) 193 194 def testAllArgumentsSet(self): 195 """Tests that no errors are raised when all arguments are set.""" 196 hooks = [_FakeHook()] 197 spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks) 198 self.assertEqual(1, spec.input_fn()) 199 self.assertEqual(2, spec.max_steps) 200 self.assertEqual(tuple(hooks), spec.hooks) 201 202 def testInvalidInputFn(self): 203 with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): 204 training.TrainSpec(input_fn='invalid') 205 206 def testInvalidMaxStep(self): 207 with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG): 208 training.TrainSpec(input_fn=lambda: 1, max_steps=0) 209 210 def testInvalidHook(self): 211 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 212 training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) 213 214 215class EvalSpecTest(test.TestCase): 216 """Tests EvalSpec.""" 217 218 def testRequiredArgumentsSet(self): 219 """Tests that no errors are raised when all required arguments are set.""" 220 spec = training.EvalSpec(input_fn=lambda: 1) 221 self.assertEqual(1, spec.input_fn()) 222 self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps) 223 self.assertIsNone(spec.name) 224 self.assertEqual(0, len(spec.hooks)) 225 self.assertEqual(0, len(spec.exporters)) 226 self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs) 227 self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs) 228 229 def testAllArgumentsSet(self): 230 """Tests that no errors are raised when all arguments are set.""" 231 hooks = [_FakeHook()] 232 exporter = _create_exporter('a') 233 234 spec = training.EvalSpec( 235 input_fn=lambda: 1, 236 steps=2, 237 name='name', 238 hooks=hooks, 239 exporters=exporter, 240 start_delay_secs=3, 241 throttle_secs=4) 242 self.assertEqual(1, spec.input_fn()) 243 self.assertEqual(2, spec.steps) 244 self.assertEqual('name', spec.name) 245 self.assertEqual(tuple(hooks), spec.hooks) 246 self.assertEqual((exporter,), spec.exporters) 247 self.assertEqual(3, spec.start_delay_secs) 248 self.assertEqual(4, spec.throttle_secs) 249 250 def testListOfExporters(self): 251 """Tests that no errors are raised with multiple exporters.""" 252 exporters = [_create_exporter('a'), _create_exporter('b')] 253 254 spec = training.EvalSpec(input_fn=lambda: 1, exporters=exporters) 255 self.assertEqual(1, spec.input_fn()) 256 self.assertEqual(tuple(exporters), spec.exporters) 257 258 def testInvalidInputFn(self): 259 with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): 260 training.EvalSpec(input_fn='invalid') 261 262 def testInvalidMaxStep(self): 263 with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG): 264 training.EvalSpec(input_fn=lambda: 1, steps=0) 265 266 def testInvalidName(self): 267 with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG): 268 training.EvalSpec(input_fn=lambda: 1, name=123) 269 270 def testInvalidHook(self): 271 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 272 training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) 273 274 def testInvalidDelaySecs(self): 275 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG): 276 training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1) 277 278 def testInvalidThrottleSecs(self): 279 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG): 280 training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1) 281 282 def testInvalidTypeOfListOfExporters(self): 283 with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG): 284 training.EvalSpec( 285 input_fn=lambda: 1, exporters=[_create_exporter('a'), 286 _FakeHook()]) 287 288 def testInvalidTypeOfIndividualExporter(self): 289 with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG): 290 training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook()) 291 292 def testInvalidTypeOfExporterName(self): 293 with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG): 294 training.EvalSpec(input_fn=lambda: 1, 295 exporters=_create_exporter(name=123)) 296 297 def testMultipleExportersWithTheSameName(self): 298 with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG): 299 training.EvalSpec( 300 input_fn=lambda: 1, 301 exporters=[_create_exporter('a'), _create_exporter('a')]) 302 303 def testMultipleExportersAndOneWithoutAName(self): 304 with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG): 305 training.EvalSpec( 306 input_fn=lambda: 1, 307 exporters=[_create_exporter('a'), 308 _create_exporter(None)]) 309 310 def testSingleExporterWithoutAName(self): 311 with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG): 312 training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None)) 313 314 315class TrainAndEvaluateTest(test.TestCase): 316 317 def test_run_task(self): 318 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 319 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 320 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 321 322 with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor: 323 mock_executor_instance = test.mock.Mock() 324 mock_executor.return_value = mock_executor_instance 325 training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) 326 mock_executor.assert_called_with(estimator=mock_est, 327 train_spec=mock_train_spec, 328 eval_spec=mock_eval_spec) 329 self.assertTrue(mock_executor_instance.run.called) 330 331 def test_error_out_if_evaluator_task_id_is_non_zero(self): 332 tf_config = { 333 'cluster': { 334 run_config_lib.TaskType.CHIEF: ['host0:0'], 335 }, 336 'task': { 337 'type': run_config_lib.TaskType.EVALUATOR, 338 'index': 1 339 } 340 } 341 342 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 343 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 344 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 345 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 346 347 with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR): 348 training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec) 349 350 def test_invalid_estimator(self): 351 invalid_estimator = object() 352 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 353 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 354 355 with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): 356 training.train_and_evaluate(invalid_estimator, mock_train_spec, 357 mock_eval_spec) 358 359 360class TrainingExecutorConstructorTest(test.TestCase): 361 """Tests constructor of _TrainingExecutor.""" 362 363 def testRequiredArgumentsSet(self): 364 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 365 train_spec = training.TrainSpec(input_fn=lambda: 1) 366 eval_spec = training.EvalSpec(input_fn=lambda: 1) 367 368 executor = training._TrainingExecutor(estimator, train_spec, eval_spec) 369 self.assertEqual(estimator, executor.estimator) 370 371 def test_invalid_estimator(self): 372 invalid_estimator = object() 373 train_spec = training.TrainSpec(input_fn=lambda: 1) 374 eval_spec = training.EvalSpec(input_fn=lambda: 1) 375 376 with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): 377 training._TrainingExecutor(invalid_estimator, train_spec, eval_spec) 378 379 def test_invalid_train_spec(self): 380 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 381 invalid_train_spec = object() 382 eval_spec = training.EvalSpec(input_fn=lambda: 1) 383 384 with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG): 385 training._TrainingExecutor(estimator, invalid_train_spec, eval_spec) 386 387 def test_invalid_eval_spec(self): 388 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 389 train_spec = training.TrainSpec(input_fn=lambda: 1) 390 invalid_eval_spec = object() 391 392 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG): 393 training._TrainingExecutor(estimator, train_spec, invalid_eval_spec) 394 395 def test_invalid_train_hooks(self): 396 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 397 train_spec = training.TrainSpec(input_fn=lambda: 1) 398 eval_spec = training.EvalSpec(input_fn=lambda: 1) 399 invalid_train_hooks = [object()] 400 401 with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): 402 training._TrainingExecutor( 403 estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks) 404 405 def test_invalid_continuous_eval_listener(self): 406 estimator = estimator_lib.Estimator(model_fn=lambda features: features) 407 train_spec = training.TrainSpec(input_fn=lambda: 1) 408 eval_spec = training.EvalSpec(input_fn=lambda: 1) 409 invalid_continuous_eval_listener = object() 410 411 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG): 412 training._TrainingExecutor( 413 estimator, 414 train_spec, 415 eval_spec, 416 continuous_eval_listener=invalid_continuous_eval_listener) 417 418 419class _TrainingExecutorTrainingTest(object): 420 """Tests training of _TrainingExecutor.""" 421 422 def __init__(self, run_config): 423 self._run_config = run_config 424 425 def _run_task(self, executor): 426 # We should not call executor.run as the test here is intended to test 427 # run_foo explicitly (foo is the task type). 428 return getattr(executor, 'run_' + self._run_config.task_type)() 429 430 @test.mock.patch.object(time, 'sleep') 431 @test.mock.patch.object(server_lib, 'Server') 432 def test_train_with_train_spec(self, mock_server, unused_mock_sleep): 433 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 434 mock_est.config = self._run_config 435 train_spec = training.TrainSpec( 436 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 437 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 438 mock_server_instance = mock_server.return_value 439 440 executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) 441 self._run_task(executor) 442 443 mock_server.assert_called_with( 444 mock_est.config.cluster_spec, 445 job_name=mock_est.config.task_type, 446 task_index=mock_est.config.task_id, 447 config=test.mock.ANY, 448 start=False) 449 450 self.assertTrue(mock_server_instance.start.called) 451 452 mock_est.train.assert_called_with( 453 input_fn=train_spec.input_fn, 454 max_steps=train_spec.max_steps, 455 hooks=list(train_spec.hooks), 456 saving_listeners=test.mock.ANY) 457 mock_est.evaluate.assert_not_called() 458 mock_est.export_savedmodel.assert_not_called() 459 460 @test.mock.patch.object(time, 'sleep') 461 @test.mock.patch.object(server_lib, 'Server') 462 def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep): 463 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 464 mock_est.config = self._run_config 465 train_spec = training.TrainSpec( 466 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 467 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 468 extra_hooks = [_FakeHook()] 469 470 executor = training._TrainingExecutor( 471 mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) 472 self._run_task(executor) 473 474 mock_est.train.assert_called_with( 475 input_fn=train_spec.input_fn, 476 max_steps=train_spec.max_steps, 477 hooks=list(train_spec.hooks) + extra_hooks, 478 saving_listeners=test.mock.ANY) 479 480 @test.mock.patch.object(time, 'sleep') 481 @test.mock.patch.object(server_lib, 'Server') 482 def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): 483 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 484 mock_est.config = self._run_config 485 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 486 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 487 488 executor = training._TrainingExecutor(mock_est, mock_train_spec, 489 mock_eval_spec) 490 tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} 491 with test.mock.patch.dict('os.environ', tf_config): 492 self._run_task(executor) 493 mock_server.assert_not_called() 494 495 def test_fail_with_empty_cluster_spec(self): 496 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 497 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 498 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 499 500 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 501 mock_est.config.cluster_spec = None 502 mock_est.config.master = 'grpc://...' 503 mock_est.config.task_type = 'worker' 504 mock_est.config.task_id = 2 505 506 with self.assertRaisesRegexp(RuntimeError, 507 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 508 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 509 mock_eval_spec)) 510 511 def test_fail_with_empty_master(self): 512 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 513 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 514 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 515 516 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 517 mock_est.config.cluster_spec = server_lib.ClusterSpec( 518 {'worker': ['dummy', 'dummy1']}) 519 mock_est.config.master = '' 520 mock_est.config.task_type = 'worker' 521 mock_est.config.task_id = 2 522 523 with self.assertRaisesRegexp(RuntimeError, 524 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 525 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 526 mock_eval_spec)) 527 528 @test.mock.patch.object(time, 'sleep') 529 @test.mock.patch.object(server_lib, 'Server') 530 def test_single_worker_node_with_empty_tf_master( 531 self, mock_server, unused_mock_sleep): 532 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 533 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 534 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 535 536 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 537 # Single node cluster. 538 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 539 mock_est.config.master = '' 540 mock_est.config.task_type = 'worker' 541 mock_est.config.task_id = 2 542 543 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 544 mock_eval_spec)) 545 self.assertTrue(mock_est.train.called) 546 mock_server.assert_not_called() 547 548 def test_fail_with_empty_task_type(self): 549 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 550 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 551 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 552 553 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 554 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 555 mock_est.config.master = 'grpc://...' 556 mock_est.config.task_type = '' 557 mock_est.config.task_id = 2 558 559 with self.assertRaisesRegexp(RuntimeError, 560 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 561 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 562 mock_eval_spec)) 563 564 def test_fail_with_none_task_id(self): 565 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 566 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 567 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 568 569 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 570 mock_est.config.cluster_spec = server_lib.ClusterSpec({'worker': ['dummy']}) 571 mock_est.config.master = 'grpc://...' 572 mock_est.config.task_type = 'worker' 573 mock_est.config.task_id = None 574 575 with self.assertRaisesRegexp(RuntimeError, 576 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 577 self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, 578 mock_eval_spec)) 579 580 581class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest, 582 test.TestCase): 583 """Tests run_worker of _TrainingExecutor.""" 584 585 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 586 test.TestCase.__init__(self, methodName) 587 _TrainingExecutorTrainingTest.__init__( 588 self, 589 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) 590 591 @test.mock.patch.object(server_lib, 'Server') 592 def test_delay_for_worker(self, _): 593 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 594 mock_est.config = self._run_config 595 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 596 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 597 598 executor = training._TrainingExecutor(mock_est, mock_train_spec, 599 mock_eval_spec) 600 601 expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER 602 with test.mock.patch.object(time, 'sleep') as mock_sleep: 603 mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s) 604 self._run_task(executor) 605 self.assertTrue(mock_sleep.called) 606 607 608class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest, 609 test.TestCase): 610 """Tests run_chief of _TrainingExecutor.""" 611 612 def __init__(self, methodName='runTest'): # pylint: disable=invalid-name 613 test.TestCase.__init__(self, methodName) 614 _TrainingExecutorTrainingTest.__init__( 615 self, 616 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) 617 618 @test.mock.patch.object(server_lib, 'Server') 619 def test_no_delay_for_chief(self, _): 620 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 621 mock_est.config = self._run_config 622 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 623 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 624 625 executor = training._TrainingExecutor(mock_est, mock_train_spec, 626 mock_eval_spec) 627 628 with test.mock.patch.object(time, 'sleep') as mock_sleep: 629 self._run_task(executor) 630 mock_sleep.assert_not_called() 631 632 633class TrainingExecutorRunMasterTest(test.TestCase): 634 """Tests run_chief of _TrainingExecutor.""" 635 636 def setUp(self): 637 self._run_config = _create_run_config_with_cluster_spec( 638 _TF_CONFIG_FOR_MASTER) 639 640 @test.mock.patch.object(server_lib, 'Server') 641 def test_no_delay_for_master(self, _): 642 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 643 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 644 mock_est.config = self._run_config 645 mock_train_spec = test.mock.Mock( 646 spec=training.TrainSpec, max_steps=123, hooks=[]) 647 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 648 649 executor = training._TrainingExecutor(mock_est, mock_train_spec, 650 mock_eval_spec) 651 652 with test.mock.patch.object(time, 'sleep') as mock_sleep: 653 executor.run_master() 654 mock_sleep.assert_not_called() 655 656 @test.mock.patch.object(time, 'sleep') 657 @test.mock.patch.object(server_lib, 'Server') 658 def test_train_with_train_spec(self, mock_server, unused_mock_sleep): 659 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 660 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 661 mock_est.config = self._run_config 662 train_spec = training.TrainSpec( 663 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 664 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 665 mock_server_instance = mock_server.return_value 666 667 executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) 668 executor.run_master() 669 670 mock_server.assert_called_with( 671 mock_est.config.cluster_spec, 672 job_name=mock_est.config.task_type, 673 task_index=mock_est.config.task_id, 674 config=test.mock.ANY, 675 start=False) 676 677 self.assertTrue(mock_server_instance.start.called) 678 679 mock_est.train.assert_called_with( 680 input_fn=train_spec.input_fn, 681 max_steps=train_spec.max_steps, 682 hooks=list(train_spec.hooks), 683 saving_listeners=test.mock.ANY) 684 mock_est.export_savedmodel.assert_not_called() 685 686 @test.mock.patch.object(time, 'sleep') 687 @test.mock.patch.object(server_lib, 'Server') 688 def test_train_with_train_hooks(self, mock_server, unused_mock_sleep): 689 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 690 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 691 mock_est.config = self._run_config 692 train_spec = training.TrainSpec( 693 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 694 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 695 extra_hooks = [_FakeHook()] 696 697 executor = training._TrainingExecutor( 698 mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks) 699 executor.run_master() 700 701 mock_est.train.assert_called_with( 702 input_fn=train_spec.input_fn, 703 max_steps=train_spec.max_steps, 704 hooks=list(train_spec.hooks) + extra_hooks, 705 saving_listeners=test.mock.ANY) 706 707 @test.mock.patch.object(time, 'sleep') 708 @test.mock.patch.object(server_lib, 'Server') 709 def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): 710 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 711 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 712 mock_est.config = self._run_config 713 mock_train_spec = test.mock.Mock( 714 spec=training.TrainSpec, max_steps=123, hooks=[]) 715 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 716 717 executor = training._TrainingExecutor(mock_est, mock_train_spec, 718 mock_eval_spec) 719 tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} 720 with test.mock.patch.dict('os.environ', tf_config): 721 executor.run_master() 722 mock_server.assert_not_called() 723 724 def test_fail_with_empty_cluster_spec(self): 725 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 726 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 727 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 728 729 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 730 mock_est.config.cluster_spec = None 731 mock_est.config.master = 'grpc://...' 732 mock_est.config.task_type = 'master' 733 mock_est.config.task_id = 2 734 735 with self.assertRaisesRegexp(RuntimeError, 736 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 737 training._TrainingExecutor( 738 mock_est, mock_train_spec, mock_eval_spec).run_master() 739 740 def test_fail_with_empty_master(self): 741 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 742 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 743 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 744 745 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 746 mock_est.config.cluster_spec = server_lib.ClusterSpec( 747 {'master': ['dummy'], 'worker': ['dummy1']}) 748 mock_est.config.master = '' 749 mock_est.config.task_type = 'master' 750 mock_est.config.task_id = 0 751 752 with self.assertRaisesRegexp(RuntimeError, 753 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 754 training._TrainingExecutor( 755 mock_est, mock_train_spec, mock_eval_spec).run_master() 756 757 @test.mock.patch.object(time, 'sleep') 758 @test.mock.patch.object(server_lib, 'Server') 759 def test_single_master_node_with_empty_tf_master( 760 self, mock_server, unused_mock_sleep): 761 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 762 mock_est.evaluate = lambda *args, **kw: {ops.GraphKeys.GLOBAL_STEP: 123} 763 764 mock_train_spec = test.mock.Mock( 765 spec=training.TrainSpec, max_steps=123, hooks=[]) 766 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec, exporters=[]) 767 768 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 769 mock_est.config.cluster_spec = server_lib.ClusterSpec( 770 {'master': ['dummy']}) 771 mock_est.config.master = '' 772 mock_est.config.task_type = 'master' 773 mock_est.config.task_id = 0 774 775 executor = training._TrainingExecutor( 776 mock_est, mock_train_spec, mock_eval_spec) 777 executor.run_master() 778 779 mock_server.assert_not_called() 780 self.assertTrue(mock_est.train.called) 781 782 def test_fail_with_empty_task_type(self): 783 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 784 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 785 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 786 787 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 788 mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']}) 789 mock_est.config.master = 'grpc://...' 790 mock_est.config.task_type = '' 791 mock_est.config.task_id = 2 792 793 with self.assertRaisesRegexp(RuntimeError, 794 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 795 training._TrainingExecutor( 796 mock_est, mock_train_spec, mock_eval_spec).run_master() 797 798 def test_fail_with_none_task_id(self): 799 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 800 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 801 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 802 803 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 804 mock_est.config.cluster_spec = server_lib.ClusterSpec({'master': ['dummy']}) 805 mock_est.config.master = 'grpc://...' 806 mock_est.config.task_type = 'master' 807 mock_est.config.task_id = None 808 809 with self.assertRaisesRegexp(RuntimeError, 810 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 811 training._TrainingExecutor( 812 mock_est, mock_train_spec, mock_eval_spec).run_master() 813 814 @test.mock.patch.object(server_lib, 'Server') 815 def test_run_master_triggers_evaluate_and_export(self, _): 816 817 def estimator_train(saving_listeners, *args, **kwargs): 818 # There shalt be a saving_listener. Estimator is going to call 819 # `after_save`. 820 del args, kwargs 821 saving_listeners[0].begin() 822 saving_listeners[0].after_save(session=None, global_step_value=None) 823 824 mock_est = test.mock.Mock( 825 spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) 826 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 827 mock_est.config = self._run_config 828 829 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 830 exporter.name = 'see_whether_export_is_called' 831 832 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 833 eval_spec = training.EvalSpec( 834 input_fn=lambda: 1, steps=2, exporters=exporter) 835 eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps} 836 mock_est.evaluate.return_value = eval_result 837 838 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 839 executor.run_master() 840 841 mock_est.evaluate.assert_called_with( 842 name=eval_spec.name, 843 input_fn=eval_spec.input_fn, 844 steps=eval_spec.steps, 845 checkpoint_path='checkpoint_path/', 846 hooks=eval_spec.hooks) 847 self.assertEqual(1, exporter.export.call_count) 848 exporter.export.assert_called_with( 849 estimator=mock_est, 850 export_path=os.path.join('path/', 'export', exporter.name), 851 checkpoint_path='checkpoint_path/', 852 eval_result=eval_result, 853 is_the_final_export=True) 854 855 @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') 856 @test.mock.patch.object(server_lib, 'Server') 857 def test_run_master_throttle_eval(self, _, mock_timer_class): 858 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 859 860 mock_timer = test.mock.Mock() 861 mock_timer_class.return_value = mock_timer 862 863 def estimator_train(saving_listeners, *args, **kwargs): 864 del args, kwargs 865 saving_listeners[0].begin() 866 867 # Call three times. 868 mock_timer.should_trigger_for_step.return_value = True 869 saving_listeners[0].after_save(session=None, global_step_value=None) 870 871 mock_timer.should_trigger_for_step.return_value = False 872 saving_listeners[0].after_save(session=None, global_step_value=None) 873 874 mock_timer.should_trigger_for_step.return_value = True 875 saving_listeners[0].after_save(session=None, global_step_value=None) 876 877 mock_est.train = estimator_train 878 mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] 879 mock_est.config = self._run_config 880 881 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 882 exporter.name = 'see_whether_export_is_called' 883 884 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 885 eval_spec = training.EvalSpec( 886 input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) 887 888 mock_est.evaluate.side_effect = [ 889 {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, 890 {_GLOBAL_STEP_KEY: train_spec.max_steps} 891 ] 892 893 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 894 executor.run_master() 895 896 self.assertEqual(2, mock_est.evaluate.call_count) 897 self.assertEqual(2, exporter.export.call_count) 898 899 is_final_export_list = [call[1]['is_the_final_export'] 900 for call in exporter.export.call_args_list] 901 self.assertEqual([False, True], is_final_export_list) 902 903 @test.mock.patch.object(basic_session_run_hooks, 'SecondOrStepTimer') 904 @test.mock.patch.object(server_lib, 'Server') 905 def test_run_master_throttle_eval_which_skips_final_ckpt( 906 self, _, mock_timer_class): 907 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 908 909 mock_timer = test.mock.Mock() 910 mock_timer_class.return_value = mock_timer 911 912 def estimator_train(saving_listeners, *args, **kwargs): 913 del args, kwargs 914 saving_listeners[0].begin() 915 916 # Call two times. 917 mock_timer.should_trigger_for_step.return_value = True 918 saving_listeners[0].after_save(session=None, global_step_value=None) 919 920 # The final ckpt is skipped by the timer. It will be picked up the final 921 # export check in the code. 922 mock_timer.should_trigger_for_step.return_value = False 923 saving_listeners[0].after_save(session=None, global_step_value=None) 924 925 mock_est.train = estimator_train 926 mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2'] 927 mock_est.config = self._run_config 928 929 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 930 exporter.name = 'see_whether_export_is_called' 931 932 train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) 933 eval_spec = training.EvalSpec( 934 input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10) 935 936 mock_est.evaluate.side_effect = [ 937 {_GLOBAL_STEP_KEY: train_spec.max_steps //2}, 938 {_GLOBAL_STEP_KEY: train_spec.max_steps} 939 ] 940 941 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 942 executor.run_master() 943 944 self.assertEqual(2, mock_est.evaluate.call_count) 945 self.assertEqual(2, exporter.export.call_count) 946 947 is_final_export_list = [call[1]['is_the_final_export'] 948 for call in exporter.export.call_args_list] 949 self.assertEqual([False, True], is_final_export_list) 950 951 952class TrainingExecutorRunEvaluatorTest(test.TestCase): 953 """Tests run_evaluator of _TrainingExecutor.""" 954 955 def _set_up_mock_est_to_train_and_evaluate_once(self, mock_est, 956 mock_train_spec): 957 """Sets global step in eval result to end the while True eval loop.""" 958 training_max_step = 200 959 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} 960 mock_train_spec.max_steps = training_max_step 961 962 def test_evaluate_with_evaluate_spec(self): 963 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 964 mock_est.latest_checkpoint.return_value = 'latest_it_is' 965 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 966 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 967 968 eval_spec = training.EvalSpec( 969 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', 970 start_delay_secs=0, throttle_secs=0) 971 972 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 973 executor.run_evaluator() 974 975 mock_est.evaluate.assert_called_with( 976 name='cont_eval', 977 input_fn=eval_spec.input_fn, 978 steps=eval_spec.steps, 979 checkpoint_path='latest_it_is', 980 hooks=eval_spec.hooks) 981 self.assertFalse(mock_est.train.called) 982 983 def test_evaluate_with_train_hooks(self): 984 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 985 mock_est.latest_checkpoint.return_value = 'latest_it_is' 986 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 987 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 988 989 eval_spec = training.EvalSpec( 990 input_fn=lambda: 1, 991 steps=2, 992 hooks=[_FakeHook()], 993 name='cont_eval', 994 start_delay_secs=0, 995 throttle_secs=0) 996 997 # The train_hooks will not be called during eval. 998 mock_hook = test.mock.Mock(spec=session_run_hook.SessionRunHook) 999 executor = training._TrainingExecutor( 1000 mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook]) 1001 executor.run_evaluator() 1002 1003 mock_hook.begin.assert_not_called() 1004 1005 def test_evaluate_multiple_times(self): 1006 training_max_step = 200 1007 1008 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1009 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1010 mock_est.evaluate.side_effect = [ 1011 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1012 {_GLOBAL_STEP_KEY: training_max_step} 1013 ] 1014 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1015 1016 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1017 mock_train_spec.max_steps = training_max_step 1018 1019 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1020 exporter.name = 'see_how_many_times_export_is_called' 1021 1022 mock_est.times_export_was_called = 0 1023 mock_est.times_final_export_was_called = 0 1024 def export(estimator, export_path, checkpoint_path, eval_result, 1025 is_the_final_export): 1026 del export_path, checkpoint_path, eval_result 1027 estimator.times_export_was_called += 1 1028 # final_export is happened at the end. 1029 self.assertEqual(0, estimator.times_final_export_was_called) 1030 if is_the_final_export: 1031 estimator.times_final_export_was_called += 1 1032 1033 exporter.export = export 1034 1035 eval_spec = training.EvalSpec( 1036 input_fn=lambda: 1, 1037 start_delay_secs=0, 1038 throttle_secs=0, 1039 exporters=exporter) 1040 1041 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1042 executor.run_evaluator() 1043 1044 self.assertEqual(2, mock_est.evaluate.call_count) 1045 self.assertEqual(2, mock_est.times_export_was_called) 1046 self.assertEqual(1, mock_est.times_final_export_was_called) 1047 1048 def test_evaluate_listener_before_eval(self): 1049 training_max_step = 200 1050 1051 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1052 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1053 # Without early stopping, this eval will be run twice. 1054 mock_est.evaluate.side_effect = [{ 1055 _GLOBAL_STEP_KEY: training_max_step // 2 1056 }, { 1057 _GLOBAL_STEP_KEY: training_max_step 1058 }] 1059 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1060 1061 mock_train_spec = test.mock.Mock(spec=training.TrainSpec, hooks=[]) 1062 mock_train_spec.max_steps = training_max_step 1063 1064 class _Listener(training._ContinuousEvalListener): 1065 1066 def __init__(self): 1067 self.call_count = 0 1068 1069 def before_eval(self): 1070 self.call_count += 1 1071 return self.call_count == 1 1072 1073 listener = _Listener() 1074 1075 eval_spec = training.EvalSpec( 1076 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1077 1078 training._TrainingExecutor( 1079 mock_est, mock_train_spec, eval_spec, 1080 continuous_eval_listener=listener).run_evaluator() 1081 1082 # Before_eval returns False during the second time, so, evaluate will be 1083 # called once. 1084 self.assertEqual(1, mock_est.evaluate.call_count) 1085 self.assertEqual(2, listener.call_count) 1086 1087 def test_evaluate_listener_after_eval(self): 1088 training_max_step = 200 1089 1090 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1091 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1092 # Without early stopping, this eval will be run twice. 1093 expected_eval_metrics = [{ 1094 _GLOBAL_STEP_KEY: training_max_step // 2 1095 }, { 1096 _GLOBAL_STEP_KEY: training_max_step 1097 }] 1098 mock_est.evaluate.side_effect = expected_eval_metrics 1099 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1100 1101 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1102 mock_train_spec.max_steps = training_max_step 1103 1104 class _Listener(training._ContinuousEvalListener): 1105 1106 def __init__(self): 1107 self.call_count = 0 1108 1109 def after_eval(self, eval_result): 1110 self.call_count += 1 1111 self.eval_result = eval_result 1112 return False 1113 1114 listener = _Listener() 1115 1116 eval_spec = training.EvalSpec( 1117 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1118 1119 training._TrainingExecutor( 1120 mock_est, mock_train_spec, eval_spec, 1121 continuous_eval_listener=listener).run_evaluator() 1122 1123 # after_eval returns False during the first time, so, evaluate will be 1124 # called once. 1125 self.assertEqual(1, mock_est.evaluate.call_count) 1126 self.assertEqual(1, listener.call_count) 1127 self.assertAllEqual(expected_eval_metrics[0], listener.eval_result.metrics) 1128 self.assertEqual('path_1', listener.eval_result.checkpoint_path) 1129 1130 def test_final_export_is_true_in_the_end(self): 1131 training_max_step = 200 1132 1133 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1134 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1135 mock_est.evaluate.side_effect = [ 1136 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1137 {_GLOBAL_STEP_KEY: training_max_step} 1138 ] 1139 mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] 1140 1141 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1142 mock_train_spec.max_steps = training_max_step 1143 1144 mock_est.times_export_fn_was_called = 0 1145 mock_est.times_the_final_export_was_true = 0 1146 def export(estimator, export_path, checkpoint_path, eval_result, 1147 is_the_final_export): 1148 del export_path, checkpoint_path, eval_result 1149 estimator.times_export_fn_was_called += 1 1150 if is_the_final_export: 1151 estimator.times_the_final_export_was_true += 1 1152 1153 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1154 exporter.name = 'see_how_many_times_export_is_called' 1155 exporter.export = export 1156 1157 eval_spec = training.EvalSpec( 1158 input_fn=lambda: 1, 1159 start_delay_secs=0, 1160 throttle_secs=0, 1161 exporters=exporter) 1162 1163 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1164 executor.run_evaluator() 1165 1166 self.assertEqual(2, mock_est.evaluate.call_count) 1167 self.assertEqual(2, mock_est.times_export_fn_was_called) 1168 self.assertEqual(1, mock_est.times_the_final_export_was_true) 1169 1170 def test_skip_evaluation_due_to_ckpt(self): 1171 training_max_step = 200 1172 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1173 mock_est.evaluate.side_effect = [ 1174 {_GLOBAL_STEP_KEY: training_max_step // 2}, 1175 {_GLOBAL_STEP_KEY: training_max_step} 1176 ] 1177 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1178 mock_train_spec.max_steps = training_max_step 1179 1180 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1181 1182 # First two items are invalid, next two items are same. 1183 mock_est.latest_checkpoint.side_effect = [ 1184 None, '', 'same', 'same', 'path_2' 1185 ] 1186 1187 eval_spec = training.EvalSpec( 1188 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1189 1190 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1191 with test.mock.patch.object(logging, 'warning') as mock_log: 1192 executor.run_evaluator() 1193 1194 # Three checkpoint paths are invalid. 1195 self.assertEqual(5, mock_est.latest_checkpoint.call_count) 1196 self.assertEqual(2, mock_est.evaluate.call_count) 1197 1198 # Two warning logs are expected (last warning time is reset after a 1199 # successuful evaluation) 1200 self.assertEqual(2, mock_log.call_count) 1201 1202 def test_continuous_eval_listener_eval_result(self): 1203 training_max_step = 200 1204 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1205 expected_eval_metrics = [{ 1206 _GLOBAL_STEP_KEY: training_max_step // 2 1207 }, { 1208 _GLOBAL_STEP_KEY: training_max_step 1209 }] 1210 mock_est.evaluate.side_effect = expected_eval_metrics 1211 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1212 mock_train_spec.max_steps = training_max_step 1213 1214 class _Listener(training._ContinuousEvalListener): 1215 1216 def __init__(self): 1217 self.eval_results = [] 1218 1219 def after_eval(self, eval_result): 1220 self.eval_results.append(eval_result) 1221 return True 1222 1223 continuous_eval_listener = _Listener() 1224 1225 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1226 1227 # First two items are invalid, next two items are same. 1228 mock_est.latest_checkpoint.side_effect = [ 1229 None, '', 'same', 'same', 'path_2' 1230 ] 1231 expected_eval_results = [ 1232 training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT), 1233 training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT), 1234 training._EvalResult( 1235 training._EvalStatus.EVALUATED, 1236 metrics=expected_eval_metrics[0], 1237 checkpoint_path='same'), 1238 training._EvalResult(training._EvalStatus.NO_NEW_CHECKPOINT), 1239 training._EvalResult( 1240 training._EvalStatus.EVALUATED, 1241 metrics=expected_eval_metrics[1], 1242 checkpoint_path='path_2'), 1243 ] 1244 1245 eval_spec = training.EvalSpec( 1246 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) 1247 1248 executor = training._TrainingExecutor( 1249 mock_est, 1250 mock_train_spec, 1251 eval_spec, 1252 continuous_eval_listener=continuous_eval_listener) 1253 executor.run_evaluator() 1254 1255 # Three checkpoint paths are invalid. 1256 self.assertEqual(5, mock_est.latest_checkpoint.call_count) 1257 self.assertEqual(2, mock_est.evaluate.call_count) 1258 1259 self.assertEqual(5, len(continuous_eval_listener.eval_results)) 1260 for i, result in enumerate(continuous_eval_listener.eval_results): 1261 self.assertEqual(expected_eval_results[i].status, result.status) 1262 self.assertAllEqual(expected_eval_results[i].metrics, result.metrics) 1263 self.assertEqual(expected_eval_results[i].checkpoint_path, 1264 result.checkpoint_path) 1265 1266 def test_sleep_start_delay_secs(self): 1267 training_max_step = 200 1268 start_delay_secs = 123 1269 1270 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1271 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} 1272 mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) 1273 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1274 mock_train_spec.max_steps = training_max_step 1275 1276 eval_spec = training.EvalSpec( 1277 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', 1278 start_delay_secs=start_delay_secs, throttle_secs=0) 1279 1280 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1281 with test.mock.patch.object(time, 'sleep') as mock_sleep: 1282 executor.run_evaluator() 1283 mock_sleep.assert_called_with(start_delay_secs) 1284 self.assertTrue(mock_est.evaluate.called) 1285 1286 @test.mock.patch.object(time, 'time') 1287 @test.mock.patch.object(time, 'sleep') 1288 def test_throttle_secs(self, mock_sleep, mock_time): 1289 throttle_secs = 123 1290 operation_secs = 12 1291 1292 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1293 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1294 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1295 1296 eval_spec = training.EvalSpec( 1297 input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs) 1298 1299 mock_time.side_effect = [921, 921 + operation_secs] 1300 1301 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1302 # Disable logging as it calls time.time also. 1303 with test.mock.patch.object(logging, 'info'): 1304 executor.run_evaluator() 1305 mock_sleep.assert_called_with(throttle_secs - operation_secs) 1306 self.assertTrue(mock_est.evaluate.called) 1307 1308 def test_that_export_is_called(self): 1309 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1310 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1311 self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) 1312 1313 def export(estimator, *args, **kwargs): 1314 del args, kwargs 1315 estimator.export_was_called = True 1316 1317 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1318 exporter.name = 'see_whether_export_is_called' 1319 exporter.export = export 1320 1321 eval_spec = training.EvalSpec( 1322 input_fn=lambda: 1, 1323 steps=2, 1324 start_delay_secs=0, 1325 throttle_secs=0, 1326 exporters=exporter) 1327 1328 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1329 executor.run_evaluator() 1330 1331 # Verify that export was called on the right estimator. 1332 self.assertTrue(mock_est.export_was_called) 1333 1334 def test_errors_out_if_evaluate_returns_empty_dict(self): 1335 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1336 train_spec = training.TrainSpec(input_fn=lambda: 1) 1337 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1338 start_delay_secs=0, throttle_secs=0) 1339 mock_est.evaluate.return_value = {} 1340 1341 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1342 with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR): 1343 executor.run_evaluator() 1344 1345 def test_errors_out_if_evaluate_returns_non_dict(self): 1346 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1347 train_spec = training.TrainSpec(input_fn=lambda: 1) 1348 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1349 start_delay_secs=0, throttle_secs=0) 1350 mock_est.evaluate.return_value = 123 1351 1352 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1353 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): 1354 executor.run_evaluator() 1355 1356 def test_errors_out_if_evaluate_returns_dict_without_global_step(self): 1357 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1358 train_spec = training.TrainSpec(input_fn=lambda: 1) 1359 eval_spec = training.EvalSpec(input_fn=(lambda: 1), 1360 start_delay_secs=0, throttle_secs=0) 1361 mock_est.evaluate.return_value = {'loss': 123} 1362 1363 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1364 with self.assertRaisesRegexp(ValueError, 1365 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): 1366 executor.run_evaluator() 1367 1368 1369class TrainingExecutorRunPsTest(test.TestCase): 1370 """Tests run_ps of _TrainingExecutor.""" 1371 1372 @test.mock.patch.object(server_lib, 'Server') 1373 def test_std_server(self, mock_server): 1374 mock_server_instance = test.mock.Mock() 1375 mock_server.return_value = mock_server_instance 1376 1377 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1378 mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS) 1379 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1380 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1381 1382 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1383 mock_eval_spec) 1384 executor.run_ps() 1385 1386 mock_server.assert_called_with( 1387 mock_est.config.cluster_spec, 1388 job_name=mock_est.config.task_type, 1389 task_index=mock_est.config.task_id, 1390 config=test.mock.ANY, 1391 start=False) 1392 1393 self.assertTrue(mock_server_instance.start.called) 1394 self.assertTrue(mock_server_instance.join.called) 1395 1396 def test_fail_with_empty_cluster_spec(self): 1397 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1398 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1399 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1400 1401 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1402 mock_est.config.cluster_spec = None 1403 mock_est.config.master = 'grpc://...' 1404 mock_est.config.task_type = 'ps' 1405 mock_est.config.task_id = 2 1406 1407 with self.assertRaisesRegexp(RuntimeError, 1408 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1409 training._TrainingExecutor(mock_est, mock_train_spec, 1410 mock_eval_spec).run_ps() 1411 1412 def test_fail_with_empty_master(self): 1413 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1414 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1415 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1416 1417 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1418 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1419 mock_est.config.master = '' 1420 mock_est.config.task_type = 'ps' 1421 mock_est.config.task_id = 2 1422 1423 with self.assertRaisesRegexp(RuntimeError, 1424 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1425 training._TrainingExecutor(mock_est, mock_train_spec, 1426 mock_eval_spec).run_ps() 1427 1428 def test_fail_with_empty_task_type(self): 1429 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1430 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1431 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1432 1433 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1434 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1435 mock_est.config.master = 'grpc://...' 1436 mock_est.config.task_type = '' 1437 mock_est.config.task_id = 2 1438 1439 with self.assertRaisesRegexp(RuntimeError, 1440 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1441 training._TrainingExecutor(mock_est, mock_train_spec, 1442 mock_eval_spec).run_ps() 1443 1444 def test_fail_with_none_task_id(self): 1445 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1446 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1447 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1448 1449 mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) 1450 mock_est.config.cluster_spec = server_lib.ClusterSpec({'ps': ['dummy']}) 1451 mock_est.config.master = 'grpc://...' 1452 mock_est.config.task_type = 'ps' 1453 mock_est.config.task_id = None 1454 1455 with self.assertRaisesRegexp(RuntimeError, 1456 _INVALID_CONFIG_FOR_STD_SERVER_MSG): 1457 training._TrainingExecutor(mock_est, mock_train_spec, 1458 mock_eval_spec).run_ps() 1459 1460 1461class StopAtSecsHookTest(test.TestCase): 1462 """Tests StopAtSecsHook.""" 1463 1464 @test.mock.patch.object(time, 'time') 1465 def test_stops_after_time(self, mock_time): 1466 mock_time.return_value = 1484695987.209386 1467 hook = training._StopAtSecsHook(1000) 1468 with ops.Graph().as_default(): 1469 no_op = control_flow_ops.no_op() 1470 # some time passed before training starts 1471 mock_time.return_value += 250 1472 with monitored_session.MonitoredSession(hooks=[hook]) as sess: 1473 self.assertFalse(sess.should_stop()) 1474 sess.run(no_op) 1475 self.assertFalse(sess.should_stop()) 1476 mock_time.return_value += 500 1477 sess.run(no_op) 1478 self.assertFalse(sess.should_stop()) 1479 mock_time.return_value += 400 1480 sess.run(no_op) 1481 self.assertFalse(sess.should_stop()) 1482 mock_time.return_value += 200 1483 sess.run(no_op) 1484 self.assertTrue(sess.should_stop()) 1485 1486 1487class TrainingExecutorRunLocalTest(test.TestCase): 1488 """Tests run_local of _TrainingExecutor.""" 1489 1490 def unique_checkpoint_every_time_fn(self): 1491 return 'checkpoint_path_%s/' % random.random() 1492 1493 def test_send_stop_at_secs_to_train(self): 1494 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1495 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1496 train_spec = training.TrainSpec( 1497 input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) 1498 eval_spec = training.EvalSpec( 1499 input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) 1500 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1501 1502 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1503 executor.run_local() 1504 1505 stop_hook = mock_est.train.call_args[1]['hooks'][-1] 1506 self.assertIsInstance(stop_hook, training._StopAtSecsHook) 1507 self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs) 1508 1509 def test_runs_in_a_loop_until_max_steps(self): 1510 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1511 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1512 1513 mock_est.times_export_was_called = 0 1514 mock_est.times_final_export_was_called = 0 1515 def export(estimator, export_path, checkpoint_path, eval_result, 1516 is_the_final_export): 1517 del export_path, checkpoint_path, eval_result 1518 estimator.times_export_was_called += 1 1519 # final_export is happened at the end. 1520 self.assertEqual(0, estimator.times_final_export_was_called) 1521 if is_the_final_export: 1522 estimator.times_final_export_was_called += 1 1523 1524 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1525 exporter.name = 'see_how_many_times_export_is_called' 1526 exporter.export = export 1527 1528 train_spec = training.TrainSpec( 1529 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1530 eval_spec = training.EvalSpec( 1531 input_fn=lambda: 1, 1532 hooks=[_FakeHook()], 1533 throttle_secs=100, 1534 exporters=exporter) 1535 # should be called 3 times. 1536 mock_est.evaluate.side_effect = [{ 1537 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1538 }, { 1539 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1540 }, { 1541 _GLOBAL_STEP_KEY: train_spec.max_steps 1542 }] 1543 1544 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1545 executor.run_local() 1546 1547 self.assertEqual(3, mock_est.train.call_count) 1548 self.assertEqual(3, mock_est.evaluate.call_count) 1549 self.assertEqual(3, mock_est.times_export_was_called) 1550 self.assertEqual(1, mock_est.times_final_export_was_called) 1551 1552 def test_handles_no_new_checkpoint_found(self): 1553 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1554 mock_est.latest_checkpoint.return_value = ( 1555 'no_new_checkpoints_after_the_first_train_step') 1556 train_spec = training.TrainSpec( 1557 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1558 eval_spec = training.EvalSpec( 1559 input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) 1560 # It was going to be called 3 times. 1561 mock_est.evaluate.side_effect = [{ 1562 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1563 }, { 1564 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1565 }, { 1566 _GLOBAL_STEP_KEY: train_spec.max_steps 1567 }] 1568 1569 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1570 with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): 1571 executor.run_local() 1572 1573 def test_final_export_is_true_in_the_end(self): 1574 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1575 mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn 1576 1577 mock_est.times_export_fn_was_called = 0 1578 mock_est.times_the_final_export_was_true = 0 1579 def export(estimator, export_path, checkpoint_path, eval_result, 1580 is_the_final_export): 1581 del export_path, checkpoint_path, eval_result 1582 estimator.times_export_fn_was_called += 1 1583 if is_the_final_export: 1584 estimator.times_the_final_export_was_true += 1 1585 1586 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1587 exporter.name = 'see_how_many_times_export_is_called' 1588 exporter.export = export 1589 1590 train_spec = training.TrainSpec( 1591 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1592 eval_spec = training.EvalSpec( 1593 input_fn=lambda: 1, 1594 hooks=[_FakeHook()], 1595 throttle_secs=100, 1596 exporters=exporter) 1597 # should be called 3 times. 1598 mock_est.evaluate.side_effect = [{ 1599 _GLOBAL_STEP_KEY: train_spec.max_steps - 100 1600 }, { 1601 _GLOBAL_STEP_KEY: train_spec.max_steps - 50 1602 }, { 1603 _GLOBAL_STEP_KEY: train_spec.max_steps 1604 }] 1605 1606 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1607 executor.run_local() 1608 1609 self.assertEqual(3, mock_est.train.call_count) 1610 self.assertEqual(3, mock_est.evaluate.call_count) 1611 self.assertEqual(3, mock_est.times_export_fn_was_called) 1612 self.assertEqual(1, mock_est.times_the_final_export_was_true) 1613 1614 def test_train_and_evaluate_args(self): 1615 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1616 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 1617 train_spec = training.TrainSpec( 1618 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1619 eval_spec = training.EvalSpec( 1620 input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') 1621 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1622 1623 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1624 executor.run_local() 1625 1626 mock_est.evaluate.assert_called_with( 1627 name=eval_spec.name, 1628 input_fn=eval_spec.input_fn, 1629 steps=eval_spec.steps, 1630 checkpoint_path='checkpoint_path/', 1631 hooks=eval_spec.hooks) 1632 1633 train_args = mock_est.train.call_args[1] 1634 self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) 1635 self.assertEqual(train_spec.input_fn, train_args['input_fn']) 1636 self.assertEqual(train_spec.max_steps, train_args['max_steps']) 1637 1638 def test_train_hooks(self): 1639 mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') 1640 mock_est.latest_checkpoint.return_value = 'checkpoint_path/' 1641 train_spec = training.TrainSpec( 1642 input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) 1643 eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2) 1644 mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} 1645 extra_hooks = [_FakeHook()] 1646 1647 executor = training._TrainingExecutor( 1648 mock_est, train_spec, eval_spec, train_hooks=extra_hooks) 1649 executor.run_local() 1650 1651 train_args = mock_est.train.call_args[1] 1652 self.assertEqual( 1653 list(train_spec.hooks) + extra_hooks, [ 1654 h for h in train_args['hooks'] 1655 if not isinstance(h, training._StopAtSecsHook) 1656 ]) 1657 1658 def test_errors_out_if_throttle_secs_is_zero(self): 1659 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1660 train_spec = training.TrainSpec(input_fn=lambda: 1) 1661 eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) 1662 1663 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1664 with self.assertRaisesRegexp(ValueError, 'throttle_secs'): 1665 executor.run_local() 1666 1667 def test_that_export_is_called_with_run_local(self): 1668 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1669 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1670 mock_train_spec.max_steps = 200 1671 mock_est.evaluate.return_value = { 1672 _GLOBAL_STEP_KEY: mock_train_spec.max_steps 1673 } 1674 # _validate_hooks would have made sure that train_spec.hooks is [], when 1675 # None were passed. 1676 mock_train_spec.hooks = [] 1677 1678 def export(estimator, *args, **kwargs): 1679 del args, kwargs 1680 estimator.export_was_called = True 1681 1682 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) 1683 exporter.name = 'see_whether_export_is_called' 1684 exporter.export = export 1685 1686 eval_spec = training.EvalSpec( 1687 input_fn=lambda: 1, 1688 steps=2, 1689 start_delay_secs=0, 1690 throttle_secs=213, 1691 exporters=exporter) 1692 1693 executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) 1694 executor.run_local() 1695 1696 self.assertTrue(mock_est.export_was_called) 1697 1698 def test_errors_out_if_evaluate_returns_empty_dict(self): 1699 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1700 train_spec = training.TrainSpec(input_fn=lambda: 1) 1701 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1702 mock_est.evaluate.return_value = {} 1703 1704 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1705 with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR): 1706 executor.run_local() 1707 1708 def test_errors_out_if_evaluate_returns_non_dict(self): 1709 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1710 train_spec = training.TrainSpec(input_fn=lambda: 1) 1711 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1712 mock_est.evaluate.return_value = 123 1713 1714 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1715 with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): 1716 executor.run_local() 1717 1718 def test_errors_out_if_evaluate_returns_dict_without_global_step(self): 1719 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1720 train_spec = training.TrainSpec(input_fn=lambda: 1) 1721 eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) 1722 mock_est.evaluate.return_value = {'loss': 123} 1723 1724 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) 1725 with self.assertRaisesRegexp(ValueError, 1726 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): 1727 executor.run_local() 1728 1729 1730class TrainAndEvaluateRunTest(test.TestCase): 1731 1732 def _test_run_task_and_executor(self, run_config): 1733 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1734 mock_est.config = run_config 1735 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1736 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1737 1738 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1739 mock_eval_spec) 1740 1741 executor.call_task = {} 1742 1743 def task_fn(name): 1744 1745 def _fn(): 1746 executor.call_task[name] = 1 1747 1748 return _fn 1749 1750 executor.run_chief = task_fn('chief') 1751 executor.run_master = task_fn('master') 1752 executor.run_ps = task_fn('ps') 1753 executor.run_evaluator = task_fn('evaluator') 1754 executor.run_worker = task_fn('worker') 1755 executor.run_local = task_fn('local') 1756 return executor 1757 1758 def test_run_chief(self): 1759 executor = self._test_run_task_and_executor( 1760 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF)) 1761 executor.run() 1762 self.assertEqual(1, executor.call_task['chief']) 1763 1764 def test_run_worker(self): 1765 executor = self._test_run_task_and_executor( 1766 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER)) 1767 executor.run() 1768 self.assertEqual(1, executor.call_task['worker']) 1769 1770 def test_run_ps(self): 1771 executor = self._test_run_task_and_executor( 1772 run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)) 1773 executor.run() 1774 self.assertEqual(1, executor.call_task['ps']) 1775 1776 def test_run_evaluator(self): 1777 executor = self._test_run_task_and_executor( 1778 run_config=_create_run_config_with_cluster_spec( 1779 _TF_CONFIG_FOR_EVALUATOR)) 1780 executor.run() 1781 self.assertEqual(1, executor.call_task['evaluator']) 1782 1783 def test_run_local(self): 1784 executor = self._test_run_task_and_executor( 1785 run_config=run_config_lib.RunConfig()) 1786 executor.run() 1787 self.assertEqual(1, executor.call_task['local']) 1788 1789 def test_invalid_local_task(self): 1790 tf_config = { 1791 'cluster': { 1792 run_config_lib.TaskType.CHIEF: ['host0:0'], 1793 'local': ['hos1:1'], 1794 }, 1795 'task': { 1796 'type': 'local', # invalid task type. 1797 'index': 0 1798 } 1799 } 1800 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1801 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1802 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1803 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1804 1805 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1806 mock_eval_spec) 1807 with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER): 1808 executor.run() 1809 1810 def test_unsupported_task_due_to_missing_run_task(self): 1811 unsupported_task = 'alloc' 1812 tf_config = { 1813 'cluster': { 1814 run_config_lib.TaskType.CHIEF: ['host0:0'], 1815 unsupported_task: ['hos1:1'], 1816 }, 1817 'task': { 1818 'type': unsupported_task, 1819 'index': 0 1820 } 1821 } 1822 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1823 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1824 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1825 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1826 1827 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1828 mock_eval_spec) 1829 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): 1830 executor.run() 1831 1832 def test_unsupported_task_due_to_not_callable(self): 1833 unsupported_task = 'alloc' 1834 tf_config = { 1835 'cluster': { 1836 run_config_lib.TaskType.CHIEF: ['host0:0'], 1837 unsupported_task: ['hos1:1'], 1838 }, 1839 'task': { 1840 'type': unsupported_task, 1841 'index': 0 1842 } 1843 } 1844 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1845 mock_est.config = _create_run_config_with_cluster_spec(tf_config) 1846 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1847 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1848 1849 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1850 mock_eval_spec) 1851 executor.run_alloc = 123 # not callable 1852 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN): 1853 executor.run() 1854 1855 def test_invalid_task_type(self): 1856 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) 1857 mock_est.config = test.mock.Mock() 1858 mock_train_spec = test.mock.Mock(spec=training.TrainSpec) 1859 mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) 1860 1861 mock_est.config = test.mock.Mock() 1862 mock_est.config.cluster_spec = server_lib.ClusterSpec({'1': ['dummy']}) 1863 mock_est.config.task_type = '' 1864 1865 executor = training._TrainingExecutor(mock_est, mock_train_spec, 1866 mock_eval_spec) 1867 with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE): 1868 executor.run() 1869 1870 1871class TrainAndEvaluateIntegrationTest(test.TestCase): 1872 1873 def setUp(self): 1874 self._model_dir = tempfile.mkdtemp() 1875 1876 def tearDown(self): 1877 if self._model_dir: 1878 shutil.rmtree(self._model_dir) 1879 1880 def _as_label(self, data_in_float): 1881 return np.rint(data_in_float).astype(np.int64) 1882 1883 def _get_exporter(self, name, fc): 1884 feature_spec = feature_column.make_parse_example_spec(fc) 1885 serving_input_receiver_fn = ( 1886 export_lib.build_parsing_serving_input_receiver_fn(feature_spec)) 1887 return exporter_lib.LatestExporter( 1888 name, serving_input_receiver_fn=serving_input_receiver_fn) 1889 1890 def _extract_loss_and_global_step(self, event_folder): 1891 """Returns the loss and global step in last event.""" 1892 event_paths = glob.glob(os.path.join(event_folder, 'events*')) 1893 1894 loss = None 1895 global_step_count = None 1896 1897 for e in summary_iterator.summary_iterator(event_paths[-1]): 1898 current_loss = None 1899 for v in e.summary.value: 1900 if v.tag == 'loss': 1901 current_loss = v.simple_value 1902 1903 # If loss is not found, global step is meaningless. 1904 if current_loss is None: 1905 continue 1906 1907 current_global_step = e.step 1908 if global_step_count is None or current_global_step > global_step_count: 1909 global_step_count = current_global_step 1910 loss = current_loss 1911 1912 return (loss, global_step_count) 1913 1914 def test_complete_flow_with_non_distributed_configuration(self): 1915 n_classes = 3 1916 input_dimension = 2 1917 batch_size = 10 1918 1919 eval_name = 'foo' 1920 exporter_name = 'saved_model_exporter' 1921 1922 # max_steps should be larger than save_summary_steps 1923 max_steps = 10 1924 save_summary_steps = 2 1925 1926 data = np.linspace( 1927 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) 1928 x_data = data.reshape(batch_size, input_dimension) 1929 y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) 1930 1931 # learn y = x 1932 train_input_fn = numpy_io.numpy_input_fn( 1933 x={'x': x_data}, 1934 y=y_data, 1935 batch_size=batch_size, 1936 num_epochs=None, 1937 shuffle=True) 1938 1939 eval_input_fn = numpy_io.numpy_input_fn( 1940 x={'x': x_data}, 1941 y=y_data, 1942 batch_size=batch_size, 1943 num_epochs=1, 1944 shuffle=False) 1945 1946 predict_input_fn = numpy_io.numpy_input_fn( 1947 x={'x': x_data}, 1948 batch_size=batch_size, 1949 shuffle=False) 1950 1951 feature_columns = [ 1952 feature_column.numeric_column('x', shape=(input_dimension,))] 1953 1954 est = dnn.DNNClassifier( 1955 hidden_units=(2, 2), 1956 feature_columns=feature_columns, 1957 n_classes=n_classes, 1958 config=run_config_lib.RunConfig(save_summary_steps=save_summary_steps), 1959 model_dir=self._model_dir) 1960 1961 train_spec = training.TrainSpec(input_fn=train_input_fn, 1962 max_steps=max_steps) 1963 1964 eval_spec = training.EvalSpec( 1965 name=eval_name, input_fn=eval_input_fn, steps=None, 1966 exporters=self._get_exporter(exporter_name, feature_columns), 1967 throttle_secs=2) 1968 1969 training.train_and_evaluate(est, train_spec, eval_spec) 1970 1971 # Make sure nothing is stuck in limbo. 1972 writer_cache.FileWriterCache.clear() 1973 1974 # Examine the training events. Use a range to check global step to avoid 1975 # flakyness due to global step race condition. 1976 training_loss, training_global_step = self._extract_loss_and_global_step( 1977 est.model_dir) 1978 self.assertIsNotNone(training_loss) 1979 self.assertTrue( 1980 max_steps - save_summary_steps < training_global_step <= max_steps) 1981 1982 # Examine the eval events. The global step should be accurate. 1983 eval_loss, eval_global_step = self._extract_loss_and_global_step( 1984 event_folder=os.path.join(est.model_dir, 'eval_' + eval_name)) 1985 self.assertIsNotNone(eval_loss) 1986 self.assertEqual(max_steps, eval_global_step) 1987 1988 # Examine the export folder. 1989 export_dir = os.path.join(os.path.join(est.model_dir, 'export'), 1990 exporter_name) 1991 self.assertTrue(gfile.Exists(export_dir)) 1992 1993 # Examine the ckpt for predict. 1994 predicted_proba = np.array([ 1995 x[prediction_keys.PredictionKeys.PROBABILITIES] 1996 for x in est.predict(predict_input_fn) 1997 ]) 1998 self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) 1999 2000 2001if __name__ == '__main__': 2002 test.main() 2003