1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Maintain moving averages of parameters."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.distribute import distribute_lib
21from tensorflow.python.distribute import distribution_strategy_context
22from tensorflow.python.distribute import reduce_util as ds_reduce_util
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import init_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.ops import variables
31from tensorflow.python.training import slot_creator
32from tensorflow.python.util.tf_export import tf_export
33
34
35# TODO(touts): switch to variables.Variable.
36def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
37  """Compute the moving average of a variable.
38
39  The moving average of 'variable' updated with 'value' is:
40    variable * decay + value * (1 - decay)
41
42  The returned Operation sets 'variable' to the newly computed moving average,
43  by performing this subtraction:
44     variable -= (1 - decay) * (variable - value)
45
46  Since variables that are initialized to a `0` value will be `0` biased,
47  `zero_debias` optionally enables scaling by the mathematically correct
48  debiasing factor of
49    1 - decay ** num_updates
50  See Section 3 of (Kingma et al., 2015) for more details.
51
52  The names of the debias shadow variables, by default, include both the scope
53  they were created in and the scope of the variables they debias. They are also
54  given a uniquifying-suffix.
55
56  E.g.:
57
58  ```
59    with tf.compat.v1.variable_scope('scope1'):
60      with tf.compat.v1.variable_scope('scope2'):
61        var = tf.compat.v1.get_variable('foo')
62        update_1 = tf.assign_moving_average(var, 0.0, 1.0)
63        update_2 = tf.assign_moving_average(var, 0.0, 0.9)
64
65    # var.name: 'scope1/scope2/foo'
66    # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
67    #                   'scope1/scope2/scope1/scope2/foo/biased_1'
68  ```
69
70  Args:
71    variable: A Variable.
72    value: A tensor with the same shape as 'variable'.
73    decay: A float Tensor or float value.  The moving average decay.
74    zero_debias: A python bool. If true, assume the variable is 0-initialized
75      and unbias it, as in (Kingma et al., 2015). See docstring in
76        `_zero_debias` for more details.
77    name: Optional name of the returned operation.
78
79  Returns:
80    A tensor which if evaluated will compute and return the new moving average.
81
82  References:
83    Adam - A Method for Stochastic Optimization:
84      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
85      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
86  """
87  with ops.name_scope(name, "AssignMovingAvg",
88                      [variable, value, decay]) as scope:
89    decay = ops.convert_to_tensor(1.0 - decay, name="decay")
90    if decay.dtype != variable.dtype.base_dtype:
91      decay = math_ops.cast(decay, variable.dtype.base_dtype)
92
93    def update_fn(v, value):
94      return state_ops.assign_sub(v, (v - value) * decay, name=scope)
95
96    def update(strategy, v, value):
97      if zero_debias:
98        return _zero_debias(strategy, v, value, decay)
99      else:
100        return _update(strategy, v, update_fn, args=(value,))
101
102    replica_context = distribution_strategy_context.get_replica_context()
103    if replica_context:
104      # In a replica context, we update variable using the mean of value across
105      # replicas.
106      def merge_fn(strategy, v, value):
107        value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
108                                            v)
109        return update(strategy, v, value)
110
111      return replica_context.merge_call(merge_fn, args=(variable, value))
112    else:
113      strategy = distribution_strategy_context.get_cross_replica_context()
114      return update(strategy, variable, value)
115
116
117def weighted_moving_average(value,
118                            decay,
119                            weight,
120                            truediv=True,
121                            collections=None,
122                            name=None):
123  """Compute the weighted moving average of `value`.
124
125  Conceptually, the weighted moving average is:
126    `moving_average(value * weight) / moving_average(weight)`,
127  where a moving average updates by the rule
128    `new_value = decay * old_value + (1 - decay) * update`
129  Internally, this Op keeps moving average variables of both `value * weight`
130  and `weight`.
131
132  Args:
133    value: A numeric `Tensor`.
134    decay: A float `Tensor` or float value.  The moving average decay.
135    weight:  `Tensor` that keeps the current value of a weight. Shape should be
136      able to multiply `value`.
137    truediv:  Boolean, if `True`, dividing by `moving_average(weight)` is
138      floating point division.  If `False`, use division implied by dtypes.
139    collections:  List of graph collections keys to add the internal variables
140      `value * weight` and `weight` to. Defaults to
141      `[GraphKeys.GLOBAL_VARIABLES]`.
142    name: Optional name of the returned operation. Defaults to
143      "WeightedMovingAvg".
144
145  Returns:
146    An Operation that updates and returns the weighted moving average.
147  """
148  # Unlike assign_moving_average, the weighted moving average doesn't modify
149  # user-visible variables. It is the ratio of two internal variables, which are
150  # moving averages of the updates.  Thus, the signature of this function is
151  # quite different than assign_moving_average.
152  if collections is None:
153    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
154  with variable_scope.variable_scope(name, "WeightedMovingAvg",
155                                     [value, weight, decay]) as scope:
156    value_x_weight_var = variable_scope.get_variable(
157        "value_x_weight",
158        shape=value.get_shape(),
159        dtype=value.dtype,
160        initializer=init_ops.zeros_initializer(),
161        trainable=False,
162        collections=collections)
163    weight_var = variable_scope.get_variable(
164        "weight",
165        shape=weight.get_shape(),
166        dtype=weight.dtype,
167        initializer=init_ops.zeros_initializer(),
168        trainable=False,
169        collections=collections)
170    numerator = assign_moving_average(
171        value_x_weight_var, value * weight, decay, zero_debias=False)
172    denominator = assign_moving_average(
173        weight_var, weight, decay, zero_debias=False)
174
175    if truediv:
176      return math_ops.truediv(numerator, denominator, name=scope.name)
177    else:
178      return math_ops.divide(numerator, denominator, name=scope.name)
179
180
181def _update(strategy, var, update_fn, args):
182  """Applies updates depending on the context."""
183  assert distribution_strategy_context.in_cross_replica_context(), (
184      "_update can only be called in cross-replica context")
185  if distribute_lib.get_update_replica_id() is not None:
186    # Call update_fn on var to delegate the implementation. We expect `var` will
187    # do the right thing in update context, e.g, if `var` is a MirroredVariable,
188    # it should pick its component variable based on `update_replica_id` and
189    # only update that.
190    return update_fn(var, *args)
191  else:
192    return strategy.extended.update(var, update_fn, args)
193
194
195def _zero_debias(strategy, unbiased_var, value, decay):
196  """Compute the delta required for a debiased Variable.
197
198  All exponential moving averages initialized with Tensors are initialized to 0,
199  and therefore are biased to 0. Variables initialized to 0 and used as EMAs are
200  similarly biased. This function creates the debias updated amount according to
201  a scale factor, as in (Kingma et al., 2015).
202
203  To demonstrate the bias the results from 0-initialization, take an EMA that
204  was initialized to `0` with decay `b`. After `t` timesteps of seeing the
205  constant `c`, the variable have the following value:
206
207  ```
208    EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ...
209        = c*(1 - b^t)
210  ```
211
212  To have the true value `c`, we would divide by the scale factor `1 - b^t`.
213
214  In order to perform debiasing, we use two shadow variables. One keeps track of
215  the biased estimate, and the other keeps track of the number of updates that
216  have occurred.
217
218  Args:
219    strategy: `Strategy` used to create and update variables.
220    unbiased_var: A Variable representing the current value of the unbiased EMA.
221    value: A Tensor representing the most recent value.
222    decay: A Tensor representing `1-decay` for the EMA.
223
224  Returns:
225    The amount that the unbiased variable should be updated. Computing this
226    tensor will also update the shadow variables appropriately.
227
228  References:
229    Adam - A Method for Stochastic Optimization:
230      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
231      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
232
233  """
234  with variable_scope.variable_scope(
235      unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]):
236    with ops.init_scope():
237      biased_initializer = init_ops.zeros_initializer()
238      local_step_initializer = init_ops.zeros_initializer()
239
240    def _maybe_get_unique(name):
241      """Get name for a unique variable, if not `reuse=True`."""
242      if variable_scope.get_variable_scope().reuse:
243        return name
244      vs_vars = [
245          x.op.name
246          for x in variable_scope.get_variable_scope().global_variables()
247      ]
248      full_name = variable_scope.get_variable_scope().name + "/" + name
249      if full_name not in vs_vars:
250        return name
251      idx = 1
252      while full_name + ("_%d" % idx) in vs_vars:
253        idx += 1
254      return name + ("_%d" % idx)
255
256    with strategy.extended.colocate_vars_with(unbiased_var):
257      biased_var = variable_scope.get_variable(
258          _maybe_get_unique("biased"),
259          initializer=biased_initializer,
260          shape=unbiased_var.get_shape(),
261          dtype=unbiased_var.dtype,
262          trainable=False)
263      local_step = variable_scope.get_variable(
264          _maybe_get_unique("local_step"),
265          shape=[],
266          dtype=unbiased_var.dtype,
267          initializer=local_step_initializer,
268          trainable=False)
269
270  def update_fn(v, value, biased_var, local_step):
271    update_biased = state_ops.assign_sub(biased_var,
272                                         (biased_var - value) * decay)
273    update_local_step = local_step.assign_add(1)
274
275    # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
276    bias_factor = 1 - math_ops.pow(1.0 - decay, update_local_step)
277    return state_ops.assign(
278        v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
279
280  return _update(
281      strategy, unbiased_var, update_fn, args=(value, biased_var, local_step))
282
283
284@tf_export("train.ExponentialMovingAverage")
285class ExponentialMovingAverage(object):
286  """Maintains moving averages of variables by employing an exponential decay.
287
288  When training a model, it is often beneficial to maintain moving averages of
289  the trained parameters.  Evaluations that use averaged parameters sometimes
290  produce significantly better results than the final trained values.
291
292  The `apply()` method adds shadow copies of trained variables and add ops that
293  maintain a moving average of the trained variables in their shadow copies.
294  It is used when building the training model.  The ops that maintain moving
295  averages are typically run after each training step.
296  The `average()` and `average_name()` methods give access to the shadow
297  variables and their names.  They are useful when building an evaluation
298  model, or when restoring a model from a checkpoint file.  They help use the
299  moving averages in place of the last trained values for evaluations.
300
301  The moving averages are computed using exponential decay.  You specify the
302  decay value when creating the `ExponentialMovingAverage` object.  The shadow
303  variables are initialized with the same initial values as the trained
304  variables.  When you run the ops to maintain the moving averages, each
305  shadow variable is updated with the formula:
306
307    `shadow_variable -= (1 - decay) * (shadow_variable - variable)`
308
309  This is mathematically equivalent to the classic formula below, but the use
310  of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless
311  updates to the variables:
312
313    `shadow_variable = decay * shadow_variable + (1 - decay) * variable`
314
315  Reasonable values for `decay` are close to 1.0, typically in the
316  multiple-nines range: 0.999, 0.9999, etc.
317
318  Example usage when creating a training model:
319
320  ```python
321  # Create variables.
322  var0 = tf.Variable(...)
323  var1 = tf.Variable(...)
324  # ... use the variables to build a training model...
325  ...
326  # Create an op that applies the optimizer.  This is what we usually
327  # would use as a training op.
328  opt_op = opt.minimize(my_loss, [var0, var1])
329
330  # Create an ExponentialMovingAverage object
331  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
332
333  with tf.control_dependencies([opt_op]):
334      # Create the shadow variables, and add ops to maintain moving averages
335      # of var0 and var1. This also creates an op that will update the moving
336      # averages after each training step.  This is what we will use in place
337      # of the usual training op.
338      training_op = ema.apply([var0, var1])
339
340  ...train the model by running training_op...
341  ```
342
343  There are two ways to use the moving averages for evaluations:
344
345  *  Build a model that uses the shadow variables instead of the variables.
346     For this, use the `average()` method which returns the shadow variable
347     for a given variable.
348  *  Build a model normally but load the checkpoint files to evaluate by using
349     the shadow variable names.  For this use the `average_name()` method.  See
350     the `tf.compat.v1.train.Saver` for more
351     information on restoring saved variables.
352
353  Example of restoring the shadow variable values:
354
355  ```python
356  # Create a Saver that loads variables from their saved shadow values.
357  shadow_var0_name = ema.average_name(var0)
358  shadow_var1_name = ema.average_name(var1)
359  saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name:
360  var1})
361  saver.restore(...checkpoint filename...)
362  # var0 and var1 now hold the moving average values
363  ```
364  """
365
366  def __init__(self,
367               decay,
368               num_updates=None,
369               zero_debias=False,
370               name="ExponentialMovingAverage"):
371    """Creates a new ExponentialMovingAverage object.
372
373    The `apply()` method has to be called to create shadow variables and add
374    ops to maintain moving averages.
375
376    The optional `num_updates` parameter allows one to tweak the decay rate
377    dynamically. It is typical to pass the count of training steps, usually
378    kept in a variable that is incremented at each step, in which case the
379    decay rate is lower at the start of training.  This makes moving averages
380    move faster.  If passed, the actual decay rate used is:
381
382      `min(decay, (1 + num_updates) / (10 + num_updates))`
383
384    Args:
385      decay: Float.  The decay to use.
386      num_updates: Optional count of number of updates applied to variables.
387      zero_debias: If `True`, zero debias moving-averages that are initialized
388        with tensors.
389      name: String. Optional prefix name to use for the name of ops added in
390        `apply()`.
391    """
392    self._decay = decay
393    self._num_updates = num_updates
394    self._zero_debias = zero_debias
395    self._name = name
396    self._averages = {}
397
398  @property
399  def name(self):
400    """The name of this ExponentialMovingAverage object."""
401    return self._name
402
403  def apply(self, var_list=None):
404    """Maintains moving averages of variables.
405
406    `var_list` must be a list of `Variable` or `Tensor` objects.  This method
407    creates shadow variables for all elements of `var_list`.  Shadow variables
408    for `Variable` objects are initialized to the variable's initial value.
409    They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
410    For `Tensor` objects, the shadow variables are initialized to 0 and zero
411    debiased (see docstring in `assign_moving_average` for more details).
412
413    shadow variables are created with `trainable=False` and added to the
414    `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
415    `tf.compat.v1.global_variables()`.
416
417    Returns an op that updates all shadow variables from the current value of
418    their associated variables.
419
420    Note that `apply()` can be called multiple times. When eager execution is
421    enabled each call to apply will update the variables once, so this needs to
422    be called in a loop.
423
424    Args:
425      var_list: A list of Variable or Tensor objects. The variables and Tensors
426        must be of types bfloat16, float16, float32, or float64.
427
428    Returns:
429      An Operation that updates the moving averages.
430
431    Raises:
432      TypeError: If the arguments are not an allowed type.
433    """
434    # TODO(touts): op_scope
435    if var_list is None:
436      var_list = variables.trainable_variables()
437    for v in var_list:
438      if isinstance(v, ops.EagerTensor):
439        raise TypeError(
440            "tf.train.ExponentialMovingAverage does not support non-Variable"
441            " tensors when eager execution is enabled.")
442    zero_debias_true = set()  # set of vars to set `zero_debias=True`
443    for var in var_list:
444      if var.dtype.base_dtype not in [
445          dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
446      ]:
447        raise TypeError("The variables must be half, float, or double: %s" %
448                        var.name)
449
450      if var.ref() not in self._averages:
451        # For variables: to lower communication bandwidth across devices we keep
452        # the moving averages on the same device as the variables. For other
453        # tensors, we rely on the existing device allocation mechanism.
454        with ops.init_scope():
455          if isinstance(var, variables.Variable):
456            with ops.device(var.device):
457              initialized_value = var.initialized_value()
458            avg = slot_creator.create_slot(
459                var,
460                initialized_value,
461                self.name,
462                colocate_with_primary=True,
463                copy_xla_sharding=True)
464            # NOTE(mrry): We only add `tf.Variable` objects to the
465            # `MOVING_AVERAGE_VARIABLES` collection.
466            ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
467          else:
468            avg = slot_creator.create_zeros_slot(
469                var,
470                self.name,
471                colocate_with_primary=(var.op.type in [
472                    "Variable", "VariableV2", "VarHandleOp"
473                ]),
474                copy_xla_sharding=True)
475            if self._zero_debias:
476              zero_debias_true.add(avg.ref())
477        self._averages[var.ref()] = avg
478
479    with ops.name_scope(self.name) as scope:
480      decay = ops.convert_to_tensor(
481          self._decay, dtype=dtypes.float32, name="decay")
482      if self._num_updates is not None:
483        num_updates = math_ops.cast(
484            self._num_updates, dtypes.float32, name="num_updates")
485        decay = math_ops.minimum(decay,
486                                 (1.0 + num_updates) / (10.0 + num_updates))
487      updates = []
488      for var in var_list:
489        avg = self._averages[var.ref()]
490        zero_debias = avg.ref() in zero_debias_true
491        updates.append(assign_moving_average(avg, var, decay, zero_debias))
492      return control_flow_ops.group(*updates, name=scope)
493
494  def average(self, var):
495    """Returns the `Variable` holding the average of `var`.
496
497    Args:
498      var: A `Variable` object.
499
500    Returns:
501      A `Variable` object or `None` if the moving average of `var`
502      is not maintained.
503    """
504    return self._averages.get(var.ref(), None)
505
506  def average_name(self, var):
507    """Returns the name of the `Variable` holding the average for `var`.
508
509    The typical scenario for `ExponentialMovingAverage` is to compute moving
510    averages of variables during training, and restore the variables from the
511    computed moving averages during evaluations.
512
513    To restore variables, you have to know the name of the shadow variables.
514    That name and the original variable can then be passed to a `Saver()` object
515    to restore the variable from the moving average value with:
516      `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
517
518    `average_name()` can be called whether or not `apply()` has been called.
519
520    Args:
521      var: A `Variable` object.
522
523    Returns:
524      A string: The name of the variable that will be used or was used
525      by the `ExponentialMovingAverage class` to hold the moving average of
526      `var`.
527    """
528    if var.ref() in self._averages:
529      return self._averages[var.ref()].op.name
530    return ops.get_default_graph().unique_name(
531        var.op.name + "/" + self.name, mark_as_used=False)
532
533  def variables_to_restore(self, moving_avg_variables=None):
534    """Returns a map of names to `Variables` to restore.
535
536    If a variable has a moving average, use the moving average variable name as
537    the restore name; otherwise, use the variable name.
538
539    For example,
540
541    ```python
542      variables_to_restore = ema.variables_to_restore()
543      saver = tf.compat.v1.train.Saver(variables_to_restore)
544    ```
545
546    Below is an example of such mapping:
547
548    ```
549      conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma,
550      conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params,
551      global_step: global_step
552    ```
553
554    Args:
555      moving_avg_variables: a list of variables that require to use of the
556        moving average variable name to be restored. If None, it will default to
557        variables.moving_average_variables() + variables.trainable_variables()
558
559    Returns:
560      A map from restore_names to variables. The restore_name is either the
561      original or the moving average version of the variable name, depending
562      on whether the variable name is in the `moving_avg_variables`.
563    """
564    name_map = {}
565    if moving_avg_variables is None:
566      # Include trainable variables and variables which have been explicitly
567      # added to the moving_average_variables collection.
568      moving_avg_variables = variables.trainable_variables()
569      moving_avg_variables += variables.moving_average_variables()
570    # Remove duplicates
571    moving_avg_variables = set(moving_avg_variables)
572    # Collect all the variables with moving average,
573    for v in moving_avg_variables:
574      name_map[self.average_name(v)] = v
575    # Make sure we restore variables without moving averages as well.
576    moving_avg_variable_names = set(v.name for v in moving_avg_variables)
577    for v in list(set(variables.global_variables())):
578      if v.name not in moving_avg_variable_names and v.op.name not in name_map:
579        name_map[v.op.name] = v
580    return name_map
581