1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Base class for optimizers."""
17# pylint: disable=g-bad-name
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import abc
24
25import six
26
27from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
28from tensorflow.python.distribute import reduce_util as ds_reduce_util
29from tensorflow.python.eager import backprop
30from tensorflow.python.eager import context
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gradients
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import state_ops
39from tensorflow.python.ops import variable_scope
40from tensorflow.python.ops import variables
41from tensorflow.python.training import slot_creator
42from tensorflow.python.training.tracking import base as trackable
43from tensorflow.python.util import nest
44from tensorflow.python.util.tf_export import tf_export
45
46
47def get_filtered_grad_fn(grad_fn):
48  # `distributed_context.join()` requires that its arguments are parallel
49  # across threads, and in particular that `grads_and_vars` has the same
50  # variables in the same order.
51
52  # When computing gradients in eager mode with multiple threads, you
53  # can get extra variables with a gradient of `None`. This happens when
54  # those variables are accessed in another thread during the gradient
55  # computation. To get a consistent set of variables, we filter out
56  # those with `None` gradients.
57  def filtered_grad_fn(*args, **kwargs):
58    return [(g, v) for g, v in grad_fn(*args, **kwargs) if g is not None]
59
60  return filtered_grad_fn
61
62
63def _deduplicate_indexed_slices(values, indices):
64  """Sums `values` associated with any non-unique `indices`.
65
66  Args:
67    values: A `Tensor` with rank >= 1.
68    indices: A one-dimensional integer `Tensor`, indexing into the first
69      dimension of `values` (as in an IndexedSlices object).
70  Returns:
71    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
72    de-duplicated version of `indices` and `summed_values` contains the sum of
73    `values` slices associated with each unique index.
74  """
75  unique_indices, new_index_positions = array_ops.unique(indices)
76  summed_values = math_ops.unsorted_segment_sum(
77      values, new_index_positions,
78      array_ops.shape(unique_indices)[0])
79  return (summed_values, unique_indices)
80
81
82def _var_key(var):
83  # TODO(ashankar): Consolidate handling for eager and graph
84  if hasattr(var, "op"):
85    return (var.op.graph, var.op.name)
86  return var._unique_id  # pylint: disable=protected-access
87
88
89@six.add_metaclass(abc.ABCMeta)
90class _OptimizableVariable(object):
91  """Interface for abstracting over variables in the optimizers."""
92
93  @abc.abstractmethod
94  def target(self):
95    """Returns the optimization target for this variable."""
96    raise NotImplementedError("Calling an abstract method.")
97
98  @abc.abstractmethod
99  def update_op(self, optimizer, g):
100    """Returns the update ops for updating the variable."""
101    raise NotImplementedError("Calling an abstract method.")
102
103
104class _RefVariableProcessor(_OptimizableVariable):
105  """Processor for Variable."""
106
107  def __init__(self, v):
108    self._v = v
109
110  def __str__(self):
111    return "<_RefVariableProcessor(%s)>" % self._v
112
113  def target(self):
114    return self._v._ref()  # pylint: disable=protected-access
115
116  def update_op(self, optimizer, g):
117    if isinstance(g, ops.Tensor):
118      update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
119      if self._v.constraint is not None:
120        with ops.control_dependencies([update_op]):
121          return self._v.assign(self._v.constraint(self._v))
122      else:
123        return update_op
124    else:
125      assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
126                                                "tensor nor IndexedSlices.")
127      if self._v.constraint is not None:
128        raise RuntimeError(
129            "Cannot use a constraint function on a sparse variable.")
130      # pylint: disable=protected-access
131      return optimizer._apply_sparse_duplicate_indices(g, self._v)
132
133
134class _DenseReadResourceVariableProcessor(_OptimizableVariable):
135  """Processor for dense ResourceVariables."""
136
137  def __init__(self, v):
138    self._v = v
139
140  def target(self):
141    return self._v
142
143  def update_op(self, optimizer, g):
144    # pylint: disable=protected-access
145    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
146    if self._v.constraint is not None:
147      with ops.control_dependencies([update_op]):
148        return self._v.assign(self._v.constraint(self._v))
149    else:
150      return update_op
151
152
153class _DenseResourceVariableProcessor(_OptimizableVariable):
154  """Processor for dense ResourceVariables."""
155
156  def __init__(self, v):
157    self._v = v
158
159  def target(self):
160    return self._v
161
162  def update_op(self, optimizer, g):
163    # pylint: disable=protected-access
164    if isinstance(g, ops.IndexedSlices):
165      if self._v.constraint is not None:
166        raise RuntimeError(
167            "Cannot use a constraint function on a sparse variable.")
168      return optimizer._resource_apply_sparse_duplicate_indices(
169          g.values, self._v, g.indices)
170    update_op = optimizer._resource_apply_dense(g, self._v)
171    if self._v.constraint is not None:
172      with ops.control_dependencies([update_op]):
173        return self._v.assign(self._v.constraint(self._v))
174    else:
175      return update_op
176
177
178class _TensorProcessor(_OptimizableVariable):
179  """Processor for ordinary Tensors.
180
181  Even though a Tensor can't really be updated, sometimes it is useful to
182  compute the gradients with respect to a Tensor using the optimizer. Updating
183  the Tensor is, of course, unsupported.
184  """
185
186  def __init__(self, v):
187    self._v = v
188
189  def target(self):
190    return self._v
191
192  def update_op(self, optimizer, g):
193    raise NotImplementedError("Trying to update a Tensor ", self._v)
194
195
196def _get_processor(v):
197  """The processor of v."""
198  if context.executing_eagerly():
199    if isinstance(v, ops.Tensor):
200      return _TensorProcessor(v)
201    else:
202      return _DenseResourceVariableProcessor(v)
203  if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode:  # pylint: disable=protected-access
204    # True if and only if `v` was initialized eagerly.
205    return _DenseResourceVariableProcessor(v)
206  if v.op.type == "VarHandleOp":
207    return _DenseResourceVariableProcessor(v)
208  if isinstance(v, variables.Variable):
209    return _RefVariableProcessor(v)
210  if isinstance(v, ops.Tensor):
211    return _TensorProcessor(v)
212  raise NotImplementedError("Trying to optimize unsupported type ", v)
213
214
215@tf_export(v1=["train.Optimizer"])
216class Optimizer(
217    # Optimizers inherit from Trackable rather than AutoTrackable
218    # since they do most of their dependency management themselves (slot
219    # variables are special-cased, and non-slot variables are keyed to graphs).
220    trackable.Trackable):
221  """Base class for optimizers.
222
223  This class defines the API to add Ops to train a model.  You never use this
224  class directly, but instead instantiate one of its subclasses such as
225  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
226
227  ### Usage
228
229  ```python
230  # Create an optimizer with the desired parameters.
231  opt = GradientDescentOptimizer(learning_rate=0.1)
232  # Add Ops to the graph to minimize a cost by updating a list of variables.
233  # "cost" is a Tensor, and the list of variables contains tf.Variable
234  # objects.
235  opt_op = opt.minimize(cost, var_list=<list of variables>)
236  ```
237
238  In the training program you will just have to run the returned Op.
239
240  ```python
241  # Execute opt_op to do one step of training:
242  opt_op.run()
243  ```
244
245  ### Processing gradients before applying them.
246
247  Calling `minimize()` takes care of both computing the gradients and
248  applying them to the variables.  If you want to process the gradients
249  before applying them you can instead use the optimizer in three steps:
250
251  1.  Compute the gradients with `compute_gradients()`.
252  2.  Process the gradients as you wish.
253  3.  Apply the processed gradients with `apply_gradients()`.
254
255  Example:
256
257  ```python
258  # Create an optimizer.
259  opt = GradientDescentOptimizer(learning_rate=0.1)
260
261  # Compute the gradients for a list of variables.
262  grads_and_vars = opt.compute_gradients(loss, <list of variables>)
263
264  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
265  # need to the 'gradient' part, for example cap them, etc.
266  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
267
268  # Ask the optimizer to apply the capped gradients.
269  opt.apply_gradients(capped_grads_and_vars)
270  ```
271
272  ### Gating Gradients
273
274  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
275  argument that controls the degree of parallelism during the application of
276  the gradients.
277
278  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
279
280  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
281  the maximum parallelism in execution, at the cost of some non-reproducibility
282  in the results.  For example the two gradients of `matmul` depend on the input
283  values: With `GATE_NONE` one of the gradients could be applied to one of the
284  inputs _before_ the other gradient is computed resulting in non-reproducible
285  results.
286
287  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
288  they are used.  This prevents race conditions for Ops that generate gradients
289  for multiple inputs where the gradients depend on the inputs.
290
291  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
292  before any one of them is used.  This provides the least parallelism but can
293  be useful if you want to process all gradients before applying any of them.
294
295  ### Slots
296
297  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
298  allocate and manage additional variables associated with the variables to
299  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
300  optimizer for the names of the slots that it uses.  Once you have a slot name
301  you can ask the optimizer for the variable it created to hold the slot value.
302
303  This can be useful if you want to log debug a training algorithm, report stats
304  about the slots, etc.
305  """
306
307  # Values for gate_gradients.
308  GATE_NONE = 0
309  GATE_OP = 1
310  GATE_GRAPH = 2
311
312  def __init__(self, use_locking, name):
313    """Create a new Optimizer.
314
315    This must be called by the constructors of subclasses.
316
317    Args:
318      use_locking: Bool. If True apply use locks to prevent concurrent updates
319        to variables.
320      name: A non-empty string.  The name to use for accumulators created
321        for the optimizer.
322
323    Raises:
324      ValueError: If name is malformed.
325    """
326    if not name:
327      raise ValueError("Must specify the optimizer name")
328    self._use_locking = use_locking
329    self._name = name
330    # Dictionary of slots.
331    #  {slot_name :
332    #      {_var_key(variable_to_train): slot_for_the_variable, ... },
333    #   ... }
334    self._slots = {}
335    self._non_slot_dict = {}
336    # For implementing Trackable. Stores information about how to restore
337    # slot variables which have not yet been created
338    # (trackable._CheckpointPosition objects).
339    #  {slot_name :
340    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
341    #   ... }
342    self._deferred_slot_restorations = {}
343
344    # TODO(isaprykin): When using a DistributionStrategy, and when an
345    # optimizer is created in each replica, it might be dangerous to
346    # rely on some Optimizer methods.  When such methods are called on a
347    # per-replica optimizer, an exception needs to be thrown.  We do
348    # allow creation per-replica optimizers however, because the
349    # compute_gradients()->apply_gradients() sequence is safe.
350
351  def get_name(self):
352    return self._name
353
354  def minimize(self, loss, global_step=None, var_list=None,
355               gate_gradients=GATE_OP, aggregation_method=None,
356               colocate_gradients_with_ops=False, name=None,
357               grad_loss=None):
358    """Add operations to minimize `loss` by updating `var_list`.
359
360    This method simply combines calls `compute_gradients()` and
361    `apply_gradients()`. If you want to process the gradient before applying
362    them call `compute_gradients()` and `apply_gradients()` explicitly instead
363    of using this function.
364
365    Args:
366      loss: A `Tensor` containing the value to minimize.
367      global_step: Optional `Variable` to increment by one after the
368        variables have been updated.
369      var_list: Optional list or tuple of `Variable` objects to update to
370        minimize `loss`.  Defaults to the list of variables collected in
371        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
372      gate_gradients: How to gate the computation of gradients.  Can be
373        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
374      aggregation_method: Specifies the method used to combine gradient terms.
375        Valid values are defined in the class `AggregationMethod`.
376      colocate_gradients_with_ops: If True, try colocating gradients with
377        the corresponding op.
378      name: Optional name for the returned operation.
379      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
380
381    Returns:
382      An Operation that updates the variables in `var_list`.  If `global_step`
383      was not `None`, that operation also increments `global_step`.
384
385    Raises:
386      ValueError: If some of the variables are not `Variable` objects.
387
388    @compatibility(eager)
389    When eager execution is enabled, `loss` should be a Python function that
390    takes no arguments and computes the value to be minimized. Minimization (and
391    gradient computation) is done with respect to the elements of `var_list` if
392    not None, else with respect to any trainable variables created during the
393    execution of the `loss` function. `gate_gradients`, `aggregation_method`,
394    `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
395    execution is enabled.
396    @end_compatibility
397    """
398    grads_and_vars = self.compute_gradients(
399        loss, var_list=var_list, gate_gradients=gate_gradients,
400        aggregation_method=aggregation_method,
401        colocate_gradients_with_ops=colocate_gradients_with_ops,
402        grad_loss=grad_loss)
403
404    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
405    if not vars_with_grad:
406      raise ValueError(
407          "No gradients provided for any variable, check your graph for ops"
408          " that do not support gradients, between variables %s and loss %s." %
409          ([str(v) for _, v in grads_and_vars], loss))
410
411    return self.apply_gradients(grads_and_vars, global_step=global_step,
412                                name=name)
413
414  def compute_gradients(self, loss, var_list=None,
415                        gate_gradients=GATE_OP,
416                        aggregation_method=None,
417                        colocate_gradients_with_ops=False,
418                        grad_loss=None):
419    """Compute gradients of `loss` for the variables in `var_list`.
420
421    This is the first part of `minimize()`.  It returns a list
422    of (gradient, variable) pairs where "gradient" is the gradient
423    for "variable".  Note that "gradient" can be a `Tensor`, an
424    `IndexedSlices`, or `None` if there is no gradient for the
425    given variable.
426
427    Args:
428      loss: A Tensor containing the value to minimize or a callable taking
429        no arguments which returns the value to minimize. When eager execution
430        is enabled it must be a callable.
431      var_list: Optional list or tuple of `tf.Variable` to update to minimize
432        `loss`.  Defaults to the list of variables collected in the graph
433        under the key `GraphKeys.TRAINABLE_VARIABLES`.
434      gate_gradients: How to gate the computation of gradients.  Can be
435        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
436      aggregation_method: Specifies the method used to combine gradient terms.
437        Valid values are defined in the class `AggregationMethod`.
438      colocate_gradients_with_ops: If True, try colocating gradients with
439        the corresponding op.
440      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
441
442    Returns:
443      A list of (gradient, variable) pairs. Variable is always present, but
444      gradient can be `None`.
445
446    Raises:
447      TypeError: If `var_list` contains anything else than `Variable` objects.
448      ValueError: If some arguments are invalid.
449      RuntimeError: If called with eager execution enabled and `loss` is
450        not callable.
451
452    @compatibility(eager)
453    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
454    and `colocate_gradients_with_ops` are ignored.
455    @end_compatibility
456    """
457    if callable(loss):
458      with backprop.GradientTape() as tape:
459        if var_list is not None:
460          tape.watch(var_list)
461        loss_value = loss()
462
463      if var_list is None:
464        var_list = tape.watched_variables()
465      # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
466      # to be executed.
467      with ops.control_dependencies([loss_value]):
468        grads = tape.gradient(loss_value, var_list, grad_loss)
469      return list(zip(grads, var_list))
470
471    # Non-callable/Tensor loss case
472    if context.executing_eagerly():
473      raise RuntimeError(
474          "`loss` passed to Optimizer.compute_gradients should "
475          "be a function when eager execution is enabled.")
476
477    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
478                              Optimizer.GATE_GRAPH]:
479      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
480                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
481                       gate_gradients)
482    self._assert_valid_dtypes([loss])
483    if grad_loss is not None:
484      self._assert_valid_dtypes([grad_loss])
485    if var_list is None:
486      var_list = (
487          variables.trainable_variables() +
488          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
489    else:
490      var_list = nest.flatten(var_list)
491    # pylint: disable=protected-access
492    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
493    # pylint: enable=protected-access
494    processors = [_get_processor(v) for v in var_list]
495    if not var_list:
496      raise ValueError("No variables to optimize.")
497    var_refs = [p.target() for p in processors]
498    grads = gradients.gradients(
499        loss, var_refs, grad_ys=grad_loss,
500        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
501        aggregation_method=aggregation_method,
502        colocate_gradients_with_ops=colocate_gradients_with_ops)
503    if gate_gradients == Optimizer.GATE_GRAPH:
504      grads = control_flow_ops.tuple(grads)
505    grads_and_vars = list(zip(grads, var_list))
506    self._assert_valid_dtypes(
507        [v for g, v in grads_and_vars
508         if g is not None and v.dtype != dtypes.resource])
509    return grads_and_vars
510
511  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
512    """Apply gradients to variables.
513
514    This is the second part of `minimize()`. It returns an `Operation` that
515    applies gradients.
516
517    Args:
518      grads_and_vars: List of (gradient, variable) pairs as returned by
519        `compute_gradients()`.
520      global_step: Optional `Variable` to increment by one after the
521        variables have been updated.
522      name: Optional name for the returned operation.  Default to the
523        name passed to the `Optimizer` constructor.
524
525    Returns:
526      An `Operation` that applies the specified gradients. If `global_step`
527      was not None, that operation also increments `global_step`.
528
529    Raises:
530      TypeError: If `grads_and_vars` is malformed.
531      ValueError: If none of the variables have gradients.
532      RuntimeError: If you should use `_distributed_apply()` instead.
533    """
534    # This is a default implementation of apply_gradients() that can be shared
535    # by most optimizers.  It relies on the subclass implementing the following
536    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
537
538    # TODO(isaprykin): Get rid of `has_strategy()` check by
539    # always calling _distributed_apply(), using the default distribution
540    # as needed.
541    if distribute_ctx.has_strategy():
542      # Handle DistributionStrategy case.
543      if distribute_ctx.in_cross_replica_context():
544        raise RuntimeError("Use `_distributed_apply()` instead of "
545                           "`apply_gradients()` in a cross-replica context.")
546
547      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
548      return distribute_ctx.get_replica_context().merge_call(
549          self._distributed_apply, args=(grads_and_vars, global_step, name))
550
551    # No DistributionStrategy case.
552    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
553    if not grads_and_vars:
554      raise ValueError("No variables provided.")
555    converted_grads_and_vars = []
556    for g, v in grads_and_vars:
557      if g is not None:
558        try:
559          # Convert the grad to Tensor or IndexedSlices if necessary.
560          g = ops.convert_to_tensor_or_indexed_slices(g)
561        except TypeError:
562          raise TypeError(
563              "Gradient must be convertible to a Tensor"
564              " or IndexedSlices, or None: %s" % g)
565        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
566          raise TypeError(
567              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
568      p = _get_processor(v)
569      converted_grads_and_vars.append((g, v, p))
570
571    converted_grads_and_vars = tuple(converted_grads_and_vars)
572    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
573    if not var_list:
574      raise ValueError("No gradients provided for any variable: %s." %
575                       ([str(v) for _, v, _ in converted_grads_and_vars],))
576    with ops.init_scope():
577      self._create_slots(var_list)
578    update_ops = []
579    with ops.name_scope(name, self._name) as name:
580      self._prepare()
581      for grad, var, processor in converted_grads_and_vars:
582        if grad is None:
583          continue
584        # We colocate all ops created in _apply_dense or _apply_sparse
585        # on the same device as the variable.
586        # TODO(apassos): figure out how to get the variable name here.
587        if context.executing_eagerly() or isinstance(
588            var,
589            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
590          scope_name = ""
591        else:
592          scope_name = var.op.name
593        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
594          update_ops.append(processor.update_op(self, grad))
595      if global_step is None:
596        apply_updates = self._finish(update_ops, name)
597      else:
598        with ops.control_dependencies([self._finish(update_ops, "update")]):
599          with ops.colocate_with(global_step):
600            if isinstance(global_step, resource_variable_ops.ResourceVariable):
601              # TODO(apassos): the implicit read in assign_add is slow; consider
602              # making it less so.
603              apply_updates = resource_variable_ops.assign_add_variable_op(
604                  global_step.handle,
605                  ops.convert_to_tensor(1, dtype=global_step.dtype),
606                  name=name)
607            else:
608              apply_updates = state_ops.assign_add(global_step, 1, name=name)
609
610      if not context.executing_eagerly():
611        if isinstance(apply_updates, ops.Tensor):
612          apply_updates = apply_updates.op
613        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
614        if apply_updates not in train_op:
615          train_op.append(apply_updates)
616
617      return apply_updates
618
619  def _distributed_apply(self,
620                         distribution,
621                         grads_and_vars,
622                         global_step=None,
623                         name=None):
624    """A version of `apply_gradients` for cross-replica context.
625
626    This is a version of `apply_gradients()` for when you are using a
627    `DistributionStrategy` and are in a cross-replica context. If in a
628    replica context, use `apply_gradients()` as normal.
629
630    Args:
631      distribution: A `DistributionStrategy` object.
632      grads_and_vars: List of (gradient, variable) pairs as returned by
633        `compute_gradients()`, and then aggregated across replicas.
634      global_step: Optional (mirrored) `Variable` to increment by one
635        after the variables have been updated.
636      name: Optional name for the returned operation.  Default to the
637        name passed to the `Optimizer` constructor.
638
639    Returns:
640      An `Operation` that applies the specified gradients across all
641      replicas. If `global_step` was not None, that operation also
642      increments `global_step`
643    """
644    reduced_grads = distribution.extended.batch_reduce_to(
645        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
646    var_list = [v for _, v in grads_and_vars]
647    grads_and_vars = zip(reduced_grads, var_list)
648
649    # Note that this is called in a cross-replica context.
650    with ops.init_scope():
651      self._create_slots(var_list)
652
653    def update(v, g):
654      """Apply gradients to a replica variable."""
655      assert v is not None
656
657      try:
658        # Convert the grad to Tensor or IndexedSlices if necessary.
659        g = ops.convert_to_tensor_or_indexed_slices(g)
660      except TypeError:
661        raise TypeError("Gradient must be convertible to a Tensor"
662                        " or IndexedSlices, or None: %s" % g)
663      if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
664        raise TypeError(
665            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
666      p = _get_processor(v)
667
668      if context.executing_eagerly() or (
669          resource_variable_ops.is_resource_variable(v) and
670          not v._in_graph_mode):  # pylint: disable=protected-access
671        scope_name = v.name.split(":")[0]
672      else:
673        scope_name = v.op.name
674
675      # device_policy is set because non-mirrored tensors will be read in
676      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
677      # is an example.
678      with ops.name_scope("update_" + scope_name):
679        return p.update_op(self, g)
680
681    with ops.name_scope(name, self._name) as name:
682      self._prepare()
683
684      update_ops = [
685          op
686          for grad, var in grads_and_vars
687          for op in distribution.extended.update(
688              var, update, args=(grad,), group=False)
689      ]
690
691      def finish(self, update_ops):
692        return self._finish(update_ops, "update")
693
694      non_slot_devices = distribution.extended.non_slot_devices(var_list)
695      finish_updates = distribution.extended.update_non_slot(
696          non_slot_devices, finish, args=(self, update_ops), group=False)
697      if global_step is None:
698        apply_updates = distribution.group(finish_updates, name=name)
699      else:
700        with ops.control_dependencies(finish_updates):
701          apply_updates = distribution.extended.update(
702              global_step, state_ops.assign_add, args=(1,),
703              kwargs={"name": name})
704
705      if not context.executing_eagerly():
706        if isinstance(apply_updates, ops.Tensor):
707          apply_updates = apply_updates.op
708        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
709        if apply_updates not in train_op:
710          train_op.append(apply_updates)
711
712      return apply_updates
713
714  def get_slot(self, var, name):
715    """Return a slot named `name` created for `var` by the Optimizer.
716
717    Some `Optimizer` subclasses use additional variables.  For example
718    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
719    gives access to these `Variable` objects if for some reason you need them.
720
721    Use `get_slot_names()` to get the list of slot names created by the
722    `Optimizer`.
723
724    Args:
725      var: A variable passed to `minimize()` or `apply_gradients()`.
726      name: A string.
727
728    Returns:
729      The `Variable` for the slot if it was created, `None` otherwise.
730    """
731    # pylint: disable=protected-access
732    named_slots = self._slots.get(name, None)
733    if not named_slots:
734      return None
735
736    if hasattr(var, "_distributed_container"):
737      # NOTE: If this isn't patched, then there is no `handle` in
738      # `_resource_apply_dense`.
739      distributed_container = var._distributed_container()
740      assert distributed_container is not None
741      if ops.executing_eagerly_outside_functions():
742        key = distributed_container._unique_id
743      else:
744        key = (distributed_container.graph, distributed_container._shared_name)
745      # pylint: enable=protected-access
746      mirrored_slot = named_slots.get(key, None)
747      if mirrored_slot is None: return None
748      return mirrored_slot.get(device=var.device)
749
750    return named_slots.get(_var_key(var), None)
751
752  def get_slot_names(self):
753    """Return a list of the names of slots created by the `Optimizer`.
754
755    See `get_slot()`.
756
757    Returns:
758      A list of strings.
759    """
760    return sorted(self._slots.keys())
761
762  def variables(self):
763    """A list of variables which encode the current state of `Optimizer`.
764
765    Includes slot variables and additional global variables created by the
766    optimizer in the current default graph.
767
768    Returns:
769      A list of variables.
770    """
771    current_graph = ops.get_default_graph()
772
773    def _from_current_graph(variable):
774      if variable._in_graph_mode:  # pylint: disable=protected-access
775        return variable.op.graph is current_graph
776      else:
777        # No variable.op in eager mode. We don't expect lots of eager graphs,
778        # but behavior should be consistent with graph mode.
779        return variable._graph_key == current_graph._graph_key  # pylint: disable=protected-access
780
781    optimizer_variables = [v for v in self._non_slot_variables()
782                           if _from_current_graph(v)]
783    for _, variable_dict in self._slots.items():
784      for _, slot_for_variable in variable_dict.items():
785        if _from_current_graph(slot_for_variable):
786          optimizer_variables.append(slot_for_variable)
787    # Sort variables by name so that the return is deterministic.
788    return sorted(optimizer_variables, key=lambda v: v.name)
789
790  def _create_non_slot_variable(self, initial_value, name, colocate_with):
791    """Add an extra variable, not associated with a slot."""
792    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
793    eager = context.executing_eagerly()
794    graph = None if eager else colocate_with.graph
795
796    key = (name, graph)
797    v = self._non_slot_dict.get(key, None)
798    if v is None:
799      self._maybe_initialize_trackable()
800      distribution_strategy = distribute_ctx.get_strategy()
801      with distribution_strategy.extended.colocate_vars_with(colocate_with):
802        if eager:
803          restored_initial_value = self._preload_simple_restoration(
804              name=name, shape=None)
805          if restored_initial_value is not None:
806            initial_value = restored_initial_value
807        v = variable_scope.variable(
808            initial_value, name=name, trainable=False,
809            use_resource=resource_variable_ops.is_resource_variable(
810                colocate_with))
811      # Restore this variable by name if necessary, but don't add a
812      # Trackable dependency. Optimizers return the current graph's
813      # non-slot variables from _checkpoint_dependencies explicitly rather
814      # than unconditionally adding dependencies (since there may be multiple
815      # non-slot variables with the same name in different graphs, trying to
816      # save all of them would result in errors).
817      self._handle_deferred_dependencies(name=name, trackable=v)
818      self._non_slot_dict[key] = v
819
820    return v
821
822  @property
823  def _checkpoint_dependencies(self):
824    """From Trackable. Gather graph-specific non-slot variables to save."""
825    current_graph_non_slot_variables = []
826    current_graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
827    for (name, _), variable_object in sorted(self._non_slot_dict.items(),
828                                             # Avoid comparing graphs
829                                             key=lambda item: item[0][0]):
830      if variable_object._graph_key == current_graph_key:  # pylint: disable=protected-access
831        current_graph_non_slot_variables.append(
832            trackable.TrackableReference(
833                name=name, ref=variable_object))
834    return (super(Optimizer, self)._checkpoint_dependencies
835            + current_graph_non_slot_variables)
836
837  def _lookup_dependency(self, name):
838    """From Trackable. Find a non-slot variable in the current graph."""
839    unconditional = super(Optimizer, self)._lookup_dependency(name)
840    if unconditional is not None:
841      return unconditional
842    graph = None if context.executing_eagerly() else ops.get_default_graph()
843    return self._get_non_slot_variable(name, graph=graph)
844
845  def _get_non_slot_variable(self, name, graph=None):
846    non_slot = self._non_slot_dict.get((name, graph), None)
847    if hasattr(non_slot, "_distributed_container"):
848      # This is a mirrored non-slot.  In order to enable code like `_finish`
849      # to assign to a non-slot, return the current context replica.
850      return non_slot.get()
851    else:
852      return non_slot
853
854  def _non_slot_variables(self):
855    """Additional variables created by the `Optimizer`.
856
857    Returns:
858      A list or tuple of variables.
859    """
860    return self._non_slot_dict.values()
861
862  def _assert_valid_dtypes(self, tensors):
863    """Asserts tensors are all valid types (see `_valid_dtypes`).
864
865    Args:
866      tensors: Tensors to check.
867
868    Raises:
869      ValueError: If any tensor is not a valid type.
870    """
871    valid_dtypes = self._valid_dtypes()
872    for t in tensors:
873      dtype = t.dtype.base_dtype
874      if dtype not in valid_dtypes:
875        raise ValueError(
876            "Invalid type %r for %s, expected: %s." % (
877                dtype, t.name, [v for v in valid_dtypes]))
878
879  # --------------
880  # Methods to be implemented by subclasses if they want to use the
881  # inherited implementation of apply_gradients() or compute_gradients().
882  # --------------
883  def _valid_dtypes(self):
884    """Valid types for loss, variables and gradients.
885
886    Subclasses should override to allow other float types.
887
888    Returns:
889      Valid types for loss, variables and gradients.
890    """
891    return set(
892        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
893
894  def _create_slots(self, var_list):
895    """Create all slots needed by the variables.
896
897    Args:
898      var_list: A list of `Variable` objects.
899    """
900    # No slots needed by default
901    pass
902
903  def _prepare(self):
904    """Create all needed tensors before applying gradients.
905
906    This is called with the name_scope using the "name" that
907    users have chosen for the application of gradients.
908    """
909    pass
910
911  def _apply_dense(self, grad, var):
912    """Add ops to apply dense gradients to `var`.
913
914    Args:
915      grad: A `Tensor`.
916      var: A `Variable` object.
917
918    Returns:
919      An `Operation`.
920    """
921    raise NotImplementedError()
922
923  def _resource_apply_dense(self, grad, handle):
924    """Add ops to apply dense gradients to the variable `handle`.
925
926    Args:
927      grad: a `Tensor` representing the gradient.
928      handle: a `Tensor` of dtype `resource` which points to the variable
929       to be updated.
930
931    Returns:
932      An `Operation` which updates the value of the variable.
933    """
934    raise NotImplementedError()
935
936  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
937    """Add ops to apply sparse gradients to `handle`, with repeated indices.
938
939    Optimizers which override this method must deal with repeated indices. See
940    the docstring of `_apply_sparse_duplicate_indices` for details. By default
941    the correct behavior, to sum non-unique indices and their associated
942    gradients, is enforced by first pre-processing `grad` and `indices` and
943    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
944    with duplicate indices may instead override this method to avoid the
945    overhead of summing.
946
947    Args:
948      grad: a `Tensor` representing the gradient for the affected indices.
949      handle: a `Tensor` of dtype `resource` which points to the variable
950       to be updated.
951      indices: a `Tensor` of integral type representing the indices for
952       which the gradient is nonzero. Indices may be repeated.
953
954    Returns:
955      An `Operation` which updates the value of the variable.
956    """
957    summed_grad, unique_indices = _deduplicate_indexed_slices(
958        values=grad, indices=indices)
959    return self._resource_apply_sparse(summed_grad, handle, unique_indices)
960
961  def _resource_apply_sparse(self, grad, handle, indices):
962    """Add ops to apply sparse gradients to the variable `handle`.
963
964    Similar to `_apply_sparse`, the `indices` argument to this method has been
965    de-duplicated. Optimizers which deal correctly with non-unique indices may
966    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
967    overhead.
968
969    Args:
970      grad: a `Tensor` representing the gradient for the affected indices.
971      handle: a `Tensor` of dtype `resource` which points to the variable
972       to be updated.
973      indices: a `Tensor` of integral type representing the indices for
974       which the gradient is nonzero. Indices are unique.
975
976    Returns:
977      An `Operation` which updates the value of the variable.
978    """
979    raise NotImplementedError()
980
981  def _apply_sparse_duplicate_indices(self, grad, var):
982    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
983
984    Optimizers which override this method must deal with IndexedSlices objects
985    such as the following:
986
987      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
988
989    The correct interpretation is:
990
991      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
992
993    Many optimizers deal incorrectly with repeated indices when updating based
994    on sparse gradients (e.g. summing squares rather than squaring the sum, or
995    applying momentum terms multiple times). Adding first is always the correct
996    behavior, so this is enforced here by reconstructing the IndexedSlices to
997    have only unique indices, then calling _apply_sparse.
998
999    Optimizers which deal correctly with repeated indices may instead override
1000    this method to avoid the overhead of summing indices.
1001
1002    Args:
1003      grad: `IndexedSlices`.
1004      var: A `Variable` object.
1005
1006    Returns:
1007      An `Operation`.
1008    """
1009    summed_values, unique_indices = _deduplicate_indexed_slices(
1010        values=grad.values, indices=grad.indices)
1011    gradient_no_duplicate_indices = ops.IndexedSlices(
1012        indices=unique_indices,
1013        values=summed_values,
1014        dense_shape=grad.dense_shape)
1015    return self._apply_sparse(gradient_no_duplicate_indices, var)
1016
1017  def _apply_sparse(self, grad, var):
1018    """Add ops to apply sparse gradients to `var`.
1019
1020    The IndexedSlices object passed to `grad` in this function is by default
1021    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
1022    indices (see its docstring for details). Optimizers which can tolerate or
1023    have correct special cases for duplicate sparse indices may override
1024    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
1025    overhead.
1026
1027    Args:
1028      grad: `IndexedSlices`, with no repeated indices.
1029      var: A `Variable` object.
1030
1031    Returns:
1032      An `Operation`.
1033    """
1034    raise NotImplementedError()
1035
1036  def _finish(self, update_ops, name_scope):
1037    """Do what is needed to finish the update.
1038
1039    This is called with the `name_scope` using the "name" that
1040    users have chosen for the application of gradients.
1041
1042    Args:
1043      update_ops: List of `Operation` objects to update variables.  This list
1044        contains the values returned by the `_apply_dense()` and
1045        `_apply_sparse()` calls.
1046      name_scope: String.  Name to use for the returned operation.
1047
1048    Returns:
1049      The operation to apply updates.
1050    """
1051    return control_flow_ops.group(*update_ops, name=name_scope)
1052
1053  # --------------
1054  # Utility methods for subclasses.
1055  # --------------
1056
1057  def _slot_dict(self, slot_name):
1058    """Returns a dict for caching slots created under the given name.
1059
1060    Args:
1061      slot_name: Name for the slot.
1062
1063    Returns:
1064      A dict that maps primary `Variable` objects to the slot created
1065      for that variable, under the given slot name.
1066    """
1067    named_slots = self._slots.get(slot_name, None)
1068    if named_slots is None:
1069      named_slots = {}
1070      self._slots[slot_name] = named_slots
1071    return named_slots
1072
1073  def _get_or_make_slot(self, var, val, slot_name, op_name):
1074    """Find or create a slot for a variable.
1075
1076    Args:
1077      var: A `Variable` object.
1078      val: A `Tensor`.  The initial value of the slot.
1079      slot_name: Name for the slot.
1080      op_name: Name to use when scoping the Variable that
1081        needs to be created for the slot.
1082
1083    Returns:
1084      A `Variable` object.
1085    """
1086    named_slots = self._slot_dict(slot_name)
1087    if _var_key(var) not in named_slots:
1088      new_slot_variable = slot_creator.create_slot(var, val, op_name)
1089      self._restore_slot_variable(
1090          slot_name=slot_name, variable=var,
1091          slot_variable=new_slot_variable)
1092      named_slots[_var_key(var)] = new_slot_variable
1093    return named_slots[_var_key(var)]
1094
1095  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
1096                                         slot_name, op_name):
1097    """Find or create a slot for a variable, using an Initializer.
1098
1099    Args:
1100      var: A `Variable` object.
1101      initializer: An `Initializer`.  The initial value of the slot.
1102      shape: Shape of the initial value of the slot.
1103      dtype: Type of the value of the slot.
1104      slot_name: Name for the slot.
1105      op_name: Name to use when scoping the Variable that
1106        needs to be created for the slot.
1107
1108    Returns:
1109      A `Variable` object.
1110    """
1111    named_slots = self._slot_dict(slot_name)
1112    if _var_key(var) not in named_slots:
1113      new_slot_variable = slot_creator.create_slot_with_initializer(
1114          var, initializer, shape, dtype, op_name)
1115      self._restore_slot_variable(
1116          slot_name=slot_name, variable=var,
1117          slot_variable=new_slot_variable)
1118      named_slots[_var_key(var)] = new_slot_variable
1119    return named_slots[_var_key(var)]
1120
1121  def _zeros_slot(self, var, slot_name, op_name):
1122    """Find or create a slot initialized with 0.0.
1123
1124    Args:
1125      var: A `Variable` object.
1126      slot_name: Name for the slot.
1127      op_name: Name to use when scoping the Variable that
1128        needs to be created for the slot.
1129
1130    Returns:
1131      A `Variable` object.
1132    """
1133    named_slots = self._slot_dict(slot_name)
1134    if _var_key(var) not in named_slots:
1135      new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
1136      self._restore_slot_variable(
1137          slot_name=slot_name, variable=var,
1138          slot_variable=new_slot_variable)
1139      named_slots[_var_key(var)] = new_slot_variable
1140    return named_slots[_var_key(var)]
1141
1142  # --------------
1143  # For implementing the Trackable interface.
1144  # --------------
1145
1146  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1147    """Restore a newly created slot variable's value."""
1148    variable_key = _var_key(variable)
1149    deferred_restorations = self._deferred_slot_restorations.get(
1150        slot_name, {}).pop(variable_key, [])
1151    # Iterate over restores, highest restore UID first to minimize the number
1152    # of assignments.
1153    deferred_restorations.sort(key=lambda position: position.restore_uid,
1154                               reverse=True)
1155    for checkpoint_position in deferred_restorations:
1156      checkpoint_position.restore(slot_variable)
1157
1158  def _create_or_restore_slot_variable(
1159      self, slot_variable_position, slot_name, variable):
1160    """Restore a slot variable's value, possibly creating it.
1161
1162    Called when a variable which has an associated slot variable is created or
1163    restored. When executing eagerly, we create the slot variable with a
1164    restoring initializer.
1165
1166    No new variables are created when graph building. Instead,
1167    _restore_slot_variable catches these after normal creation and adds restore
1168    ops to the graph. This method is nonetheless important when graph building
1169    for the case when a slot variable has already been created but `variable`
1170    has just been added to a dependency graph (causing us to realize that the
1171    slot variable needs to be restored).
1172
1173    Args:
1174      slot_variable_position: A `trackable._CheckpointPosition` object
1175        indicating the slot variable `Trackable` object to be restored.
1176      slot_name: The name of this `Optimizer`'s slot to restore into.
1177      variable: The variable object this slot is being created for.
1178    """
1179    named_slots = self._slot_dict(slot_name)
1180    variable_key = _var_key(variable)
1181    slot_variable = named_slots.get(variable_key, None)
1182    if (slot_variable is None and context.executing_eagerly() and
1183        slot_variable_position.is_simple_variable()
1184        # Defer slot variable creation if there is an active variable creator
1185        # scope. Generally we'd like to eagerly create/restore slot variables
1186        # when possible, but this may mean that scopes intended to catch
1187        # `variable` also catch its eagerly created slot variable
1188        # unintentionally (specifically make_template would add a dependency on
1189        # a slot variable if not for this case). Deferring is mostly harmless
1190        # (aside from double initialization), and makes variable creator scopes
1191        # behave the same way they do when graph building.
1192        and not ops.get_default_graph()._variable_creator_stack):  # pylint: disable=protected-access
1193      initializer = trackable.CheckpointInitialValue(
1194          checkpoint_position=slot_variable_position)
1195      slot_variable = self._get_or_make_slot(
1196          var=variable,
1197          val=initializer,
1198          slot_name=slot_name,
1199          op_name=self._name)
1200      # Slot variables are not owned by any one object (because we don't want to
1201      # save the slot variable if the optimizer is saved without the non-slot
1202      # variable, or if the non-slot variable is saved without the optimizer;
1203      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1204      # variable, variable)). So we don't _track_ slot variables anywhere, and
1205      # instead special-case this dependency and otherwise pretend it's a normal
1206      # graph.
1207    if slot_variable is not None:
1208      # If we've either made this slot variable, or if we've pulled out an
1209      # existing slot variable, we should restore it.
1210      slot_variable_position.restore(slot_variable)
1211    else:
1212      # We didn't make the slot variable. Defer restoring until it gets created
1213      # normally. We keep a list rather than the one with the highest restore
1214      # UID in case slot variables have their own dependencies, in which case
1215      # those could differ between restores.
1216      self._deferred_slot_restorations.setdefault(
1217          slot_name, {}).setdefault(variable_key, []).append(
1218              slot_variable_position)
1219
1220  def _call_if_callable(self, param):
1221    """Call the function if param is callable."""
1222    return param() if callable(param) else param
1223