1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions related to train_and_evaluate.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import json 23import os 24import time 25 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.estimator import estimator as estimator_lib 30from tensorflow.python.estimator import exporter as exporter_lib 31from tensorflow.python.estimator import run_config as run_config_lib 32from tensorflow.python.framework import ops 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.training import basic_session_run_hooks 35from tensorflow.python.training import server_lib 36from tensorflow.python.training import session_run_hook 37from tensorflow.python.util import compat 38from tensorflow.python.util.tf_export import tf_export 39 40_MAX_DELAY_SECS = 60 41_DELAY_SECS_PER_WORKER = 5 42_TF_CONFIG_ENV = 'TF_CONFIG' 43_ENVIRONMENT_KEY = 'environment' 44_ENVIRONMENT_GOOGLE_VALUE = 'google' 45_TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER, 46 run_config_lib.TaskType.WORKER) 47 48 49def _validate_input_fn(input_fn): 50 """Validates the `input_fn`.""" 51 if not callable(input_fn): 52 raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn)) 53 54 55def _validate_hooks(hooks): 56 """Validates the `hooks`.""" 57 hooks = tuple(hooks or []) 58 for hook in hooks: 59 if not isinstance(hook, session_run_hook.SessionRunHook): 60 raise TypeError( 61 'All hooks must be `SessionRunHook` instances, given: {}'.format( 62 hook)) 63 return hooks 64 65 66def _validate_exporters(exporters): 67 """Validates `exporters` and returns them as a tuple.""" 68 if not exporters: 69 return () 70 71 if isinstance(exporters, exporter_lib.Exporter): 72 exporters = [exporters] 73 74 unique_names = [] # `Exporter`s should have unique names. 75 try: 76 for exporter in exporters: 77 if not isinstance(exporter, exporter_lib.Exporter): 78 # Error message will be printed out by the outer try/except. 79 raise TypeError 80 81 if not exporter.name: 82 full_list_of_names = [e.name for e in exporters] 83 raise ValueError('An Exporter cannot have a name that is `None` or' 84 ' empty. All exporter names:' 85 ' {}'.format(full_list_of_names)) 86 87 if not isinstance(exporter.name, six.string_types): 88 raise ValueError('An Exporter must have a string name. Given: ' 89 '{}'.format(type(exporter.name))) 90 91 if exporter.name in unique_names: 92 full_list_of_names = [e.name for e in exporters] 93 raise ValueError( 94 '`exporters` must have unique names. Such a name cannot be `None`.' 95 ' All exporter names: {}'.format(full_list_of_names)) 96 unique_names.append(exporter.name) 97 except TypeError: 98 # Two possibilities: 99 # - `exporters` is neither `Exporter` nor iterable. Python has 100 # raised a `TypeError` when iterating over `exporters`. 101 # - an `exporter` was None or not of type `Exporter`, so we raised a 102 # `TypeError`. 103 raise TypeError('`exporters` must be an Exporter,' 104 ' an iterable of Exporter, or `None`,' 105 ' found %s.' % exporters) 106 107 return tuple(exporters) 108 109 110def _is_google_env(): 111 """Detects whether current environment is google.""" 112 tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}') 113 if not tf_config: 114 logging.warn('TF_CONFIG should not be empty in distributed environment.') 115 return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE 116 117 118@tf_export('estimator.TrainSpec') 119class TrainSpec( 120 collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): 121 """Configuration for the "train" part for the `train_and_evaluate` call. 122 123 `TrainSpec` determines the input data for the training, as well as the 124 duration. Optional hooks run at various stages of training. 125 """ 126 127 def __new__(cls, input_fn, max_steps=None, hooks=None): 128 """Creates a validated `TrainSpec` instance. 129 130 Args: 131 input_fn: Training input function returning a tuple of: 132 features - `Tensor` or dictionary of string feature name to `Tensor`. 133 labels - `Tensor` or dictionary of `Tensor` with labels. 134 max_steps: Int. Positive number of total steps for which to train model. 135 If `None`, train forever. The training `input_fn` is not expected to 136 generate `OutOfRangeError` or `StopIteration` exceptions. See the 137 `train_and_evaluate` stop condition section for details. 138 hooks: Iterable of `tf.train.SessionRunHook` objects to run 139 on all workers (including chief) during training. 140 141 Returns: 142 A validated `TrainSpec` object. 143 144 Raises: 145 ValueError: If any of the input arguments is invalid. 146 TypeError: If any of the arguments is not of the expected type. 147 """ 148 # Validate input_fn. 149 _validate_input_fn(input_fn) 150 151 # Validate max_steps. 152 if max_steps is not None and max_steps <= 0: 153 raise ValueError( 154 'Must specify max_steps > 0, given: {}'.format(max_steps)) 155 156 # Validate hooks. 157 hooks = _validate_hooks(hooks) 158 159 return super(TrainSpec, cls).__new__( 160 cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks) 161 162 163@tf_export('estimator.EvalSpec') 164class EvalSpec( 165 collections.namedtuple('EvalSpec', [ 166 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs', 167 'throttle_secs' 168 ])): 169 """Configuration for the "eval" part for the `train_and_evaluate` call. 170 171 `EvalSpec` combines details of evaluation of the trained model as well as its 172 export. Evaluation consists of computing metrics to judge the performance of 173 the trained model. Export writes out the trained model on to external 174 storage. 175 """ 176 177 def __new__(cls, 178 input_fn, 179 steps=100, 180 name=None, 181 hooks=None, 182 exporters=None, 183 start_delay_secs=120, 184 throttle_secs=600): 185 """Creates a validated `EvalSpec` instance. 186 187 Args: 188 input_fn: Evaluation input function returning a tuple of: 189 features - `Tensor` or dictionary of string feature name to `Tensor`. 190 labels - `Tensor` or dictionary of `Tensor` with labels. 191 steps: Int. Positive number of steps for which to evaluate model. If 192 `None`, evaluates until `input_fn` raises an end-of-input exception. 193 See `Estimator.evaluate` for details. 194 name: String. Name of the evaluation if user needs to run multiple 195 evaluations on different data sets. Metrics for different evaluations 196 are saved in separate folders, and appear separately in tensorboard. 197 hooks: Iterable of `tf.train.SessionRunHook` objects to run 198 during evaluation. 199 exporters: Iterable of `Exporter`s, or a single one, or `None`. 200 `exporters` will be invoked after each evaluation. 201 start_delay_secs: Int. Start evaluating after waiting for this many 202 seconds. 203 throttle_secs: Int. Do not re-evaluate unless the last evaluation was 204 started at least this many seconds ago. Of course, evaluation does not 205 occur if no new checkpoints are available, hence, this is the minimum. 206 207 Returns: 208 A validated `EvalSpec` object. 209 210 Raises: 211 ValueError: If any of the input arguments is invalid. 212 TypeError: If any of the arguments is not of the expected type. 213 """ 214 # Validate input_fn. 215 _validate_input_fn(input_fn) 216 217 # Validate steps. 218 if steps is not None and steps <= 0: 219 raise ValueError('Must specify steps > 0, given: {}'.format(steps)) 220 221 # Validate name. 222 if name is not None and not isinstance(name, six.string_types): 223 raise TypeError('`name` must be string, given: {}'.format(name)) 224 225 # Validate hooks. 226 hooks = _validate_hooks(hooks) 227 228 # Validate exporters. 229 exporters = _validate_exporters(exporters) 230 231 # Validate start_delay_secs. 232 if start_delay_secs < 0: 233 raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format( 234 start_delay_secs)) 235 236 # Validate throttle_secs. 237 if throttle_secs < 0: 238 raise ValueError( 239 'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs)) 240 241 return super(EvalSpec, cls).__new__( 242 cls, 243 input_fn=input_fn, 244 steps=steps, 245 name=name, 246 hooks=hooks, 247 exporters=exporters, 248 start_delay_secs=start_delay_secs, 249 throttle_secs=throttle_secs) 250 251 252@tf_export('estimator.train_and_evaluate') 253def train_and_evaluate(estimator, train_spec, eval_spec): 254 """Train and evaluate the `estimator`. 255 256 This utility function trains, evaluates, and (optionally) exports the model by 257 using the given `estimator`. All training related specification is held in 258 `train_spec`, including training `input_fn` and training max steps, etc. All 259 evaluation and export related specification is held in `eval_spec`, including 260 evaluation `input_fn`, steps, etc. 261 262 This utility function provides consistent behavior for both local 263 (non-distributed) and distributed configurations. Currently, the only 264 supported distributed training configuration is between-graph replication. 265 266 Overfitting: In order to avoid overfitting, it is recommended to set up the 267 training `input_fn` to shuffle the training data properly. It is also 268 recommended to train the model a little longer, say multiple epochs, before 269 performing evaluation, as the input pipeline starts from scratch for each 270 training. It is particularly important for local training and evaluation. 271 272 Stop condition: In order to support both distributed and non-distributed 273 configuration reliably, the only supported stop condition for model 274 training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the 275 model is trained forever. *Use with care* if model stop condition is 276 different. For example, assume that the model is expected to be trained with 277 one epoch of training data, and the training `input_fn` is configured to throw 278 `OutOfRangeError` after going through one epoch, which stops the 279 `Estimator.train`. For a three-training-worker distributed configuration, each 280 training worker is likely to go through the whole epoch independently. So, the 281 model will be trained with three epochs of training data instead of one epoch. 282 283 Example of local (non-distributed) training: 284 ```python 285 # Set up feature columns. 286 categorial_feature_a = categorial_column_with_hash_bucket(...) 287 categorial_feature_a_emb = embedding_column( 288 categorical_column=categorial_feature_a, ...) 289 ... # other feature columns 290 291 estimator = DNNClassifier( 292 feature_columns=[categorial_feature_a_emb, ...], 293 hidden_units=[1024, 512, 256]) 294 295 # Or set up the model directory 296 # estimator = DNNClassifier( 297 # config=tf.estimator.RunConfig( 298 # model_dir='/my_model', save_summary_steps=100), 299 # feature_columns=[categorial_feature_a_emb, ...], 300 # hidden_units=[1024, 512, 256]) 301 302 # Input pipeline for train and evaluate. 303 def train_input_fn: # returns x, y 304 # please shuffle the data. 305 pass 306 def eval_input_fn_eval: # returns x, y 307 pass 308 309 train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000) 310 eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 311 312 tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 313 ``` 314 315 Example of distributed training: 316 317 Regarding the example of distributed training, the code above can be used 318 without a change (Please do make sure that the `RunConfig.model_dir` for all 319 workers is set to the same directory, i.e., a shared file system all workers 320 can read and write). The only extra work to do is setting the environment 321 variable `TF_CONFIG` properly for each worker correspondingly. 322 323 Also see: https://www.tensorflow.org/deploy/distributed 324 325 Setting environment variable depends on the platform. For example, on Linux, 326 it can be done as follows (`$` is the shell prompt): 327 ``` 328 $ TF_CONFIG='<replace_with_real_content>' python train_model.py 329 ``` 330 331 For the content in `TF_CONFIG`, assume that the training cluster spec looks 332 like: 333 ``` 334 cluster = {"chief": ["host0:2222"], 335 "worker": ["host1:2222", "host2:2222", "host3:2222"], 336 "ps": ["host4:2222", "host5:2222"]} 337 ``` 338 339 Example of `TF_CONFIG` for chief training worker (must have one and only one): 340 ``` 341 # This should be a JSON string, which is set as environment variable. Usually 342 # the cluster manager handles that. 343 TF_CONFIG='{ 344 "cluster": { 345 "chief": ["host0:2222"], 346 "worker": ["host1:2222", "host2:2222", "host3:2222"], 347 "ps": ["host4:2222", "host5:2222"] 348 }, 349 "task": {"type": "chief", "index": 0} 350 }' 351 ``` 352 Note that the chief worker also does the model training job, similar to other 353 non-chief training workers (see next paragraph). In addition to the model 354 training, it manages some extra work, e.g., checkpoint saving and restoring, 355 writing summaries, etc. 356 357 Example of `TF_CONFIG` for non-chief training worker (optional, could be 358 multiple): 359 ``` 360 # This should be a JSON string, which is set as environment variable. Usually 361 # the cluster manager handles that. 362 TF_CONFIG='{ 363 "cluster": { 364 "chief": ["host0:2222"], 365 "worker": ["host1:2222", "host2:2222", "host3:2222"], 366 "ps": ["host4:2222", "host5:2222"] 367 }, 368 "task": {"type": "worker", "index": 0} 369 }' 370 ``` 371 where the `task.index` should be set as 0, 1, 2, in this example, respectively 372 for non-chief training workers. 373 374 Example of `TF_CONFIG` for parameter server, aka ps (could be multiple): 375 ``` 376 # This should be a JSON string, which is set as environment variable. Usually 377 # the cluster manager handles that. 378 TF_CONFIG='{ 379 "cluster": { 380 "chief": ["host0:2222"], 381 "worker": ["host1:2222", "host2:2222", "host3:2222"], 382 "ps": ["host4:2222", "host5:2222"] 383 }, 384 "task": {"type": "ps", "index": 0} 385 }' 386 ``` 387 where the `task.index` should be set as 0 and 1, in this example, respectively 388 for parameter servers. 389 390 Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is 391 not part of the training cluster. There could be only one. It is used for 392 model evaluation. 393 ``` 394 # This should be a JSON string, which is set as environment variable. Usually 395 # the cluster manager handles that. 396 TF_CONFIG='{ 397 "cluster": { 398 "chief": ["host0:2222"], 399 "worker": ["host1:2222", "host2:2222", "host3:2222"], 400 "ps": ["host4:2222", "host5:2222"] 401 }, 402 "task": {"type": "evaluator", "index": 0} 403 }' 404 ``` 405 406 Args: 407 estimator: An `Estimator` instance to train and evaluate. 408 train_spec: A `TrainSpec` instance to specify the training specification. 409 eval_spec: A `EvalSpec` instance to specify the evaluation and export 410 specification. 411 412 Raises: 413 ValueError: if environment variable `TF_CONFIG` is incorrectly set. 414 """ 415 executor = _TrainingExecutor( 416 estimator=estimator, train_spec=train_spec, eval_spec=eval_spec) 417 418 config = estimator.config 419 if (config.task_type == run_config_lib.TaskType.EVALUATOR and 420 config.task_id > 0): 421 raise ValueError( 422 'For distributed training, there can only be one `evaluator` task ' 423 '(with task id 0). Given task id {}'.format(config.task_id)) 424 425 executor.run() 426 427 428class _StopAtSecsHook(session_run_hook.SessionRunHook): 429 """Stops given secs after begin is called.""" 430 431 def __init__(self, stop_after_secs): 432 self._stop_after_secs = stop_after_secs 433 self._start_time = None 434 435 def begin(self): 436 self._start_time = time.time() 437 438 def after_run(self, run_context, run_values): 439 del run_values 440 if time.time() - self._start_time >= self._stop_after_secs: 441 run_context.request_stop() 442 443 444class _TrainingExecutor(object): 445 """The executor to run `Estimator` training and evaluation. 446 447 This implementation supports both distributed and non-distributed (aka local) 448 training and evaluation based on the setting in `tf.estimator.RunConfig`. 449 """ 450 451 def __init__(self, 452 estimator, 453 train_spec, 454 eval_spec, 455 train_hooks=None, 456 continuous_eval_listener=None): 457 if not isinstance(estimator, estimator_lib.Estimator): 458 raise TypeError('`estimator` must have type `tf.estimator.Estimator`.') 459 self._estimator = estimator 460 461 if not isinstance(train_spec, TrainSpec): 462 raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.') 463 self._train_spec = train_spec 464 465 if not isinstance(eval_spec, EvalSpec): 466 raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.') 467 self._eval_spec = eval_spec 468 469 self._train_hooks = _validate_hooks(train_hooks) 470 471 if (continuous_eval_listener and 472 not isinstance(continuous_eval_listener, _ContinuousEvalListener)): 473 raise TypeError('`continuous_eval_listener` must have type ' 474 '`_ContinuousEvalListener`.') 475 self._continuous_eval_listener = ( 476 continuous_eval_listener or _ContinuousEvalListener()) 477 478 @property 479 def estimator(self): 480 return self._estimator 481 482 def run(self): 483 """Executes the run_foo for task type `foo`. 484 485 `_TrainingExecutor` predefines the procedure for task type 'chief', 486 'worker', 'ps', and 'evaluator'. For task type `foo`, the corresponding 487 procedure is `run_foo'. This `run` method invoke the procedure base on the 488 `RunConfig.task_type`. 489 490 Raises: 491 ValueError: if the estimator.config is mis-configured. 492 """ 493 config = self._estimator.config 494 495 if (not config.cluster_spec and 496 config.task_type != run_config_lib.TaskType.EVALUATOR): 497 logging.info('Running training and evaluation locally (non-distributed).') 498 self.run_local() 499 return 500 501 # Distributed case. 502 if not config.task_type: 503 # TODO(xiejw): Improve the error message about how to set the TF_CONFIG 504 # correctly. 505 raise ValueError( 506 '`estimator.config` must have task_type set. This usually means ' 507 'TF_CONFIG environment is not set correctly.') 508 509 if config.task_type == 'local': 510 raise ValueError( 511 '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and ' 512 '`task` properties in TF_CONFIG absent triggers train and evaluate ' 513 '`Estimator` locally (non-distributed).') 514 515 # For task type foo, call executor.run_foo. 516 available_tasks = [ 517 x for x in dir(self) 518 if x.startswith('run_') and x != 'run_local' and 519 callable(getattr(self, x)) 520 ] 521 task_to_run = 'run_' + config.task_type 522 if task_to_run not in available_tasks: 523 raise ValueError( 524 'Task type {} is not supported. Supported task types are {}'.format( 525 config.task_type, [x[len('run_'):] for x in available_tasks])) 526 getattr(self, task_to_run)() 527 528 def run_chief(self): 529 """Runs task chief.""" 530 # TODO(xiejw): To allow execution framework to add train hooks. 531 return self._start_distributed_training() 532 533 def run_worker(self): 534 """Runs task (training) worker.""" 535 # TODO(xiejw): To allow execution framework to add train hooks. 536 return self._start_distributed_training() 537 538 def run_master(self): 539 """Runs task master.""" 540 541 class NewCheckpointListener( 542 basic_session_run_hooks.CheckpointSaverListener): 543 544 def __init__(self, evaluator, eval_throttle_secs): 545 self._evaluator = evaluator 546 self._eval_throttle_secs = eval_throttle_secs 547 548 def begin(self): 549 self._timer = basic_session_run_hooks.SecondOrStepTimer( 550 every_secs=self._eval_throttle_secs) 551 552 def after_save(self, session, global_step_value): 553 del session # unused; required by signature. 554 555 if self._timer.should_trigger_for_step(global_step_value): 556 self._timer.update_last_triggered_step(global_step_value) 557 self._evaluator.evaluate_and_export() 558 else: 559 logging.info('Skip the current checkpoint eval due to throttle secs ' 560 '({} secs).'.format(self._eval_throttle_secs)) 561 562 # Final export signal: For any eval result with global_step >= train 563 # max_steps, the evaluator will send the final export signal. There is a 564 # small chance that the Estimator.train stopping logic sees a different 565 # global_step value (due to global step race condition and the fact the 566 # saver sees a larger value for checkpoing saving), which does not end 567 # the training. When the training ends, a new checkpoint is generated, which 568 # triggers the listener again. So, it could be the case the final export is 569 # triggered twice. 570 # 571 # But here, throttle_secs will skip the next intermediate checkpoint and, 572 # so, the double final export chance is very small. 573 evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, 574 self._train_spec.max_steps) 575 576 # When the underlying `Estimator` object saves a new checkpoint, we would 577 # like this callback to be called so that evaluation and export can trigger. 578 saving_listeners = [ 579 NewCheckpointListener(evaluator, self._eval_spec.throttle_secs) 580 ] 581 self._start_distributed_training(saving_listeners=saving_listeners) 582 583 if not evaluator.is_final_export_triggered: 584 logging.info('Training has already ended. But the last eval is skipped ' 585 'due to eval throttle_secs. Now evaluating the final ' 586 'checkpoint.') 587 evaluator.evaluate_and_export() 588 589 def run_evaluator(self): 590 """Runs task evaluator.""" 591 # TODO(xiejw): To allow execution framework to add continuous eval listener. 592 return self._start_continuous_evaluation() 593 594 def run_ps(self): 595 """Runs task parameter server (in training cluster spec).""" 596 config = self._estimator.config 597 server = self._start_std_server(config) 598 server.join() 599 600 def run_local(self): 601 """Runs training and evaluation locally (non-distributed).""" 602 603 def _should_stop_local_train(global_step): 604 if self._train_spec.max_steps is None: 605 return False 606 if global_step >= self._train_spec.max_steps: 607 return True 608 return False 609 610 if self._eval_spec.throttle_secs <= 0: 611 raise ValueError('eval_spec.throttle_secs should be positive, given: {}.' 612 'It is used do determine how long each training ' 613 'iteration should go when train and evaluate ' 614 'locally.'.format(self._eval_spec.throttle_secs)) 615 616 stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs) 617 train_hooks = ( 618 list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks)) 619 logging.info('Start train and evaluate loop. The evaluate will happen ' 620 'after {} secs (eval_spec.throttle_secs) or training is ' 621 'finished.'.format(self._eval_spec.throttle_secs)) 622 623 evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, 624 self._train_spec.max_steps) 625 626 while True: 627 self._estimator.train( 628 input_fn=self._train_spec.input_fn, 629 max_steps=self._train_spec.max_steps, 630 hooks=train_hooks) 631 632 # Final export signal: For any eval result with global_step >= train 633 # max_steps, the evaluator will send the final export signal. The 634 # _should_stop_local_train will then end the while True as the stopping 635 # condition is satisfied (both checks use the same global_step value, 636 # i.e., no race condition) 637 eval_result = evaluator.evaluate_and_export() 638 639 if eval_result.status != _EvalStatus.EVALUATED: 640 # This is unexpected; should never happen. 641 # Training should always end with a new checkpoint. 642 raise RuntimeError('There was no new checkpoint after the training. ' 643 'Eval status: {}'.format(eval_result.status)) 644 645 if _should_stop_local_train( 646 eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): 647 break 648 649 def _start_std_server(self, config): 650 """Creates, starts, and returns a server_lib.Server.""" 651 if (not config.cluster_spec or not config.task_type or 652 config.task_id is None): 653 raise RuntimeError('Could not start server; be sure to specify ' 654 'cluster_spec, task_type, and task in ' 655 'RunConfig or set the TF_CONFIG environment variable.') 656 657 if not config.master: 658 jobs = config.cluster_spec.jobs 659 if (len(jobs) == 1 and 660 len(config.cluster_spec.job_tasks(jobs[0])) == 1 and 661 config.task_type in _TRAINER_JOBS): 662 # For distributed training, config.master is empty if and only if it has 663 # a single node in the cluster spec. In this case, we should not start 664 # the server. 665 logging.info('Skip starting Tensorflow server as there is only one ' 666 'node in the cluster.') 667 return 668 else: 669 raise RuntimeError( 670 'Could not start server; be sure to specify master in ' 671 'RunConfig or set the TF_CONFIG environment variable.') 672 673 logging.info('Start Tensorflow server.') 674 675 if config.session_config is None: 676 session_config = config_pb2.ConfigProto(log_device_placement=False) 677 else: 678 session_config = config_pb2.ConfigProto( 679 log_device_placement=False, 680 gpu_options=config.session_config.gpu_options) 681 682 server = server_lib.Server( 683 config.cluster_spec, 684 job_name=config.task_type, 685 task_index=config.task_id, 686 config=session_config, 687 start=False) 688 server.start() 689 return server 690 691 def _start_distributed_training(self, saving_listeners=None): 692 """Calls `Estimator` train in a distributed setting.""" 693 config = self._estimator.config 694 695 # Start in-process TensorFlow server if needed. It's important to start the 696 # server before we (optionally) sleep. Otherwise, the servers will wait to 697 # connect to each other before starting to train. 698 if not _is_google_env(): 699 self._start_std_server(config) 700 701 # Delay worker to start. For asynchronous training, this usually helps model 702 # to converge faster. Chief starts the training immediately, so, worker 703 # with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER. 704 start_delay_secs = 0 705 if config.task_type == run_config_lib.TaskType.WORKER: 706 # TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in 707 # training cluster. 708 start_delay_secs = min(_MAX_DELAY_SECS, 709 (config.task_id + 1) * _DELAY_SECS_PER_WORKER) 710 if start_delay_secs > 0: 711 logging.info('Waiting %d secs before starting training.', 712 start_delay_secs) 713 time.sleep(start_delay_secs) 714 715 self._estimator.train( 716 input_fn=self._train_spec.input_fn, 717 max_steps=self._train_spec.max_steps, 718 hooks=list(self._train_spec.hooks) + list(self._train_hooks), 719 saving_listeners=saving_listeners) 720 721 def _start_continuous_evaluation(self): 722 """Repeatedly calls `Estimator` evaluate and export until training ends.""" 723 start_delay_secs = self._eval_spec.start_delay_secs 724 if start_delay_secs: 725 logging.info('Waiting %f secs before starting eval.', start_delay_secs) 726 time.sleep(start_delay_secs) 727 728 latest_eval_result = None 729 evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, 730 self._train_spec.max_steps) 731 732 should_early_stop = False 733 while not should_early_stop: 734 if (latest_eval_result and 735 latest_eval_result.status == _EvalStatus.EVALUATED): 736 global_step = latest_eval_result.metrics.get(ops.GraphKeys.GLOBAL_STEP) 737 if (global_step and self._train_spec.max_steps and 738 global_step >= self._train_spec.max_steps): 739 logging.info( 740 'Exiting evaluation, global_step=%s >= train max_steps=%s', 741 global_step, self._train_spec.max_steps) 742 return 743 744 latest_eval_result, should_early_stop = self._execute_evaluator_once( 745 evaluator, self._continuous_eval_listener, 746 self._eval_spec.throttle_secs) 747 748 def _execute_evaluator_once(self, evaluator, continuous_eval_listener, 749 throttle_secs): 750 """Executes the `evaluator`.""" 751 start = time.time() 752 753 eval_result = None 754 should_early_stop = False 755 756 if not continuous_eval_listener.before_eval(): 757 logging.info('Exiting evaluation, as requested by ' 758 '_ContinuousEvalListener.before_eval.') 759 should_early_stop = True 760 return (eval_result, should_early_stop) 761 762 # Final export signal: For any eval result with global_step >= train 763 # max_steps, the evaluator will send the final export signal. The next 764 # iteration of while loop will end the continuous eval as the stopping 765 # condition is satisfied (both checks use the same global_step value, 766 # i.e., no race condition) 767 eval_result = evaluator.evaluate_and_export() 768 769 if not self._continuous_eval_listener.after_eval(eval_result): 770 logging.info('Exiting evaluation, as requested by ' 771 '_ContinuousEvalListener.after_eval.') 772 should_early_stop = True 773 return (eval_result, should_early_stop) 774 775 # Throttle if necessary. 776 elapsed_time = time.time() - start 777 difference = throttle_secs - elapsed_time 778 if difference > 0: 779 logging.info('Waiting %f secs before starting next eval run.', difference) 780 time.sleep(difference) 781 782 return (eval_result, should_early_stop) 783 784 class _Evaluator(object): 785 """A helper class to call `Estimator.evaluate` and export model.""" 786 787 def __init__(self, estimator, eval_spec, max_training_steps): 788 self._estimator = estimator 789 self._eval_spec = eval_spec 790 self._is_final_export_triggered = False 791 self._previous_ckpt_path = None 792 self._last_warning_time = 0 793 self._max_training_steps = max_training_steps 794 795 @property 796 def is_final_export_triggered(self): 797 return self._is_final_export_triggered 798 799 def evaluate_and_export(self): 800 """Evaluate and (maybe) export the current model. 801 802 Returns: 803 An `EvalResult` instance. 804 805 Raises: 806 RuntimeError: for any unexpected internal error. 807 TypeError: if evaluation result has wrong type. 808 """ 809 latest_ckpt_path = self._estimator.latest_checkpoint() 810 if not latest_ckpt_path: 811 self._log_err_msg('Estimator is not trained yet. Will start an ' 812 'evaluation when a checkpoint is ready.') 813 return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT) 814 815 if latest_ckpt_path == self._previous_ckpt_path: 816 self._log_err_msg( 817 'No new checkpoint ready for evaluation. Skip the current ' 818 'evaluation pass as evaluation results are expected to be same ' 819 'for the same checkpoint.') 820 return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT) 821 822 metrics = self._estimator.evaluate( 823 input_fn=self._eval_spec.input_fn, 824 steps=self._eval_spec.steps, 825 name=self._eval_spec.name, 826 checkpoint_path=latest_ckpt_path, 827 hooks=self._eval_spec.hooks) 828 829 # _EvalResult validates the metrics. 830 eval_result = _EvalResult( 831 status=_EvalStatus.EVALUATED, 832 metrics=metrics, 833 checkpoint_path=latest_ckpt_path) 834 835 is_the_final_export = ( 836 eval_result.metrics[ops.GraphKeys.GLOBAL_STEP] >= 837 self._max_training_steps if self._max_training_steps else False) 838 self._export_eval_result(eval_result, is_the_final_export) 839 840 if is_the_final_export: 841 logging.debug('Calling exporter with the `is_the_final_export=True`.') 842 self._is_final_export_triggered = True 843 844 self._last_warning_time = 0 845 self._previous_ckpt_path = latest_ckpt_path 846 return eval_result 847 848 def _log_err_msg(self, message): 849 """Prints warning `message` every 10 mins.""" 850 current_time = time.time() 851 if current_time - self._last_warning_time > 600: 852 logging.warning(message) 853 self._last_warning_time = current_time 854 855 def _export_eval_result(self, eval_result, is_the_final_export): 856 """Export `eval_result` according to exporters in `EvalSpec`.""" 857 export_dir_base = os.path.join( 858 compat.as_str_any(self._estimator.model_dir), 859 compat.as_str_any('export')) 860 861 for exporter in self._eval_spec.exporters: 862 exporter.export( 863 estimator=self._estimator, 864 export_path=os.path.join( 865 compat.as_str_any(export_dir_base), 866 compat.as_str_any(exporter.name)), 867 checkpoint_path=eval_result.checkpoint_path, 868 eval_result=eval_result.metrics, 869 is_the_final_export=is_the_final_export) 870 871 872class _EvalStatus(object): 873 """The status of an evaluation event. 874 875 For local training and evaluation, the status can only be `EVALUATED` as 876 `Estimator.train` always generates a new checkpoint. 877 878 For distributed training and evaluation, a separated evaluator keeps looking 879 for new checkpoint. So, multiple situations might occur: 880 881 - EVALUATED: A new checkpoint is found since last evaluation. 882 `Estimator.evaluate` will be invoked. 883 - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means 884 the trainer has not yet produced any checkpoint. 885 - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation. 886 Typically, this means the trainer has not yet produced any new checkpoint. 887 """ 888 889 EVALUATED = 'evaluated' 890 MISSING_CHECKPOINT = 'missing checkpoint' 891 NO_NEW_CHECKPOINT = 'no new checkpoint' 892 893 894class _EvalResult( 895 collections.namedtuple('EvalResult', 896 ['status', 'metrics', 'checkpoint_path'])): 897 """_EvalResult holds the result of an evaluation event.""" 898 899 def __new__(cls, status, metrics=None, checkpoint_path=None): 900 """Creates a validated `_EvalResult`. 901 902 Args: 903 status: See `_EvalStatus`. 904 metrics: The evaluation results returned by `Estimator.evaluate`. Only set 905 if status is `EVALUATED`. 906 checkpoint_path: The corresponding checkpoint path for the `metrics`. Only 907 set if status is `EVALUATED`. 908 Returns: 909 A validated `_EvalResult` object. 910 911 Raises: 912 ValueError: If validation fails. 913 TypeError: If any of the arguments is not the expected type. 914 """ 915 916 if status != _EvalStatus.EVALUATED: 917 if metrics: 918 raise ValueError( 919 'metrics must be `None` if status is not {}; got status {},' 920 ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics)) 921 if checkpoint_path: 922 raise ValueError( 923 'checkpoint must be `None` if status is not {}; got status {}, ' 924 'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status, 925 checkpoint_path)) 926 return super(_EvalResult, cls).__new__(cls, status, metrics, 927 checkpoint_path) 928 929 # Now, evaluated case. 930 assert status == _EvalStatus.EVALUATED 931 932 # Validates metrics. 933 if not metrics: 934 raise ValueError( 935 'Internal error: `Estimator.evaluate` should never return empty ' 936 'metrics.') 937 if not isinstance(metrics, dict): 938 raise TypeError( 939 '`Estimator.evaluate` should return dict. Given {}.'.format( 940 type(metrics))) 941 if ops.GraphKeys.GLOBAL_STEP not in metrics: 942 raise ValueError( 943 'Internal error: `Estimator.evaluate` result should have ' 944 '`global_step` in result. Given {}'.format(metrics)) 945 946 # Validates checkpoint_path. 947 if not checkpoint_path: 948 raise ValueError( 949 'Internal error: `checkpoint_path` should never be empty.') 950 951 return super(_EvalResult, cls).__new__(cls, status, metrics, 952 checkpoint_path) 953 954 955class _ContinuousEvalListener(object): 956 """Interface for listeners that take action before or after evaluation.""" 957 958 def before_eval(self): 959 """Called before evaluation. 960 961 Returns: 962 `False` if you want to skip the current evaluation and early stop the 963 continuous evaluation; `True` otherwise. 964 """ 965 return True 966 967 def after_eval(self, eval_result): 968 """Called after the evaluation is executed. 969 970 Args: 971 eval_result: An `_EvalResult` instance. 972 973 Returns: 974 False if you want to early stop continuous evaluation; `True` otherwise. 975 """ 976 del eval_result 977 return True 978