1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functions for evaluation and summarization of metrics.
16
17The evaluation.py module contains helper functions for evaluating TensorFlow
18modules using a variety of metrics and summarizing the results.
19
20****************************************
21* Evaluating a Checkpointed Model Once *
22****************************************
23
24Once we've trained a model, we'll want to evaluate it. The simplest way to do
25this is to evaluate the performance of a saved model a single time. In order
26to do this, we can specify a number of metrics we'll want to evaluate as well
27as specify the summaries we want to save to disk. Furthermore, we can print
28out the metrics values to stdout:
29
30  # Specify where the checkpoint is stored:
31  checkpoint_path = ...
32
33  # Create model and obtain the predictions:
34  images, labels = LoadData(...)
35  predictions = MyModel(images)
36
37  # Choose the metrics to compute:
38  names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({
39      "accuracy": tf.metrics.accuracy(labels, predictions),
40      "mse": tf.metrics.mean_squared_error(labels, predictions),
41  })
42
43  # Define the summaries to write:
44  for metric_name, metric_value in metrics_to_values.iteritems():
45    tf.summary.scalar(metric_name, metric_value)
46
47  checkpoint_dir = '/tmp/my_model_dir/'
48  log_dir = '/tmp/my_model_eval/'
49
50  # We'll evaluate 1000 batches:
51  num_evals = 1000
52
53  names_to_values = evaluate_once(
54      checkpoint_path=checkpoint_path,
55      eval_ops=names_to_updates.values(),
56      final_ops=names_to_values,
57      hooks=[
58            tf.contrib.training.StopAfterNEvalsHook(num_evals),
59            tf.contrib.training.SummaryAtEndHook(logdir),
60      ],
61      config=None)
62
63  for name in names_to_values:
64    print('Metric %s has value %f.' % (name, names_to_values[name]))
65
66
67************************************************
68* Evaluating a Checkpointed Model with Metrics *
69************************************************
70
71Often, one wants to evaluate a model checkpoint saved on disk. This can be
72performed once or repeatedly on a set schedule.
73
74To evaluate a particular model, users define zero or more metrics and zero or
75more summaries and call the evaluate_repeatedly method:
76
77  # Create model and obtain the predictions:
78  images, labels = LoadData(...)
79  predictions = MyModel(images)
80
81  # Choose the metrics to compute:
82  names_to_values, names_to_updates = tf.contrib.metrics.aggregate_metric_map({
83      "accuracy": tf.metrics.accuracy(labels, predictions),
84      "mse": tf.metrics.mean_squared_error(labels, predictions),
85  })
86
87  # Define the summaries to write:
88  for metric_name, metric_value in metrics_to_values.iteritems():
89    tf.summary.scalar(metric_name, metric_value)
90
91  checkpoint_dir = '/tmp/my_model_dir/'
92  log_dir = '/tmp/my_model_eval/'
93
94  # We'll evaluate 1000 batches:
95  num_evals = 1000
96
97  # Evaluate every 10 minutes:
98  tf.contrib.training.evaluate_repeatedly(
99      checkpoint_dir,
100      eval_ops=names_to_updates.values(),
101      hooks=[
102            tf.contrib.training.StopAfterNEvalsHook(num_evals),
103            tf.contrib.training.SummaryAtEndHook(logdir),
104      ],
105      eval_interval_secs=600)
106
107*******************************************************
108* Evaluating a Checkpointed Model with Summaries Only *
109*******************************************************
110
111At times, an evaluation can be performed without metrics at all but rather
112with only summaries. The user need only leave out the 'eval_ops' argument:
113
114  # Create model and obtain the predictions:
115  images, labels = LoadData(...)
116  predictions = MyModel(images)
117
118  # Define the summaries to write:
119  tf.summary.scalar(...)
120  tf.summary.histogram(...)
121
122  checkpoint_dir = '/tmp/my_model_dir/'
123  log_dir = '/tmp/my_model_eval/'
124
125  # Evaluate once every 10 minutes.
126  tf.contrib.training.evaluate_repeatedly(
127      checkpoint_dir,
128      hooks=[
129          tf.contrib.training.SummaryAtEndHook(logdir),
130      ],
131      eval_interval_secs=600)
132
133"""
134
135from __future__ import absolute_import
136from __future__ import division
137from __future__ import print_function
138
139import time
140
141from tensorflow.python.ops import state_ops
142from tensorflow.python.platform import tf_logging as logging
143from tensorflow.python.summary import summary
144from tensorflow.python.training import basic_session_run_hooks
145from tensorflow.python.training import checkpoint_management
146from tensorflow.python.training import evaluation
147from tensorflow.python.training import monitored_session
148from tensorflow.python.training import session_run_hook
149from tensorflow.python.training import training_util
150
151__all__ = [
152    'StopAfterNEvalsHook',
153    'SummaryAtEndHook',
154    'checkpoints_iterator',
155    'evaluate_once',
156    'evaluate_repeatedly',
157    'get_or_create_eval_step',
158    'wait_for_new_checkpoint',
159]
160
161# pylint: disable=protected-access
162# pylint: disable=invalid-name
163StopAfterNEvalsHook = evaluation._StopAfterNEvalsHook
164evaluate_once = evaluation._evaluate_once
165get_or_create_eval_step = evaluation._get_or_create_eval_step
166
167# pylint: enable=invalid-name
168# pylint: enable=protected-access
169
170
171def wait_for_new_checkpoint(checkpoint_dir,
172                            last_checkpoint=None,
173                            seconds_to_sleep=1,
174                            timeout=None):
175  """Waits until a new checkpoint file is found.
176
177  Args:
178    checkpoint_dir: The directory in which checkpoints are saved.
179    last_checkpoint: The last checkpoint path used or `None` if we're expecting
180      a checkpoint for the first time.
181    seconds_to_sleep: The number of seconds to sleep for before looking for a
182      new checkpoint.
183    timeout: The maximum amount of time to wait. If left as `None`, then the
184      process will wait indefinitely.
185
186  Returns:
187    a new checkpoint path, or None if the timeout was reached.
188  """
189  logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
190  stop_time = time.time() + timeout if timeout is not None else None
191  while True:
192    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
193    if checkpoint_path is None or checkpoint_path == last_checkpoint:
194      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
195        return None
196      time.sleep(seconds_to_sleep)
197    else:
198      logging.info('Found new checkpoint at %s', checkpoint_path)
199      return checkpoint_path
200
201
202def checkpoints_iterator(checkpoint_dir,
203                         min_interval_secs=0,
204                         timeout=None,
205                         timeout_fn=None):
206  """Continuously yield new checkpoint files as they appear.
207
208  The iterator only checks for new checkpoints when control flow has been
209  reverted to it. This means it can miss checkpoints if your code takes longer
210  to run between iterations than `min_interval_secs` or the interval at which
211  new checkpoints are written.
212
213  The `timeout` argument is the maximum number of seconds to block waiting for
214  a new checkpoint.  It is used in combination with the `timeout_fn` as
215  follows:
216
217  * If the timeout expires and no `timeout_fn` was specified, the iterator
218    stops yielding.
219  * If a `timeout_fn` was specified, that function is called and if it returns
220    a true boolean value the iterator stops yielding.
221  * If the function returns a false boolean value then the iterator resumes the
222    wait for new checkpoints.  At this point the timeout logic applies again.
223
224  This behavior gives control to callers on what to do if checkpoints do not
225  come fast enough or stop being generated.  For example, if callers have a way
226  to detect that the training has stopped and know that no new checkpoints
227  will be generated, they can provide a `timeout_fn` that returns `True` when
228  the training has stopped.  If they know that the training is still going on
229  they return `False` instead.
230
231  Args:
232    checkpoint_dir: The directory in which checkpoints are saved.
233    min_interval_secs: The minimum number of seconds between yielding
234      checkpoints.
235    timeout: The maximum amount of time to wait between checkpoints. If left as
236      `None`, then the process will wait indefinitely.
237    timeout_fn: Optional function to call after a timeout.  If the function
238      returns True, then it means that no new checkpoints will be generated and
239      the iterator will exit.  The function is called with no arguments.
240
241  Yields:
242    String paths to latest checkpoint files as they arrive.
243  """
244  checkpoint_path = None
245  while True:
246    new_checkpoint_path = wait_for_new_checkpoint(
247        checkpoint_dir, checkpoint_path, timeout=timeout)
248    if new_checkpoint_path is None:
249      if not timeout_fn:
250        # timed out
251        logging.info('Timed-out waiting for a checkpoint.')
252        return
253      if timeout_fn():
254        # The timeout_fn indicated that we are truly done.
255        return
256      else:
257        # The timeout_fn indicated that more checkpoints may come.
258        continue
259    start = time.time()
260    checkpoint_path = new_checkpoint_path
261    yield checkpoint_path
262    time_to_next_eval = start + min_interval_secs - time.time()
263    if time_to_next_eval > 0:
264      time.sleep(time_to_next_eval)
265
266
267class SummaryAtEndHook(session_run_hook.SessionRunHook):
268  """A run hook that saves a summary with the results of evaluation."""
269
270  def __init__(self,
271               log_dir=None,
272               summary_writer=None,
273               summary_op=None,
274               feed_dict=None):
275    """Constructs the Summary Hook.
276
277    Args:
278      log_dir: The directory where the summary events are saved to.  Used only
279        when `summary_writer` is not specified.
280      summary_writer: A `tf.summary.FileWriter` to write summary events with.
281      summary_op: The summary op to run. If left as `None`, then all summaries
282        in the tf.GraphKeys.SUMMARIES collection are used.
283      feed_dict: An optional feed dictionary to use when evaluating the
284        summaries.
285
286    Raises:
287      ValueError: If both `log_dir` and `summary_writer` are `None`.
288    """
289    self._summary_op = summary_op
290    self._replace_summary_op = summary_op is None
291    self._feed_dict = feed_dict
292    self._summary_writer = summary_writer
293    self._log_dir = log_dir
294    if self._log_dir is None and self._summary_writer is None:
295      raise ValueError('One of log_dir or summary_writer should be used.')
296
297  def begin(self):
298    if self._replace_summary_op:
299      # This can still remain None if there are no summaries.
300      self._summary_op = summary.merge_all()
301    self._global_step = training_util.get_or_create_global_step()
302
303  def after_create_session(self, session, coord):
304    if self._summary_writer is None and self._log_dir:
305      self._summary_writer = summary.FileWriterCache.get(self._log_dir)
306
307  def end(self, session):
308    if self._summary_op is not None:
309      global_step = training_util.global_step(session, self._global_step)
310      summary_str = session.run(self._summary_op, self._feed_dict)
311      if self._summary_writer:
312        self._summary_writer.add_summary(summary_str, global_step)
313    if self._summary_writer:
314      self._summary_writer.flush()
315
316
317def _scaffold_with_init(scaffold, saver, checkpoint_path):
318  """Creates a scaffold that loads the given checkpoint using an init_fn.
319
320  Args:
321    scaffold: The scaffold to copy.
322    saver: The saver to use when restoring the checkpoint.
323    checkpoint_path: An absolute path to a checkpoint.
324
325  Returns:
326    A scaffold with an init_fn that loads the given checkpoint. If the scaffold
327    provided already has an init_fn, the scaffold is returned unchanged.
328  """
329
330  def restore_checkpoint(_, session):
331    saver.restore(session, checkpoint_path)
332
333  if not scaffold.init_fn:
334    scaffold = monitored_session.Scaffold(
335        init_op=scaffold.init_op,
336        init_feed_dict=scaffold.init_feed_dict,
337        init_fn=restore_checkpoint,
338        ready_op=scaffold.ready_op,
339        local_init_op=scaffold.local_init_op,
340        summary_op=scaffold.summary_op,
341        saver=scaffold.saver)
342  return scaffold
343
344
345def evaluate_repeatedly(checkpoint_dir,
346                        master='',
347                        scaffold=None,
348                        eval_ops=None,
349                        feed_dict=None,
350                        final_ops=None,
351                        final_ops_feed_dict=None,
352                        eval_interval_secs=60,
353                        hooks=None,
354                        config=None,
355                        max_number_of_evaluations=None,
356                        timeout=None,
357                        timeout_fn=None):
358  """Repeatedly searches for a checkpoint in `checkpoint_dir` and evaluates it.
359
360  During a single evaluation, the `eval_ops` is run until the session is
361  interrupted or requested to finish. This is typically requested via a
362  `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
363  the requested number of times.
364
365  Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
366  `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
367  evaluated a single time after `eval_ops` has finished running and the fetched
368  values of `final_ops` are returned. If `final_ops` is left as `None`, then
369  `None` is returned.
370
371  One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
372  summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
373  summaries run immediately after the model checkpoint has been restored.
374
375  Note that `evaluate_once` creates a local variable used to track the number of
376  evaluations run via `tf.contrib.training.get_or_create_eval_step`.
377  Consequently, if a custom local init op is provided via a `scaffold`, the
378  caller should ensure that the local init op also initializes the eval step.
379
380  Args:
381    checkpoint_dir: The directory where checkpoints are stored.
382    master: The address of the TensorFlow master.
383    scaffold: An tf.train.Scaffold instance for initializing variables and
384      restoring variables. Note that `scaffold.init_fn` is used by the function
385      to restore the checkpoint. If you supply a custom init_fn, then it must
386      also take care of restoring the model from its checkpoint.
387    eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
388      to `Tensors`, which is run until the session is requested to stop,
389      commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
390    feed_dict: The feed dictionary to use when executing the `eval_ops`.
391    final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
392      to `Tensors`.
393    final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
394    eval_interval_secs: The minimum number of seconds between evaluations.
395    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
396      evaluation loop.
397    config: An instance of `tf.ConfigProto` that will be used to
398      configure the `Session`. If left as `None`, the default will be used.
399    max_number_of_evaluations: The maximum times to run the evaluation. If left
400      as `None`, then evaluation runs indefinitely.
401    timeout: The maximum amount of time to wait between checkpoints. If left as
402      `None`, then the process will wait indefinitely.
403    timeout_fn: Optional function to call after a timeout.  If the function
404      returns True, then it means that no new checkpoints will be generated and
405      the iterator will exit.  The function is called with no arguments.
406
407  Returns:
408    The fetched values of `final_ops` or `None` if `final_ops` is `None`.
409  """
410  eval_step = get_or_create_eval_step()
411
412  # Prepare the run hooks.
413  hooks = hooks or []
414
415  if eval_ops is not None:
416    update_eval_step = state_ops.assign_add(eval_step, 1)
417
418    for h in hooks:
419      if isinstance(h, StopAfterNEvalsHook):
420        h._set_evals_completed_tensor(update_eval_step)  # pylint: disable=protected-access
421
422    if isinstance(eval_ops, dict):
423      eval_ops['update_eval_step'] = update_eval_step
424    elif isinstance(eval_ops, (tuple, list)):
425      eval_ops = list(eval_ops) + [update_eval_step]
426    else:
427      eval_ops = [eval_ops, update_eval_step]
428
429  final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
430                                                        final_ops_feed_dict)
431  hooks.append(final_ops_hook)
432
433  num_evaluations = 0
434  for checkpoint_path in checkpoints_iterator(
435      checkpoint_dir,
436      min_interval_secs=eval_interval_secs,
437      timeout=timeout,
438      timeout_fn=timeout_fn):
439
440    session_creator = monitored_session.ChiefSessionCreator(
441        scaffold=scaffold,
442        checkpoint_filename_with_path=checkpoint_path,
443        master=master,
444        config=config)
445
446    with monitored_session.MonitoredSession(
447        session_creator=session_creator, hooks=hooks) as session:
448      logging.info('Starting evaluation at ' + time.strftime(
449          '%Y-%m-%d-%H:%M:%S', time.gmtime()))
450      if eval_ops is not None:
451        while not session.should_stop():
452          session.run(eval_ops, feed_dict)
453
454      logging.info('Finished evaluation at ' + time.strftime(
455          '%Y-%m-%d-%H:%M:%S', time.gmtime()))
456    num_evaluations += 1
457
458    if (max_number_of_evaluations is not None and
459        num_evaluations >= max_number_of_evaluations):
460      return final_ops_hook.final_ops_values
461
462  return final_ops_hook.final_ops_values
463