1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains functions for evaluation and summarization of metrics."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22import time
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import state_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.training import basic_session_run_hooks
33from tensorflow.python.training import monitored_session
34from tensorflow.python.training import session_run_hook
35
36
37def _get_or_create_eval_step():
38  """Gets or creates the eval step `Tensor`.
39
40  Returns:
41    A `Tensor` representing a counter for the evaluation step.
42
43  Raises:
44    ValueError: If multiple `Tensors` have been added to the
45      `tf.GraphKeys.EVAL_STEP` collection.
46  """
47  graph = ops.get_default_graph()
48  eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP)
49  if len(eval_steps) == 1:
50    return eval_steps[0]
51  elif len(eval_steps) > 1:
52    raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP')
53  else:
54    counter = variable_scope.get_variable(
55        'eval_step',
56        shape=[],
57        dtype=dtypes.int64,
58        initializer=init_ops.zeros_initializer(),
59        trainable=False,
60        collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP])
61    return counter
62
63
64def _get_latest_eval_step_value(update_ops):
65  """Gets the eval step `Tensor` value after running `update_ops`.
66
67  Args:
68    update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
69        which are run before reading the eval step value.
70
71  Returns:
72    A `Tensor` representing the value for the evaluation step.
73  """
74  if isinstance(update_ops, dict):
75    update_ops = list(update_ops.values())
76
77  with ops.control_dependencies(update_ops):
78    return array_ops.identity(_get_or_create_eval_step().read_value())
79
80
81class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
82  """Run hook used by the evaluation routines to run the `eval_ops` N times."""
83
84  def __init__(self, num_evals, steps_per_run=1):
85    """Constructs the run hook.
86
87    Args:
88      num_evals: The number of evaluations to run for. if set to None, will
89        iterate the dataset until all inputs are exhausted.
90      steps_per_run: Number of steps executed per run call.
91    """
92    self._num_evals = num_evals
93    self._evals_completed = None
94    self._steps_per_run_initial_value = steps_per_run
95
96  def _set_evals_completed_tensor(self, updated_eval_step):
97    self._evals_completed = updated_eval_step
98
99  def begin(self):
100    self._steps_per_run_variable = \
101        basic_session_run_hooks.get_or_create_steps_per_run_variable()
102
103  def after_create_session(self, session, coord):
104    # Update number of steps to run in the first run call
105    if  self._num_evals is None:
106      steps = self._steps_per_run_initial_value
107    else:
108      steps = min(self._steps_per_run_initial_value, self._num_evals)
109    self._steps_per_run_variable.load(steps, session=session)
110
111  def before_run(self, run_context):
112    return session_run_hook.SessionRunArgs({
113        'evals_completed': self._evals_completed
114    })
115
116  def after_run(self, run_context, run_values):
117    evals_completed = run_values.results['evals_completed']
118    # Update number of steps to run in the next iteration
119    if  self._num_evals is None:
120      steps = self._steps_per_run_initial_value
121    else:
122      steps = min(self._num_evals - evals_completed,
123                  self._steps_per_run_initial_value)
124    self._steps_per_run_variable.load(steps, session=run_context.session)
125
126    if self._num_evals is None:
127      logging.info('Evaluation [%d]', evals_completed)
128    else:
129      logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
130    if self._num_evals is not None and evals_completed >= self._num_evals:
131      run_context.request_stop()
132
133
134class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
135  """Run hook used by the evaluation routines to run the `eval_ops` N times."""
136
137  def __init__(self, num_evals, log_progress=True):
138    """Constructs the run hook.
139
140    Args:
141      num_evals: The number of evaluations to run for. if set to None, will
142        iterate the dataset until all inputs are exhausted.
143      log_progress: Whether to log evaluation progress, defaults to True.
144    """
145    # The number of evals to run for.
146    self._num_evals = num_evals
147    self._evals_completed = None
148    self._log_progress = log_progress
149    # Reduce logging frequency if there are 20 or more evaluations.
150    self._log_frequency = (1 if (num_evals is None or num_evals < 20)
151                           else math.floor(num_evals / 10.))
152
153  def _set_evals_completed_tensor(self, updated_eval_step):
154    self._evals_completed = updated_eval_step
155
156  def before_run(self, run_context):
157    return session_run_hook.SessionRunArgs({
158        'evals_completed': self._evals_completed
159    })
160
161  def after_run(self, run_context, run_values):
162    evals_completed = run_values.results['evals_completed']
163    if self._log_progress:
164      if self._num_evals is None:
165        logging.info('Evaluation [%d]', evals_completed)
166      else:
167        if ((evals_completed % self._log_frequency) == 0 or
168            (self._num_evals == evals_completed)):
169          logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
170    if self._num_evals is not None and evals_completed >= self._num_evals:
171      run_context.request_stop()
172
173
174def _evaluate_once(checkpoint_path,
175                   master='',
176                   scaffold=None,
177                   eval_ops=None,
178                   feed_dict=None,
179                   final_ops=None,
180                   final_ops_feed_dict=None,
181                   hooks=None,
182                   config=None):
183  """Evaluates the model at the given checkpoint path.
184
185  During a single evaluation, the `eval_ops` is run until the session is
186  interrupted or requested to finish. This is typically requested via a
187  `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running
188  the requested number of times.
189
190  Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of
191  `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is
192  evaluated a single time after `eval_ops` has finished running and the fetched
193  values of `final_ops` are returned. If `final_ops` is left as `None`, then
194  `None` is returned.
195
196  One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record
197  summaries after the `eval_ops` have run. If `eval_ops` is `None`, the
198  summaries run immediately after the model checkpoint has been restored.
199
200  Note that `evaluate_once` creates a local variable used to track the number of
201  evaluations run via `tf.contrib.training.get_or_create_eval_step`.
202  Consequently, if a custom local init op is provided via a `scaffold`, the
203  caller should ensure that the local init op also initializes the eval step.
204
205  Args:
206    checkpoint_path: The path to a checkpoint to use for evaluation.
207    master: The BNS address of the TensorFlow master.
208    scaffold: An tf.train.Scaffold instance for initializing variables and
209      restoring variables. Note that `scaffold.init_fn` is used by the function
210      to restore the checkpoint. If you supply a custom init_fn, then it must
211      also take care of restoring the model from its checkpoint.
212    eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
213      to `Tensors`, which is run until the session is requested to stop,
214      commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
215    feed_dict: The feed dictionary to use when executing the `eval_ops`.
216    final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
217      to `Tensors`.
218    final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
219    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
220      evaluation loop.
221    config: An instance of `tf.ConfigProto` that will be used to
222      configure the `Session`. If left as `None`, the default will be used.
223
224  Returns:
225    The fetched values of `final_ops` or `None` if `final_ops` is `None`.
226  """
227  eval_step = _get_or_create_eval_step()
228
229  # Prepare the run hooks.
230  hooks = list(hooks or [])
231
232  if eval_ops is not None:
233    if any(isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks):
234      steps_per_run_variable = \
235          basic_session_run_hooks.get_or_create_steps_per_run_variable()
236      update_eval_step = state_ops.assign_add(
237          eval_step,
238          math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
239          use_locking=True)
240    else:
241      update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
242
243    if isinstance(eval_ops, dict):
244      eval_ops['update_eval_step'] = update_eval_step
245    elif isinstance(eval_ops, (tuple, list)):
246      eval_ops = list(eval_ops) + [update_eval_step]
247    else:
248      eval_ops = [eval_ops, update_eval_step]
249
250    eval_step_value = _get_latest_eval_step_value(eval_ops)
251
252    for h in hooks:
253      if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
254        h._set_evals_completed_tensor(eval_step_value)  # pylint: disable=protected-access
255
256  logging.info('Starting evaluation at ' +
257               time.strftime('%Y-%m-%dT%H:%M:%SZ', time.localtime()))
258
259  # Prepare the session creator.
260  session_creator = monitored_session.ChiefSessionCreator(
261      scaffold=scaffold,
262      checkpoint_filename_with_path=checkpoint_path,
263      master=master,
264      config=config)
265
266  final_ops_hook = basic_session_run_hooks.FinalOpsHook(
267      final_ops, final_ops_feed_dict)
268  hooks.append(final_ops_hook)
269
270  with monitored_session.MonitoredSession(
271      session_creator=session_creator, hooks=hooks) as session:
272    if eval_ops is not None:
273      while not session.should_stop():
274        session.run(eval_ops, feed_dict)
275
276  logging.info('Finished evaluation at ' +
277               time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()))
278  return final_ops_hook.final_ops_values
279