1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Maintain moving averages of parameters.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import distribution_strategy_context 21from tensorflow.python.distribute import reduce_util as ds_reduce_util 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import init_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops import variables 30from tensorflow.python.training import slot_creator 31from tensorflow.python.util.tf_export import tf_export 32 33 34# TODO(touts): switch to variables.Variable. 35def assign_moving_average(variable, value, decay, zero_debias=True, name=None): 36 """Compute the moving average of a variable. 37 38 The moving average of 'variable' updated with 'value' is: 39 variable * decay + value * (1 - decay) 40 41 The returned Operation sets 'variable' to the newly computed moving average, 42 by performing this subtraction: 43 variable -= (1 - decay) * (variable - value) 44 45 Since variables that are initialized to a `0` value will be `0` biased, 46 `zero_debias` optionally enables scaling by the mathematically correct 47 debiasing factor of 48 1 - decay ** num_updates 49 See `ADAM: A Method for Stochastic Optimization` Section 3 for more details 50 (https://arxiv.org/abs/1412.6980). 51 52 The names of the debias shadow variables, by default, include both the scope 53 they were created in and the scope of the variables they debias. They are also 54 given a uniquifying-suffix. 55 56 E.g.: 57 58 ``` 59 with tf.variable_scope('scope1'): 60 with tf.variable_scope('scope2'): 61 var = tf.get_variable('foo') 62 update_1 = tf.assign_moving_average(var, 0.0, 1.0) 63 update_2 = tf.assign_moving_average(var, 0.0, 0.9) 64 65 # var.name: 'scope1/scope2/foo' 66 # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' 67 # 'scope1/scope2/scope1/scope2/foo/biased_1' 68 ``` 69 70 Args: 71 variable: A Variable. 72 value: A tensor with the same shape as 'variable'. 73 decay: A float Tensor or float value. The moving average decay. 74 zero_debias: A python bool. If true, assume the variable is 0-initialized 75 and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in 76 `_zero_debias` for more details. 77 name: Optional name of the returned operation. 78 79 Returns: 80 A tensor which if evaluated will compute and return the new moving average. 81 """ 82 def update_fn(v, value, decay=decay): 83 decay = ops.convert_to_tensor(1.0 - decay, name="decay") 84 if decay.dtype != v.dtype.base_dtype: 85 decay = math_ops.cast(decay, v.dtype.base_dtype) 86 if zero_debias: 87 update_delta = _zero_debias(v, value, decay) 88 else: 89 update_delta = (v - value) * decay 90 return state_ops.assign_sub(v, update_delta, name=scope) 91 92 with ops.name_scope(name, "AssignMovingAvg", 93 [variable, value, decay]) as scope: 94 replica_context = distribution_strategy_context.get_replica_context() 95 if replica_context: 96 # In a replica context, we update variable using the mean of value across 97 # replicas. 98 def merge_fn(strategy, v, value): 99 value = strategy.extended.reduce_to( 100 ds_reduce_util.ReduceOp.MEAN, value, v) 101 return strategy.extended.update(v, update_fn, args=(value,)) 102 103 return replica_context.merge_call(merge_fn, args=(variable, value)) 104 else: 105 strategy = distribution_strategy_context.get_cross_replica_context() 106 return strategy.extended.update(variable, update_fn, args=(value,)) 107 108 109def weighted_moving_average(value, 110 decay, 111 weight, 112 truediv=True, 113 collections=None, 114 name=None): 115 """Compute the weighted moving average of `value`. 116 117 Conceptually, the weighted moving average is: 118 `moving_average(value * weight) / moving_average(weight)`, 119 where a moving average updates by the rule 120 `new_value = decay * old_value + (1 - decay) * update` 121 Internally, this Op keeps moving average variables of both `value * weight` 122 and `weight`. 123 124 Args: 125 value: A numeric `Tensor`. 126 decay: A float `Tensor` or float value. The moving average decay. 127 weight: `Tensor` that keeps the current value of a weight. 128 Shape should be able to multiply `value`. 129 truediv: Boolean, if `True`, dividing by `moving_average(weight)` is 130 floating point division. If `False`, use division implied by dtypes. 131 collections: List of graph collections keys to add the internal variables 132 `value * weight` and `weight` to. 133 Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 134 name: Optional name of the returned operation. 135 Defaults to "WeightedMovingAvg". 136 137 Returns: 138 An Operation that updates and returns the weighted moving average. 139 """ 140 # Unlike assign_moving_average, the weighted moving average doesn't modify 141 # user-visible variables. It is the ratio of two internal variables, which are 142 # moving averages of the updates. Thus, the signature of this function is 143 # quite different than assign_moving_average. 144 if collections is None: 145 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 146 with variable_scope.variable_scope(name, "WeightedMovingAvg", 147 [value, weight, decay]) as scope: 148 value_x_weight_var = variable_scope.get_variable( 149 "value_x_weight", 150 shape=value.get_shape(), 151 dtype=value.dtype, 152 initializer=init_ops.zeros_initializer(), 153 trainable=False, 154 collections=collections) 155 weight_var = variable_scope.get_variable( 156 "weight", 157 shape=weight.get_shape(), 158 dtype=weight.dtype, 159 initializer=init_ops.zeros_initializer(), 160 trainable=False, 161 collections=collections) 162 numerator = assign_moving_average( 163 value_x_weight_var, value * weight, decay, zero_debias=False) 164 denominator = assign_moving_average( 165 weight_var, weight, decay, zero_debias=False) 166 167 if truediv: 168 return math_ops.truediv(numerator, denominator, name=scope.name) 169 else: 170 return math_ops.div(numerator, denominator, name=scope.name) 171 172 173def _zero_debias(unbiased_var, value, decay): 174 """Compute the delta required for a debiased Variable. 175 176 All exponential moving averages initialized with Tensors are initialized to 0, 177 and therefore are biased to 0. Variables initialized to 0 and used as EMAs are 178 similarly biased. This function creates the debias updated amount according to 179 a scale factor, as in https://arxiv.org/abs/1412.6980. 180 181 To demonstrate the bias the results from 0-initialization, take an EMA that 182 was initialized to `0` with decay `b`. After `t` timesteps of seeing the 183 constant `c`, the variable have the following value: 184 185 ``` 186 EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ... 187 = c*(1 - b^t) 188 ``` 189 190 To have the true value `c`, we would divide by the scale factor `1 - b^t`. 191 192 In order to perform debiasing, we use two shadow variables. One keeps track of 193 the biased estimate, and the other keeps track of the number of updates that 194 have occurred. 195 196 Args: 197 unbiased_var: A Variable representing the current value of the unbiased EMA. 198 value: A Tensor representing the most recent value. 199 decay: A Tensor representing `1-decay` for the EMA. 200 201 Returns: 202 The amount that the unbiased variable should be updated. Computing this 203 tensor will also update the shadow variables appropriately. 204 """ 205 with variable_scope.variable_scope( 206 unbiased_var.name[:-len(":0")], values=[unbiased_var, 207 value, decay]) as scope: 208 with ops.colocate_with(unbiased_var): 209 with ops.init_scope(): 210 biased_initializer = init_ops.zeros_initializer( 211 dtype=unbiased_var.dtype)(unbiased_var.get_shape()) 212 local_step_initializer = init_ops.zeros_initializer() 213 def _maybe_get_unique(name): 214 """Get name for a unique variable, if not `reuse=True`.""" 215 if variable_scope.get_variable_scope().reuse: 216 return name 217 vs_vars = [x.op.name for x in 218 variable_scope.get_variable_scope().global_variables()] 219 full_name = variable_scope.get_variable_scope().name + "/" + name 220 if full_name not in vs_vars: return name 221 idx = 1 222 while full_name + ("_%d" % idx) in vs_vars: 223 idx += 1 224 return name + ("_%d" % idx) 225 biased_var = variable_scope.get_variable( 226 _maybe_get_unique("biased"), initializer=biased_initializer, 227 trainable=False) 228 local_step = variable_scope.get_variable( 229 _maybe_get_unique("local_step"), 230 shape=[], 231 dtype=unbiased_var.dtype, 232 initializer=local_step_initializer, 233 trainable=False) 234 235 # Get an update ops for both shadow variables. 236 update_biased = state_ops.assign_sub(biased_var, 237 (biased_var - value) * decay, 238 name=scope.name) 239 update_local_step = local_step.assign_add(1) 240 241 # Compute the value of the delta to update the unbiased EMA. Make sure to 242 # use the new values of the biased variable and the local step. 243 with ops.control_dependencies([update_biased, update_local_step]): 244 # This function gets `1 - decay`, so use `1.0 - decay` in the exponent. 245 unbiased_ema_delta = (unbiased_var - biased_var.read_value() / 246 (1 - math_ops.pow( 247 1.0 - decay, local_step.read_value()))) 248 249 return unbiased_ema_delta 250 251 252@tf_export("train.ExponentialMovingAverage") 253class ExponentialMovingAverage(object): 254 """Maintains moving averages of variables by employing an exponential decay. 255 256 When training a model, it is often beneficial to maintain moving averages of 257 the trained parameters. Evaluations that use averaged parameters sometimes 258 produce significantly better results than the final trained values. 259 260 The `apply()` method adds shadow copies of trained variables and add ops that 261 maintain a moving average of the trained variables in their shadow copies. 262 It is used when building the training model. The ops that maintain moving 263 averages are typically run after each training step. 264 The `average()` and `average_name()` methods give access to the shadow 265 variables and their names. They are useful when building an evaluation 266 model, or when restoring a model from a checkpoint file. They help use the 267 moving averages in place of the last trained values for evaluations. 268 269 The moving averages are computed using exponential decay. You specify the 270 decay value when creating the `ExponentialMovingAverage` object. The shadow 271 variables are initialized with the same initial values as the trained 272 variables. When you run the ops to maintain the moving averages, each 273 shadow variable is updated with the formula: 274 275 `shadow_variable -= (1 - decay) * (shadow_variable - variable)` 276 277 This is mathematically equivalent to the classic formula below, but the use 278 of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless 279 updates to the variables: 280 281 `shadow_variable = decay * shadow_variable + (1 - decay) * variable` 282 283 Reasonable values for `decay` are close to 1.0, typically in the 284 multiple-nines range: 0.999, 0.9999, etc. 285 286 Example usage when creating a training model: 287 288 ```python 289 # Create variables. 290 var0 = tf.Variable(...) 291 var1 = tf.Variable(...) 292 # ... use the variables to build a training model... 293 ... 294 # Create an op that applies the optimizer. This is what we usually 295 # would use as a training op. 296 opt_op = opt.minimize(my_loss, [var0, var1]) 297 298 # Create an ExponentialMovingAverage object 299 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 300 301 with tf.control_dependencies([opt_op]): 302 # Create the shadow variables, and add ops to maintain moving averages 303 # of var0 and var1. This also creates an op that will update the moving 304 # averages after each training step. This is what we will use in place 305 # of the usual training op. 306 training_op = ema.apply([var0, var1]) 307 308 ...train the model by running training_op... 309 ``` 310 311 There are two ways to use the moving averages for evaluations: 312 313 * Build a model that uses the shadow variables instead of the variables. 314 For this, use the `average()` method which returns the shadow variable 315 for a given variable. 316 * Build a model normally but load the checkpoint files to evaluate by using 317 the shadow variable names. For this use the `average_name()` method. See 318 the `tf.train.Saver` for more 319 information on restoring saved variables. 320 321 Example of restoring the shadow variable values: 322 323 ```python 324 # Create a Saver that loads variables from their saved shadow values. 325 shadow_var0_name = ema.average_name(var0) 326 shadow_var1_name = ema.average_name(var1) 327 saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1}) 328 saver.restore(...checkpoint filename...) 329 # var0 and var1 now hold the moving average values 330 ``` 331 """ 332 333 def __init__(self, decay, num_updates=None, zero_debias=False, 334 name="ExponentialMovingAverage"): 335 """Creates a new ExponentialMovingAverage object. 336 337 The `apply()` method has to be called to create shadow variables and add 338 ops to maintain moving averages. 339 340 The optional `num_updates` parameter allows one to tweak the decay rate 341 dynamically. It is typical to pass the count of training steps, usually 342 kept in a variable that is incremented at each step, in which case the 343 decay rate is lower at the start of training. This makes moving averages 344 move faster. If passed, the actual decay rate used is: 345 346 `min(decay, (1 + num_updates) / (10 + num_updates))` 347 348 Args: 349 decay: Float. The decay to use. 350 num_updates: Optional count of number of updates applied to variables. 351 zero_debias: If `True`, zero debias moving-averages that are initialized 352 with tensors. 353 name: String. Optional prefix name to use for the name of ops added in 354 `apply()`. 355 """ 356 self._decay = decay 357 self._num_updates = num_updates 358 self._zero_debias = zero_debias 359 self._name = name 360 self._averages = {} 361 362 @property 363 def name(self): 364 """The name of this ExponentialMovingAverage object.""" 365 return self._name 366 367 def apply(self, var_list=None): 368 """Maintains moving averages of variables. 369 370 `var_list` must be a list of `Variable` or `Tensor` objects. This method 371 creates shadow variables for all elements of `var_list`. Shadow variables 372 for `Variable` objects are initialized to the variable's initial value. 373 They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 374 For `Tensor` objects, the shadow variables are initialized to 0 and zero 375 debiased (see docstring in `assign_moving_average` for more details). 376 377 shadow variables are created with `trainable=False` and added to the 378 `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to 379 `tf.global_variables()`. 380 381 Returns an op that updates all shadow variables from the current value of 382 their associated variables. 383 384 Note that `apply()` can be called multiple times. When eager execution is 385 enabled each call to apply will update the variables once, so this needs to 386 be called in a loop. 387 388 Args: 389 var_list: A list of Variable or Tensor objects. The variables 390 and Tensors must be of types bfloat16, float16, float32, or float64. 391 392 Returns: 393 An Operation that updates the moving averages. 394 395 Raises: 396 TypeError: If the arguments are not an allowed type. 397 """ 398 # TODO(touts): op_scope 399 if var_list is None: 400 var_list = variables.trainable_variables() 401 zero_debias_true = set() # set of vars to set `zero_debias=True` 402 for var in var_list: 403 if var.dtype.base_dtype not in [ 404 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 405 ]: 406 raise TypeError("The variables must be half, float, or double: %s" % 407 var.name) 408 409 if var not in self._averages: 410 # For variables: to lower communication bandwidth across devices we keep 411 # the moving averages on the same device as the variables. For other 412 # tensors, we rely on the existing device allocation mechanism. 413 with ops.init_scope(): 414 if isinstance(var, variables.Variable): 415 avg = slot_creator.create_slot(var, 416 var.initialized_value(), 417 self.name, 418 colocate_with_primary=True) 419 # NOTE(mrry): We only add `tf.Variable` objects to the 420 # `MOVING_AVERAGE_VARIABLES` collection. 421 ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 422 else: 423 avg = slot_creator.create_zeros_slot( 424 var, 425 self.name, 426 colocate_with_primary=(var.op.type in ["Variable", 427 "VariableV2", 428 "VarHandleOp"])) 429 if self._zero_debias: 430 zero_debias_true.add(avg) 431 self._averages[var] = avg 432 433 with ops.name_scope(self.name) as scope: 434 decay = ops.convert_to_tensor(self._decay, name="decay") 435 if self._num_updates is not None: 436 num_updates = math_ops.cast(self._num_updates, 437 dtypes.float32, 438 name="num_updates") 439 decay = math_ops.minimum(decay, 440 (1.0 + num_updates) / (10.0 + num_updates)) 441 updates = [] 442 for var in var_list: 443 zero_debias = self._averages[var] in zero_debias_true 444 updates.append(assign_moving_average( 445 self._averages[var], var, decay, zero_debias=zero_debias)) 446 return control_flow_ops.group(*updates, name=scope) 447 448 def average(self, var): 449 """Returns the `Variable` holding the average of `var`. 450 451 Args: 452 var: A `Variable` object. 453 454 Returns: 455 A `Variable` object or `None` if the moving average of `var` 456 is not maintained. 457 """ 458 return self._averages.get(var, None) 459 460 def average_name(self, var): 461 """Returns the name of the `Variable` holding the average for `var`. 462 463 The typical scenario for `ExponentialMovingAverage` is to compute moving 464 averages of variables during training, and restore the variables from the 465 computed moving averages during evaluations. 466 467 To restore variables, you have to know the name of the shadow variables. 468 That name and the original variable can then be passed to a `Saver()` object 469 to restore the variable from the moving average value with: 470 `saver = tf.train.Saver({ema.average_name(var): var})` 471 472 `average_name()` can be called whether or not `apply()` has been called. 473 474 Args: 475 var: A `Variable` object. 476 477 Returns: 478 A string: The name of the variable that will be used or was used 479 by the `ExponentialMovingAverage class` to hold the moving average of 480 `var`. 481 """ 482 if var in self._averages: 483 return self._averages[var].op.name 484 return ops.get_default_graph().unique_name( 485 var.op.name + "/" + self.name, mark_as_used=False) 486 487 def variables_to_restore(self, moving_avg_variables=None): 488 """Returns a map of names to `Variables` to restore. 489 490 If a variable has a moving average, use the moving average variable name as 491 the restore name; otherwise, use the variable name. 492 493 For example, 494 495 ```python 496 variables_to_restore = ema.variables_to_restore() 497 saver = tf.train.Saver(variables_to_restore) 498 ``` 499 500 Below is an example of such mapping: 501 502 ``` 503 conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, 504 conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, 505 global_step: global_step 506 ``` 507 Args: 508 moving_avg_variables: a list of variables that require to use of the 509 moving average variable name to be restored. If None, it will default to 510 variables.moving_average_variables() + variables.trainable_variables() 511 512 Returns: 513 A map from restore_names to variables. The restore_name is either the 514 original or the moving average version of the variable name, depending 515 on whether the variable name is in the `moving_avg_variables`. 516 """ 517 name_map = {} 518 if moving_avg_variables is None: 519 # Include trainable variables and variables which have been explicitly 520 # added to the moving_average_variables collection. 521 moving_avg_variables = variables.trainable_variables() 522 moving_avg_variables += variables.moving_average_variables() 523 # Remove duplicates 524 moving_avg_variables = set(moving_avg_variables) 525 # Collect all the variables with moving average, 526 for v in moving_avg_variables: 527 name_map[self.average_name(v)] = v 528 # Make sure we restore variables without moving averages as well. 529 moving_avg_variable_names = set([v.name for v in moving_avg_variables]) 530 for v in list(set(variables.global_variables())): 531 if v.name not in moving_avg_variable_names and v.op.name not in name_map: 532 name_map[v.op.name] = v 533 return name_map 534