1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains functions for evaluation and summarization of metrics. 16 17The evaluation.py module contains helper functions for evaluating TensorFlow 18modules using a variety of metrics and summarizing the results. 19 20**************************************** 21* Evaluating a Checkpointed Model Once * 22**************************************** 23 24Once we've trained a model, we'll want to evaluate it. The simplest way to do 25this is to evaluate the performance of a saved model a single time. In order 26to do this, we can specify a number of metrics we'll want to evaluate as well 27as specify the summaries we want to save to disk. Furthermore, we can print 28out the metrics values to stdout: 29 30 # Specify where the checkpoint is stored: 31 checkpoint_path = ... 32 33 # Create model and obtain the predictions: 34 images, labels = LoadData(...) 35 predictions = MyModel(images) 36 37 # Choose the metrics to compute: 38 names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ 39 "accuracy": tf.metrics.accuracy(labels, predictions), 40 "mse": tf.metrics.mean_squared_error(labels, predictions), 41 }) 42 43 # Define the summaries to write: 44 for metric_name, metric_value in metrics_to_values.iteritems(): 45 tf.summary.scalar(metric_name, metric_value) 46 47 checkpoint_dir = '/tmp/my_model_dir/' 48 log_dir = '/tmp/my_model_eval/' 49 50 # We'll evaluate 1000 batches: 51 num_evals = 1000 52 53 names_to_values = evaluate_once( 54 checkpoint_path=checkpoint_path, 55 eval_ops=names_to_updates.values(), 56 final_ops=names_to_values, 57 hooks=[ 58 tf.contrib.training.StopAfterNEvalsHook(num_evals), 59 tf.contrib.training.SummaryAtEndHook(logdir), 60 ], 61 config=None) 62 63 for name in names_to_values: 64 print('Metric %s has value %f.' % (name, names_to_values[name])) 65 66 67************************************************ 68* Evaluating a Checkpointed Model with Metrics * 69************************************************ 70 71Often, one wants to evaluate a model checkpoint saved on disk. This can be 72performed once or repeatedly on a set schedule. 73 74To evaluate a particular model, users define zero or more metrics and zero or 75more summaries and call the evaluate_repeatedly method: 76 77 # Create model and obtain the predictions: 78 images, labels = LoadData(...) 79 predictions = MyModel(images) 80 81 # Choose the metrics to compute: 82 names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({ 83 "accuracy": tf.metrics.accuracy(labels, predictions), 84 "mse": tf.metrics.mean_squared_error(labels, predictions), 85 }) 86 87 # Define the summaries to write: 88 for metric_name, metric_value in metrics_to_values.iteritems(): 89 tf.summary.scalar(metric_name, metric_value) 90 91 checkpoint_dir = '/tmp/my_model_dir/' 92 log_dir = '/tmp/my_model_eval/' 93 94 # We'll evaluate 1000 batches: 95 num_evals = 1000 96 97 # Evaluate every 10 minutes: 98 tf.contrib.training.evaluate_repeatedly( 99 checkpoint_dir, 100 eval_ops=names_to_updates.values(), 101 hooks=[ 102 tf.contrib.training.StopAfterNEvalsHook(num_evals), 103 tf.contrib.training.SummaryAtEndHook(logdir), 104 ], 105 eval_interval_secs=600) 106 107******************************************************* 108* Evaluating a Checkpointed Model with Summaries Only * 109******************************************************* 110 111At times, an evaluation can be performed without metrics at all but rather 112with only summaries. The user need only leave out the 'eval_ops' argument: 113 114 # Create model and obtain the predictions: 115 images, labels = LoadData(...) 116 predictions = MyModel(images) 117 118 # Define the summaries to write: 119 tf.summary.scalar(...) 120 tf.summary.histogram(...) 121 122 checkpoint_dir = '/tmp/my_model_dir/' 123 log_dir = '/tmp/my_model_eval/' 124 125 # Evaluate once every 10 minutes. 126 tf.contrib.training.evaluate_repeatedly( 127 checkpoint_dir, 128 hooks=[ 129 tf.contrib.training.SummaryAtEndHook(logdir), 130 ], 131 eval_interval_secs=600) 132 133""" 134 135from __future__ import absolute_import 136from __future__ import division 137from __future__ import print_function 138 139import time 140 141from tensorflow.python.ops import state_ops 142from tensorflow.python.platform import tf_logging as logging 143from tensorflow.python.summary import summary 144from tensorflow.python.training import basic_session_run_hooks 145from tensorflow.python.training import checkpoint_management 146from tensorflow.python.training import evaluation 147from tensorflow.python.training import monitored_session 148from tensorflow.python.training import session_run_hook 149from tensorflow.python.training import training_util 150 151__all__ = [ 152 'StopAfterNEvalsHook', 153 'SummaryAtEndHook', 154 'checkpoints_iterator', 155 'evaluate_once', 156 'evaluate_repeatedly', 157 'get_or_create_eval_step', 158 'wait_for_new_checkpoint', 159] 160 161# pylint: disable=protected-access 162# pylint: disable=invalid-name 163StopAfterNEvalsHook = evaluation._StopAfterNEvalsHook 164evaluate_once = evaluation._evaluate_once 165get_or_create_eval_step = evaluation._get_or_create_eval_step 166 167# pylint: enable=invalid-name 168# pylint: enable=protected-access 169 170 171def wait_for_new_checkpoint(checkpoint_dir, 172 last_checkpoint=None, 173 seconds_to_sleep=1, 174 timeout=None): 175 """Waits until a new checkpoint file is found. 176 177 Args: 178 checkpoint_dir: The directory in which checkpoints are saved. 179 last_checkpoint: The last checkpoint path used or `None` if we're expecting 180 a checkpoint for the first time. 181 seconds_to_sleep: The number of seconds to sleep for before looking for a 182 new checkpoint. 183 timeout: The maximum amount of time to wait. If left as `None`, then the 184 process will wait indefinitely. 185 186 Returns: 187 a new checkpoint path, or None if the timeout was reached. 188 """ 189 logging.info('Waiting for new checkpoint at %s', checkpoint_dir) 190 stop_time = time.time() + timeout if timeout is not None else None 191 while True: 192 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) 193 if checkpoint_path is None or checkpoint_path == last_checkpoint: 194 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 195 return None 196 time.sleep(seconds_to_sleep) 197 else: 198 logging.info('Found new checkpoint at %s', checkpoint_path) 199 return checkpoint_path 200 201 202def checkpoints_iterator(checkpoint_dir, 203 min_interval_secs=0, 204 timeout=None, 205 timeout_fn=None): 206 """Continuously yield new checkpoint files as they appear. 207 208 The iterator only checks for new checkpoints when control flow has been 209 reverted to it. This means it can miss checkpoints if your code takes longer 210 to run between iterations than `min_interval_secs` or the interval at which 211 new checkpoints are written. 212 213 The `timeout` argument is the maximum number of seconds to block waiting for 214 a new checkpoint. It is used in combination with the `timeout_fn` as 215 follows: 216 217 * If the timeout expires and no `timeout_fn` was specified, the iterator 218 stops yielding. 219 * If a `timeout_fn` was specified, that function is called and if it returns 220 a true boolean value the iterator stops yielding. 221 * If the function returns a false boolean value then the iterator resumes the 222 wait for new checkpoints. At this point the timeout logic applies again. 223 224 This behavior gives control to callers on what to do if checkpoints do not 225 come fast enough or stop being generated. For example, if callers have a way 226 to detect that the training has stopped and know that no new checkpoints 227 will be generated, they can provide a `timeout_fn` that returns `True` when 228 the training has stopped. If they know that the training is still going on 229 they return `False` instead. 230 231 Args: 232 checkpoint_dir: The directory in which checkpoints are saved. 233 min_interval_secs: The minimum number of seconds between yielding 234 checkpoints. 235 timeout: The maximum amount of time to wait between checkpoints. If left as 236 `None`, then the process will wait indefinitely. 237 timeout_fn: Optional function to call after a timeout. If the function 238 returns True, then it means that no new checkpoints will be generated and 239 the iterator will exit. The function is called with no arguments. 240 241 Yields: 242 String paths to latest checkpoint files as they arrive. 243 """ 244 checkpoint_path = None 245 while True: 246 new_checkpoint_path = wait_for_new_checkpoint( 247 checkpoint_dir, checkpoint_path, timeout=timeout) 248 if new_checkpoint_path is None: 249 if not timeout_fn: 250 # timed out 251 logging.info('Timed-out waiting for a checkpoint.') 252 return 253 if timeout_fn(): 254 # The timeout_fn indicated that we are truly done. 255 return 256 else: 257 # The timeout_fn indicated that more checkpoints may come. 258 continue 259 start = time.time() 260 checkpoint_path = new_checkpoint_path 261 yield checkpoint_path 262 time_to_next_eval = start + min_interval_secs - time.time() 263 if time_to_next_eval > 0: 264 time.sleep(time_to_next_eval) 265 266 267class SummaryAtEndHook(session_run_hook.SessionRunHook): 268 """A run hook that saves a summary with the results of evaluation.""" 269 270 def __init__(self, 271 log_dir=None, 272 summary_writer=None, 273 summary_op=None, 274 feed_dict=None): 275 """Constructs the Summary Hook. 276 277 Args: 278 log_dir: The directory where the summary events are saved to. Used only 279 when `summary_writer` is not specified. 280 summary_writer: A `tf.summary.FileWriter` to write summary events with. 281 summary_op: The summary op to run. If left as `None`, then all summaries 282 in the tf.GraphKeys.SUMMARIES collection are used. 283 feed_dict: An optional feed dictionary to use when evaluating the 284 summaries. 285 286 Raises: 287 ValueError: If both `log_dir` and `summary_writer` are `None`. 288 """ 289 self._summary_op = summary_op 290 self._replace_summary_op = summary_op is None 291 self._feed_dict = feed_dict 292 self._summary_writer = summary_writer 293 self._log_dir = log_dir 294 if self._log_dir is None and self._summary_writer is None: 295 raise ValueError('One of log_dir or summary_writer should be used.') 296 297 def begin(self): 298 if self._replace_summary_op: 299 # This can still remain None if there are no summaries. 300 self._summary_op = summary.merge_all() 301 self._global_step = training_util.get_or_create_global_step() 302 303 def after_create_session(self, session, coord): 304 if self._summary_writer is None and self._log_dir: 305 self._summary_writer = summary.FileWriterCache.get(self._log_dir) 306 307 def end(self, session): 308 if self._summary_op is not None: 309 global_step = training_util.global_step(session, self._global_step) 310 summary_str = session.run(self._summary_op, self._feed_dict) 311 if self._summary_writer: 312 self._summary_writer.add_summary(summary_str, global_step) 313 if self._summary_writer: 314 self._summary_writer.flush() 315 316 317def _scaffold_with_init(scaffold, saver, checkpoint_path): 318 """Creates a scaffold that loads the given checkpoint using an init_fn. 319 320 Args: 321 scaffold: The scaffold to copy. 322 saver: The saver to use when restoring the checkpoint. 323 checkpoint_path: An absolute path to a checkpoint. 324 325 Returns: 326 A scaffold with an init_fn that loads the given checkpoint. If the scaffold 327 provided already has an init_fn, the scaffold is returned unchanged. 328 """ 329 330 def restore_checkpoint(_, session): 331 saver.restore(session, checkpoint_path) 332 333 if not scaffold.init_fn: 334 scaffold = monitored_session.Scaffold( 335 init_op=scaffold.init_op, 336 init_feed_dict=scaffold.init_feed_dict, 337 init_fn=restore_checkpoint, 338 ready_op=scaffold.ready_op, 339 local_init_op=scaffold.local_init_op, 340 summary_op=scaffold.summary_op, 341 saver=scaffold.saver) 342 return scaffold 343 344 345def evaluate_repeatedly(checkpoint_dir, 346 master='', 347 scaffold=None, 348 eval_ops=None, 349 feed_dict=None, 350 final_ops=None, 351 final_ops_feed_dict=None, 352 eval_interval_secs=60, 353 hooks=None, 354 config=None, 355 max_number_of_evaluations=None, 356 timeout=None, 357 timeout_fn=None): 358 """Repeatedly searches for a checkpoint in `checkpoint_dir` and evaluates it. 359 360 During a single evaluation, the `eval_ops` is run until the session is 361 interrupted or requested to finish. This is typically requested via a 362 `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running 363 the requested number of times. 364 365 Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of 366 `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is 367 evaluated a single time after `eval_ops` has finished running and the fetched 368 values of `final_ops` are returned. If `final_ops` is left as `None`, then 369 `None` is returned. 370 371 One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record 372 summaries after the `eval_ops` have run. If `eval_ops` is `None`, the 373 summaries run immediately after the model checkpoint has been restored. 374 375 Note that `evaluate_once` creates a local variable used to track the number of 376 evaluations run via `tf.contrib.training.get_or_create_eval_step`. 377 Consequently, if a custom local init op is provided via a `scaffold`, the 378 caller should ensure that the local init op also initializes the eval step. 379 380 Args: 381 checkpoint_dir: The directory where checkpoints are stored. 382 master: The address of the TensorFlow master. 383 scaffold: An tf.train.Scaffold instance for initializing variables and 384 restoring variables. Note that `scaffold.init_fn` is used by the function 385 to restore the checkpoint. If you supply a custom init_fn, then it must 386 also take care of restoring the model from its checkpoint. 387 eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 388 to `Tensors`, which is run until the session is requested to stop, 389 commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. 390 feed_dict: The feed dictionary to use when executing the `eval_ops`. 391 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 392 to `Tensors`. 393 final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. 394 eval_interval_secs: The minimum number of seconds between evaluations. 395 hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the 396 evaluation loop. 397 config: An instance of `tf.ConfigProto` that will be used to 398 configure the `Session`. If left as `None`, the default will be used. 399 max_number_of_evaluations: The maximum times to run the evaluation. If left 400 as `None`, then evaluation runs indefinitely. 401 timeout: The maximum amount of time to wait between checkpoints. If left as 402 `None`, then the process will wait indefinitely. 403 timeout_fn: Optional function to call after a timeout. If the function 404 returns True, then it means that no new checkpoints will be generated and 405 the iterator will exit. The function is called with no arguments. 406 407 Returns: 408 The fetched values of `final_ops` or `None` if `final_ops` is `None`. 409 """ 410 eval_step = get_or_create_eval_step() 411 412 # Prepare the run hooks. 413 hooks = hooks or [] 414 415 if eval_ops is not None: 416 update_eval_step = state_ops.assign_add(eval_step, 1) 417 418 for h in hooks: 419 if isinstance(h, StopAfterNEvalsHook): 420 h._set_evals_completed_tensor(update_eval_step) # pylint: disable=protected-access 421 422 if isinstance(eval_ops, dict): 423 eval_ops['update_eval_step'] = update_eval_step 424 elif isinstance(eval_ops, (tuple, list)): 425 eval_ops = list(eval_ops) + [update_eval_step] 426 else: 427 eval_ops = [eval_ops, update_eval_step] 428 429 final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops, 430 final_ops_feed_dict) 431 hooks.append(final_ops_hook) 432 433 num_evaluations = 0 434 for checkpoint_path in checkpoints_iterator( 435 checkpoint_dir, 436 min_interval_secs=eval_interval_secs, 437 timeout=timeout, 438 timeout_fn=timeout_fn): 439 440 session_creator = monitored_session.ChiefSessionCreator( 441 scaffold=scaffold, 442 checkpoint_filename_with_path=checkpoint_path, 443 master=master, 444 config=config) 445 446 with monitored_session.MonitoredSession( 447 session_creator=session_creator, hooks=hooks) as session: 448 logging.info('Starting evaluation at ' + time.strftime( 449 '%Y-%m-%d-%H:%M:%S', time.gmtime())) 450 if eval_ops is not None: 451 while not session.should_stop(): 452 session.run(eval_ops, feed_dict) 453 454 logging.info('Finished evaluation at ' + time.strftime( 455 '%Y-%m-%d-%H:%M:%S', time.gmtime())) 456 num_evaluations += 1 457 458 if (max_number_of_evaluations is not None and 459 num_evaluations >= max_number_of_evaluations): 460 return final_ops_hook.final_ops_values 461 462 return final_ops_hook.final_ops_values 463