1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains functions for evaluation and summarization of metrics.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22import time 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import state_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.training import basic_session_run_hooks 33from tensorflow.python.training import monitored_session 34from tensorflow.python.training import session_run_hook 35 36 37def _get_or_create_eval_step(): 38 """Gets or creates the eval step `Tensor`. 39 40 Returns: 41 A `Tensor` representing a counter for the evaluation step. 42 43 Raises: 44 ValueError: If multiple `Tensors` have been added to the 45 `tf.GraphKeys.EVAL_STEP` collection. 46 """ 47 graph = ops.get_default_graph() 48 eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP) 49 if len(eval_steps) == 1: 50 return eval_steps[0] 51 elif len(eval_steps) > 1: 52 raise ValueError('Multiple tensors added to tf.GraphKeys.EVAL_STEP') 53 else: 54 counter = variable_scope.get_variable( 55 'eval_step', 56 shape=[], 57 dtype=dtypes.int64, 58 initializer=init_ops.zeros_initializer(), 59 trainable=False, 60 collections=[ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.EVAL_STEP]) 61 return counter 62 63 64def _get_latest_eval_step_value(update_ops): 65 """Gets the eval step `Tensor` value after running `update_ops`. 66 67 Args: 68 update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, 69 which are run before reading the eval step value. 70 71 Returns: 72 A `Tensor` representing the value for the evaluation step. 73 """ 74 if isinstance(update_ops, dict): 75 update_ops = list(update_ops.values()) 76 77 with ops.control_dependencies(update_ops): 78 return array_ops.identity(_get_or_create_eval_step().read_value()) 79 80 81class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook): 82 """Run hook used by the evaluation routines to run the `eval_ops` N times.""" 83 84 def __init__(self, num_evals, steps_per_run=1): 85 """Constructs the run hook. 86 87 Args: 88 num_evals: The number of evaluations to run for. if set to None, will 89 iterate the dataset until all inputs are exhausted. 90 steps_per_run: Number of steps executed per run call. 91 """ 92 self._num_evals = num_evals 93 self._evals_completed = None 94 self._steps_per_run_initial_value = steps_per_run 95 96 def _set_evals_completed_tensor(self, updated_eval_step): 97 self._evals_completed = updated_eval_step 98 99 def begin(self): 100 self._steps_per_run_variable = \ 101 basic_session_run_hooks.get_or_create_steps_per_run_variable() 102 103 def after_create_session(self, session, coord): 104 # Update number of steps to run in the first run call 105 if self._num_evals is None: 106 steps = self._steps_per_run_initial_value 107 else: 108 steps = min(self._steps_per_run_initial_value, self._num_evals) 109 self._steps_per_run_variable.load(steps, session=session) 110 111 def before_run(self, run_context): 112 return session_run_hook.SessionRunArgs({ 113 'evals_completed': self._evals_completed 114 }) 115 116 def after_run(self, run_context, run_values): 117 evals_completed = run_values.results['evals_completed'] 118 # Update number of steps to run in the next iteration 119 if self._num_evals is None: 120 steps = self._steps_per_run_initial_value 121 else: 122 steps = min(self._num_evals - evals_completed, 123 self._steps_per_run_initial_value) 124 self._steps_per_run_variable.load(steps, session=run_context.session) 125 126 if self._num_evals is None: 127 logging.info('Evaluation [%d]', evals_completed) 128 else: 129 logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) 130 if self._num_evals is not None and evals_completed >= self._num_evals: 131 run_context.request_stop() 132 133 134class _StopAfterNEvalsHook(session_run_hook.SessionRunHook): 135 """Run hook used by the evaluation routines to run the `eval_ops` N times.""" 136 137 def __init__(self, num_evals, log_progress=True): 138 """Constructs the run hook. 139 140 Args: 141 num_evals: The number of evaluations to run for. if set to None, will 142 iterate the dataset until all inputs are exhausted. 143 log_progress: Whether to log evaluation progress, defaults to True. 144 """ 145 # The number of evals to run for. 146 self._num_evals = num_evals 147 self._evals_completed = None 148 self._log_progress = log_progress 149 # Reduce logging frequency if there are 20 or more evaluations. 150 self._log_frequency = (1 if (num_evals is None or num_evals < 20) 151 else math.floor(num_evals / 10.)) 152 153 def _set_evals_completed_tensor(self, updated_eval_step): 154 self._evals_completed = updated_eval_step 155 156 def before_run(self, run_context): 157 return session_run_hook.SessionRunArgs({ 158 'evals_completed': self._evals_completed 159 }) 160 161 def after_run(self, run_context, run_values): 162 evals_completed = run_values.results['evals_completed'] 163 if self._log_progress: 164 if self._num_evals is None: 165 logging.info('Evaluation [%d]', evals_completed) 166 else: 167 if ((evals_completed % self._log_frequency) == 0 or 168 (self._num_evals == evals_completed)): 169 logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals) 170 if self._num_evals is not None and evals_completed >= self._num_evals: 171 run_context.request_stop() 172 173 174def _evaluate_once(checkpoint_path, 175 master='', 176 scaffold=None, 177 eval_ops=None, 178 feed_dict=None, 179 final_ops=None, 180 final_ops_feed_dict=None, 181 hooks=None, 182 config=None): 183 """Evaluates the model at the given checkpoint path. 184 185 During a single evaluation, the `eval_ops` is run until the session is 186 interrupted or requested to finish. This is typically requested via a 187 `tf.contrib.training.StopAfterNEvalsHook` which results in `eval_ops` running 188 the requested number of times. 189 190 Optionally, a user can pass in `final_ops`, a single `Tensor`, a list of 191 `Tensors` or a dictionary from names to `Tensors`. The `final_ops` is 192 evaluated a single time after `eval_ops` has finished running and the fetched 193 values of `final_ops` are returned. If `final_ops` is left as `None`, then 194 `None` is returned. 195 196 One may also consider using a `tf.contrib.training.SummaryAtEndHook` to record 197 summaries after the `eval_ops` have run. If `eval_ops` is `None`, the 198 summaries run immediately after the model checkpoint has been restored. 199 200 Note that `evaluate_once` creates a local variable used to track the number of 201 evaluations run via `tf.contrib.training.get_or_create_eval_step`. 202 Consequently, if a custom local init op is provided via a `scaffold`, the 203 caller should ensure that the local init op also initializes the eval step. 204 205 Args: 206 checkpoint_path: The path to a checkpoint to use for evaluation. 207 master: The BNS address of the TensorFlow master. 208 scaffold: An tf.train.Scaffold instance for initializing variables and 209 restoring variables. Note that `scaffold.init_fn` is used by the function 210 to restore the checkpoint. If you supply a custom init_fn, then it must 211 also take care of restoring the model from its checkpoint. 212 eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 213 to `Tensors`, which is run until the session is requested to stop, 214 commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. 215 feed_dict: The feed dictionary to use when executing the `eval_ops`. 216 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 217 to `Tensors`. 218 final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. 219 hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the 220 evaluation loop. 221 config: An instance of `tf.ConfigProto` that will be used to 222 configure the `Session`. If left as `None`, the default will be used. 223 224 Returns: 225 The fetched values of `final_ops` or `None` if `final_ops` is `None`. 226 """ 227 eval_step = _get_or_create_eval_step() 228 229 # Prepare the run hooks. 230 hooks = list(hooks or []) 231 232 if eval_ops is not None: 233 if any(isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks): 234 steps_per_run_variable = \ 235 basic_session_run_hooks.get_or_create_steps_per_run_variable() 236 update_eval_step = state_ops.assign_add( 237 eval_step, 238 math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype), 239 use_locking=True) 240 else: 241 update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True) 242 243 if isinstance(eval_ops, dict): 244 eval_ops['update_eval_step'] = update_eval_step 245 elif isinstance(eval_ops, (tuple, list)): 246 eval_ops = list(eval_ops) + [update_eval_step] 247 else: 248 eval_ops = [eval_ops, update_eval_step] 249 250 eval_step_value = _get_latest_eval_step_value(eval_ops) 251 252 for h in hooks: 253 if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)): 254 h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access 255 256 logging.info('Starting evaluation at ' + 257 time.strftime('%Y-%m-%dT%H:%M:%SZ', time.localtime())) 258 259 # Prepare the session creator. 260 session_creator = monitored_session.ChiefSessionCreator( 261 scaffold=scaffold, 262 checkpoint_filename_with_path=checkpoint_path, 263 master=master, 264 config=config) 265 266 final_ops_hook = basic_session_run_hooks.FinalOpsHook( 267 final_ops, final_ops_feed_dict) 268 hooks.append(final_ops_hook) 269 270 with monitored_session.MonitoredSession( 271 session_creator=session_creator, hooks=hooks) as session: 272 if eval_ops is not None: 273 while not session.should_stop(): 274 session.run(eval_ops, feed_dict) 275 276 logging.info('Finished evaluation at ' + 277 time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime())) 278 return final_ops_hook.final_ops_values 279