1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Version 2 of class Optimizer."""
16# pylint: disable=g-bad-name
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import contextlib
24import functools
25
26import six
27
28from tensorflow.python.distribute import central_storage_strategy
29from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
30from tensorflow.python.distribute import parameter_server_strategy
31from tensorflow.python.distribute import parameter_server_strategy_v2
32from tensorflow.python.distribute import values as ds_values
33from tensorflow.python.eager import backprop
34from tensorflow.python.eager import context
35from tensorflow.python.eager import monitoring
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.keras import backend
40from tensorflow.python.keras import initializers
41from tensorflow.python.keras.engine import base_layer_utils
42from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
43from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
44from tensorflow.python.keras.utils import generic_utils
45from tensorflow.python.keras.utils import layer_utils
46from tensorflow.python.keras.utils import tf_inspect
47from tensorflow.python.keras.utils import tf_utils
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import control_flow_ops
50from tensorflow.python.ops import gen_resource_variable_ops
51from tensorflow.python.ops import gradients
52from tensorflow.python.ops import math_ops
53from tensorflow.python.ops import variables as tf_variables
54from tensorflow.python.saved_model import revived_types
55from tensorflow.python.training.tracking import base as trackable
56from tensorflow.python.util import nest
57from tensorflow.python.util.tf_export import keras_export
58
59
60keras_optimizers_gauge = monitoring.BoolGauge(
61    "/tensorflow/api/keras/optimizers", "keras optimizer usage", "method")
62
63_DEFAULT_VALID_DTYPES = frozenset([
64    dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64,
65    dtypes.complex64, dtypes.complex128
66])
67
68
69def _deduplicate_indexed_slices(values, indices):
70  """Sums `values` associated with any non-unique `indices`.
71
72  Args:
73    values: A `Tensor` with rank >= 1.
74    indices: A one-dimensional integer `Tensor`, indexing into the first
75      dimension of `values` (as in an IndexedSlices object).
76
77  Returns:
78    A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
79    de-duplicated version of `indices` and `summed_values` contains the sum of
80    `values` slices associated with each unique index.
81  """
82  unique_indices, new_index_positions = array_ops.unique(indices)
83  summed_values = math_ops.unsorted_segment_sum(
84      values, new_index_positions,
85      array_ops.shape(unique_indices)[0])
86  return (summed_values, unique_indices)
87
88
89class NullContextmanager(object):
90
91  def __init__(self, *args, **kwargs):
92    pass
93
94  def __enter__(self):
95    pass
96
97  def __exit__(self, type_arg, value_arg, traceback_arg):
98    return False  # False values do not suppress exceptions
99
100
101def name_scope_only_in_function_or_graph(name):
102  """Internal-only entry point for `name_scope*`.
103
104  Enters a compat.v1.name_scope only when in a function or graph,
105  not when running fully eagerly.
106
107  Args:
108    name: The name argument that is passed to the op function.
109
110  Returns:
111    `name_scope*` context manager.
112  """
113  if not context.executing_eagerly():
114    return ops.name_scope_v1(name)
115  else:
116    return NullContextmanager()
117
118
119@six.add_metaclass(abc.ABCMeta)
120@keras_export("keras.optimizers.Optimizer")
121class OptimizerV2(trackable.Trackable):
122  """Base class for Keras optimizers.
123
124  You should not use this class directly, but instead instantiate one of its
125  subclasses such as `tf.keras.optimizers.SGD`, `tf.keras.optimizers.Adam`, etc.
126
127  ### Usage
128
129  ```python
130  # Create an optimizer with the desired parameters.
131  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
132  # `loss` is a callable that takes no argument and returns the value
133  # to minimize.
134  loss = lambda: 3 * var1 * var1 + 2 * var2 * var2
135  # In graph mode, returns op that minimizes the loss by updating the listed
136  # variables.
137  opt_op = opt.minimize(loss, var_list=[var1, var2])
138  opt_op.run()
139  # In eager mode, simply call minimize to update the list of variables.
140  opt.minimize(loss, var_list=[var1, var2])
141  ```
142
143  ### Usage in custom training loops
144
145  In Keras models, sometimes variables are created when the model is first
146  called, instead of construction time. Examples include 1) sequential models
147  without input shape pre-defined, or 2) subclassed models. Pass var_list as
148  callable in these cases.
149
150  Example:
151
152  ```python
153  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
154  model = tf.keras.Sequential()
155  model.add(tf.keras.layers.Dense(num_hidden, activation='relu'))
156  model.add(tf.keras.layers.Dense(num_classes, activation='sigmoid'))
157  loss_fn = lambda: tf.keras.losses.mse(model(input), output)
158  var_list_fn = lambda: model.trainable_weights
159  for input, output in data:
160    opt.minimize(loss_fn, var_list_fn)
161  ```
162
163  ### Processing gradients before applying them
164
165  Calling `minimize()` takes care of both computing the gradients and
166  applying them to the variables.  If you want to process the gradients
167  before applying them you can instead use the optimizer in three steps:
168
169  1.  Compute the gradients with `tf.GradientTape`.
170  2.  Process the gradients as you wish.
171  3.  Apply the processed gradients with `apply_gradients()`.
172
173  Example:
174
175  ```python
176  # Create an optimizer.
177  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
178
179  # Compute the gradients for a list of variables.
180  with tf.GradientTape() as tape:
181    loss = <call_loss_function>
182  vars = <list_of_variables>
183  grads = tape.gradient(loss, vars)
184
185  # Process the gradients, for example cap them, etc.
186  # capped_grads = [MyCapper(g) for g in grads]
187  processed_grads = [process_gradient(g) for g in grads]
188
189  # Ask the optimizer to apply the processed gradients.
190  opt.apply_gradients(zip(processed_grads, var_list))
191  ```
192
193  ### Use with `tf.distribute.Strategy`
194
195  This optimizer class is `tf.distribute.Strategy` aware, which means it
196  automatically sums gradients across all replicas. To average gradients,
197  you divide your loss by the global batch size, which is done
198  automatically if you use `tf.keras` built-in training or evaluation loops.
199  See the `reduction` argument of your loss which should be set to
200  `tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
201  `tf.keras.losses.Reduction.SUM` for not.
202
203  To aggregate gradients yourself, call `apply_gradients` with
204  `experimental_aggregate_gradients` set to False. This is useful if you need to
205  process aggregated gradients.
206
207  If you are not using these and you want to average gradients, you should use
208  `tf.math.reduce_sum` to add up your per-example losses and then divide by the
209  global batch size. Note that when using `tf.distribute.Strategy`, the first
210  component of a tensor's shape is the *replica-local* batch size, which is off
211  by a factor equal to the number of replicas being used to compute a single
212  step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
213  resulting in gradients that can be many times too big.
214
215  ### Variable Constraints
216
217  All Keras optimizers respect variable constraints. If constraint function is
218  passed to any variable, the constraint will be applied to the variable after
219  the gradient has been applied to the variable.
220  Important: If gradient is sparse tensor, variable constraint is not supported.
221
222  ### Thread Compatibility
223
224  The entire optimizer is currently thread compatible, not thread-safe. The user
225  needs to perform synchronization if necessary.
226
227  ### Slots
228
229  Many optimizer subclasses, such as `Adam` and `Adagrad` allocate and manage
230  additional variables associated with the variables to train.  These are called
231  <i>Slots</i>.  Slots have names and you can ask the optimizer for the names of
232  the slots that it uses.  Once you have a slot name you can ask the optimizer
233  for the variable it created to hold the slot value.
234
235  This can be useful if you want to log debug a training algorithm, report stats
236  about the slots, etc.
237
238  ### Hyperparameters
239
240  These are arguments passed to the optimizer subclass constructor
241  (the `__init__` method), and then passed to `self._set_hyper()`.
242  They can be either regular Python values (like 1.0), tensors, or
243  callables. If they are callable, the callable will be called during
244  `apply_gradients()` to get the value for the hyper parameter.
245
246  Hyperparameters can be overwritten through user code:
247
248  Example:
249
250  ```python
251  # Create an optimizer with the desired parameters.
252  opt = tf.keras.optimizers.SGD(learning_rate=0.1)
253  # `loss` is a callable that takes no argument and returns the value
254  # to minimize.
255  loss = lambda: 3 * var1 + 2 * var2
256  # In eager mode, simply call minimize to update the list of variables.
257  opt.minimize(loss, var_list=[var1, var2])
258  # update learning rate
259  opt.learning_rate = 0.05
260  opt.minimize(loss, var_list=[var1, var2])
261  ```
262
263  ### Callable learning rate
264
265  Optimizer accepts a callable learning rate in two ways. The first way is
266  through built-in or customized
267  `tf.keras.optimizers.schedules.LearningRateSchedule`. The schedule will be
268  called on each iteration with `schedule(iteration)`, a `tf.Variable`
269  owned by the optimizer.
270
271  Example:
272
273  >>> var = tf.Variable(np.random.random(size=(1,)))
274  >>> learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
275  ... initial_learning_rate=.01, decay_steps=20, decay_rate=.1)
276  >>> opt = tf.keras.optimizers.SGD(learning_rate=learning_rate)
277  >>> loss = lambda: 3 * var
278  >>> opt.minimize(loss, var_list=[var])
279  <tf.Variable...
280
281  The second way is through a callable function that
282  does not accept any arguments.
283
284  Example:
285
286  >>> var = tf.Variable(np.random.random(size=(1,)))
287  >>> def lr_callable():
288  ...   return .1
289  >>> opt = tf.keras.optimizers.SGD(learning_rate=lr_callable)
290  >>> loss = lambda: 3 * var
291  >>> opt.minimize(loss, var_list=[var])
292  <tf.Variable...
293
294  ### Creating a custom optimizer
295
296  If you intend to create your own optimization algorithm, simply inherit from
297  this class and override the following methods:
298
299    - `_resource_apply_dense` (update variable given gradient tensor is a dense
300      `tf.Tensor`)
301    - `_resource_apply_sparse` (update variable given gradient tensor is a
302      sparse `tf.IndexedSlices`. The most common way for this to happen
303      is if you are taking the gradient through a `tf.gather`.)
304    - `_create_slots`
305      (if your optimizer algorithm requires additional variables)
306    - `get_config`
307      (serialization of the optimizer, include all hyper parameters)
308  """
309
310  # Subclasses should set this to True unless they override `apply_gradients`
311  # with a version that does not have the `experimental_aggregate_gradients`
312  # argument.  Older versions of Keras did not have this argument so custom
313  # optimizers may have overridden `apply_gradients` without the
314  # `experimental_aggregate_gradients` argument. Keras only passes
315  # `experimental_aggregate_gradients` if this attribute is True.
316  # Note: This attribute will likely be removed in an upcoming release.
317  _HAS_AGGREGATE_GRAD = False
318
319  def __init__(self,
320               name,
321               gradient_aggregator=None,
322               gradient_transformers=None,
323               **kwargs):
324    """Create a new Optimizer.
325
326    This must be called by the constructors of subclasses.
327    Note that Optimizer instances should not bind to a single graph,
328    and so shouldn't keep Tensors as member variables. Generally
329    you should be able to use the _set_hyper()/state.get_hyper()
330    facility instead.
331
332    This class is stateful and thread-compatible.
333
334    Example of custom gradient transformations:
335
336    ```python
337    def my_gradient_transformer(grads_and_vars):
338      # Simple example, double the gradients.
339      return [(2. * g, v) for g, v in grads_and_vars]
340
341    optimizer = tf.keras.optimizers.SGD(
342        1e-3, gradient_transformers=[my_gradient_transformer])
343    ```
344
345    Args:
346      name: String. The name to use for momentum accumulator weights created
347        by the optimizer.
348      gradient_aggregator: The function to use to aggregate gradients across
349        devices (when using `tf.distribute.Strategy`). If `None`, defaults to
350        summing the gradients across devices. The function should accept and
351        return a list of `(gradient, variable)` tuples.
352      gradient_transformers: Optional. List of functions to use to transform
353        gradients before applying updates to Variables. The functions are
354        applied after `gradient_aggregator`. The functions should accept and
355        return a list of `(gradient, variable)` tuples.
356      **kwargs: keyword arguments. Allowed arguments are `clipvalue`,
357        `clipnorm`, `global_clipnorm`.
358        If `clipvalue` (float) is set, the gradient of each weight
359        is clipped to be no higher than this value.
360        If `clipnorm` (float) is set, the gradient of each weight
361        is individually clipped so that its norm is no higher than this value.
362        If `global_clipnorm` (float) is set the gradient of all weights is
363        clipped so that their global norm is no higher than this value.
364
365    Raises:
366      ValueError: in case of any invalid argument.
367    """
368    # Instrument optimizer usages
369    keras_optimizers_gauge.get_cell(self.__class__.__name__).set(True)
370
371    allowed_kwargs = {"clipnorm", "clipvalue", "lr", "decay", "global_clipnorm"}
372    for k in kwargs:
373      if k not in allowed_kwargs:
374        raise TypeError("Unexpected keyword argument "
375                        "passed to optimizer: " + str(k))
376      # checks that all keyword arguments are non-negative.
377      if kwargs[k] is not None and kwargs[k] < 0:
378        raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
379
380    self._use_locking = True
381    self._init_set_name(name)
382    self._hyper = {}
383    # dict: {variable name : {slot name : variable}}
384    self._slots = {}
385    self._slot_names = []
386    self._weights = []
387    self._iterations = None
388
389    # For implementing Trackable. Stores information about how to restore
390    # slot variables which have not yet been created
391    # (trackable._CheckpointPosition objects).
392    #  {slot_name :
393    #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
394    #   ... }
395    self._deferred_slot_restorations = {}
396
397    decay = kwargs.pop("decay", 0.0)
398    if decay < 0.:
399      raise ValueError("decay cannot be less than 0: {}".format(decay))
400    self._initial_decay = decay
401
402    self._hypers_created = False
403    # Store the distribution strategy object if the optimizer is created inside
404    # strategy scope, so it could be used to create variables later.
405    if distribute_ctx.has_strategy():
406      self._distribution_strategy = distribute_ctx.get_strategy()
407    else:
408      self._distribution_strategy = None
409
410    # Configure gradient transformations.
411    if gradient_aggregator is None:
412      gradient_aggregator = optimizer_utils.all_reduce_sum_gradients
413    self.gradient_aggregator = gradient_aggregator
414    if gradient_transformers is None:
415      gradient_transformers = []
416    self.gradient_transformers = gradient_transformers
417    self.clipnorm = kwargs.pop("clipnorm", None)
418    self.global_clipnorm = kwargs.pop("global_clipnorm", None)
419    if self.clipnorm is not None and self.global_clipnorm is not None:
420      raise ValueError("Cannot accept both `clipnorm` and `global_clipnorm`, "
421                       "passed `clipnorm` {}, `global_clipnorm` {}".format(
422                           self.clipnorm, self.global_clipnorm))
423    self.clipvalue = kwargs.pop("clipvalue", None)
424
425  @property
426  def clipnorm(self):
427    """`float` or `None`. If set, clips gradients to a maximum norm."""
428    return self._clipnorm
429
430  @property
431  def global_clipnorm(self):
432    """`float` or `None`. If set, clips gradients to a maximum norm."""
433    return self._global_clipnorm
434
435  @clipnorm.setter
436  def clipnorm(self, val):
437    if val is not None and self.gradient_transformers:
438      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
439                       "is set. Instead, use the `gradient_transformers` to "
440                       "specify clipping and other transformations.")
441    self._clipnorm = val
442    self._clipnorm_fn = optimizer_utils.make_gradient_clipnorm_fn(
443        self._clipnorm)
444
445  @global_clipnorm.setter
446  def global_clipnorm(self, val):
447    if val is not None and self.gradient_transformers:
448      raise ValueError("`clipnorm` cannot be set when `gradient_transformers` "
449                       "is set. Instead, use the `gradient_transformers` to "
450                       "specify clipping and other transformations.")
451    self._global_clipnorm = val
452    self._global_clipnorm_fn = optimizer_utils.make_global_gradient_clipnorm_fn(
453        self._global_clipnorm)
454
455  @property
456  def clipvalue(self):
457    """`float` or `None`. If set, clips gradients to a maximum value."""
458    return self._clipvalue
459
460  @clipvalue.setter
461  def clipvalue(self, val):
462    if val is not None and self.gradient_transformers:
463      raise ValueError("`clipvalue` cannot be set when `gradient_transformers` "
464                       "is set. Instead, use the `gradient_transformers` to "
465                       "specify clipping and other transformations.")
466    self._clipvalue = val
467    self._clipvalue_fn = optimizer_utils.make_gradient_clipvalue_fn(
468        self._clipvalue)
469
470  def _transform_loss(self, loss):
471    """Called in `.minimize` to transform loss before computing gradients."""
472    return loss
473
474  def _get_gradients(self, tape, loss, var_list, grad_loss=None):
475    """Called in `minimize` to compute gradients from loss."""
476    grads = tape.gradient(loss, var_list, grad_loss)
477    return list(zip(grads, var_list))
478
479  def _transform_unaggregated_gradients(self, grads_and_vars):
480    """Called in `apply_gradients` before gradient aggregation."""
481    return grads_and_vars
482
483  def _aggregate_gradients(self, grads_and_vars):
484    """Called in `apply_gradients` to aggregate gradients across devices."""
485    return self.gradient_aggregator(grads_and_vars)
486
487  def _transform_gradients(self, grads_and_vars):
488    """Called in `apply_gradients` after aggregation."""
489    if self._clipvalue is not None:
490      grads_and_vars = self._clipvalue_fn(grads_and_vars)
491    if self._clipnorm is not None:
492      grads_and_vars = self._clipnorm_fn(grads_and_vars)
493    if self._global_clipnorm is not None:
494      grads_and_vars = self._global_clipnorm_fn(grads_and_vars)
495
496    for fn in self.gradient_transformers:
497      grads_and_vars = fn(grads_and_vars)
498    return grads_and_vars
499
500  def minimize(self, loss, var_list, grad_loss=None, name=None, tape=None):
501    """Minimize `loss` by updating `var_list`.
502
503    This method simply computes gradient using `tf.GradientTape` and calls
504    `apply_gradients()`. If you want to process the gradient before applying
505    then call `tf.GradientTape` and `apply_gradients()` explicitly instead
506    of using this function.
507
508    Args:
509      loss: `Tensor` or callable. If a callable, `loss` should take no arguments
510        and return the value to minimize. If a `Tensor`, the `tape` argument
511        must be passed.
512      var_list: list or tuple of `Variable` objects to update to minimize
513        `loss`, or a callable returning the list or tuple of `Variable` objects.
514        Use callable when the variable list would otherwise be incomplete before
515        `minimize` since the variables are created at the first time `loss` is
516        called.
517      grad_loss: (Optional). A `Tensor` holding the gradient computed for
518        `loss`.
519      name: (Optional) str. Name for the returned operation.
520      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
521        the tape that computed the `loss` must be provided.
522
523    Returns:
524      An `Operation` that updates the variables in `var_list`. The `iterations`
525      will be automatically increased by 1.
526
527    Raises:
528      ValueError: If some of the variables are not `Variable` objects.
529
530    """
531    grads_and_vars = self._compute_gradients(
532        loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
533    return self.apply_gradients(grads_and_vars, name=name)
534
535  def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
536    """Compute gradients of `loss` for the variables in `var_list`.
537
538    This is the first part of `minimize()`.  It returns a list
539    of (gradient, variable) pairs where "gradient" is the gradient
540    for "variable".  Note that "gradient" can be a `Tensor`, an
541    `IndexedSlices`, or `None` if there is no gradient for the
542    given variable.
543
544    Args:
545      loss: `Tensor` or callable. If a callable, `loss` should take no
546        arguments and return the value to minimize. If a `Tensor`, the `tape`
547        argument must be passed.
548      var_list: list or tuple of `Variable` objects to update to minimize
549        `loss`, or a callable returning the list or tuple of `Variable` objects.
550        Use callable when the variable list would otherwise be incomplete before
551        `minimize` and the variables are created at the first time when `loss`
552        is called.
553      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
554      tape: (Optional) `tf.GradientTape`. If `loss` is provided as a `Tensor`,
555        the tape that computed the `loss` must be provided.
556
557    Returns:
558      A list of (gradient, variable) pairs. Variable is always present, but
559      gradient can be `None`.
560
561    Raises:
562      TypeError: If `var_list` contains anything else than `Variable` objects.
563      ValueError: If some arguments are invalid, or var_list is None.
564    """
565    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
566    if not callable(loss) and tape is None:
567      raise ValueError("`tape` is required when a `Tensor` loss is passed.")
568    tape = tape if tape is not None else backprop.GradientTape()
569
570    if callable(loss):
571      with tape:
572        if not callable(var_list):
573          tape.watch(var_list)
574        loss = loss()
575        if callable(var_list):
576          var_list = var_list()
577
578    with tape:
579      loss = self._transform_loss(loss)
580
581    var_list = nest.flatten(var_list)
582    with ops.name_scope_v2(self._name + "/gradients"):
583      grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
584
585    self._assert_valid_dtypes([
586        v for g, v in grads_and_vars
587        if g is not None and v.dtype != dtypes.resource
588    ])
589
590    return grads_and_vars
591
592  def apply_gradients(self,
593                      grads_and_vars,
594                      name=None,
595                      experimental_aggregate_gradients=True):
596    """Apply gradients to variables.
597
598    This is the second part of `minimize()`. It returns an `Operation` that
599    applies gradients.
600
601    The method sums gradients from all replicas in the presence of
602    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
603    passing `experimental_aggregate_gradients=False`.
604
605    Example:
606
607    ```python
608    grads = tape.gradient(loss, vars)
609    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
610    # Processing aggregated gradients.
611    optimizer.apply_gradients(zip(grads, vars),
612        experimental_aggregate_gradients=False)
613
614    ```
615
616    Args:
617      grads_and_vars: List of (gradient, variable) pairs.
618      name: Optional name for the returned operation. Default to the name passed
619        to the `Optimizer` constructor.
620      experimental_aggregate_gradients: Whether to sum gradients from different
621        replicas in the presense of `tf.distribute.Strategy`. If False, it's
622        user responsibility to aggregate the gradients. Default to True.
623
624    Returns:
625      An `Operation` that applies the specified gradients. The `iterations`
626      will be automatically increased by 1.
627
628    Raises:
629      TypeError: If `grads_and_vars` is malformed.
630      ValueError: If none of the variables have gradients.
631      RuntimeError: If called in a cross-replica context.
632    """
633    grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
634    var_list = [v for (_, v) in grads_and_vars]
635
636    with ops.name_scope_v2(self._name):
637      # Create iteration if necessary.
638      with ops.init_scope():
639        self._create_all_weights(var_list)
640
641      if not grads_and_vars:
642        # Distribution strategy does not support reducing an empty list of
643        # gradients
644        return control_flow_ops.no_op()
645
646      if distribute_ctx.in_cross_replica_context():
647        raise RuntimeError(
648            "`apply_gradients() cannot be called in cross-replica context. "
649            "Use `tf.distribute.Strategy.run` to enter replica "
650            "context.")
651
652      strategy = distribute_ctx.get_strategy()
653      if (not experimental_aggregate_gradients and strategy and
654          isinstance(strategy,
655                     (parameter_server_strategy.ParameterServerStrategyV1,
656                      parameter_server_strategy_v2.ParameterServerStrategyV2,
657                      central_storage_strategy.CentralStorageStrategy,
658                      central_storage_strategy.CentralStorageStrategyV1))):
659        raise NotImplementedError(
660            "`experimental_aggregate_gradients=False is not supported for "
661            "ParameterServerStrategy and CentralStorageStrategy")
662
663      apply_state = self._prepare(var_list)
664      if experimental_aggregate_gradients:
665        grads_and_vars = self._transform_unaggregated_gradients(grads_and_vars)
666        grads_and_vars = self._aggregate_gradients(grads_and_vars)
667      grads_and_vars = self._transform_gradients(grads_and_vars)
668
669      return distribute_ctx.get_replica_context().merge_call(
670          functools.partial(self._distributed_apply, apply_state=apply_state),
671          args=(grads_and_vars,),
672          kwargs={
673              "name": name,
674          })
675
676  def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
677    """`apply_gradients` using a `DistributionStrategy`."""
678
679    def apply_grad_to_update_var(var, grad):
680      """Apply gradient to variable."""
681      if isinstance(var, ops.Tensor):
682        raise NotImplementedError("Trying to update a Tensor ", var)
683
684      apply_kwargs = {}
685      if isinstance(grad, ops.IndexedSlices):
686        if var.constraint is not None:
687          raise RuntimeError(
688              "Cannot use a constraint function on a sparse variable.")
689        if "apply_state" in self._sparse_apply_args:
690          apply_kwargs["apply_state"] = apply_state
691        return self._resource_apply_sparse_duplicate_indices(
692            grad.values, var, grad.indices, **apply_kwargs)
693
694      if "apply_state" in self._dense_apply_args:
695        apply_kwargs["apply_state"] = apply_state
696      update_op = self._resource_apply_dense(grad, var, **apply_kwargs)
697      if var.constraint is not None:
698        with ops.control_dependencies([update_op]):
699          return var.assign(var.constraint(var))
700      else:
701        return update_op
702
703    eagerly_outside_functions = ops.executing_eagerly_outside_functions()
704    update_ops = []
705    with name_scope_only_in_function_or_graph(name or self._name):
706      for grad, var in grads_and_vars:
707        # TODO(crccw): It's not allowed to assign PerReplica value to
708        # MirroredVariable.  Remove this after we relax this restriction.
709        def _assume_mirrored(grad):
710          if isinstance(grad, ds_values.PerReplica):
711            return ds_values.Mirrored(grad.values)
712          return grad
713
714        grad = nest.map_structure(_assume_mirrored, grad)
715        # Colocate the update with variables to avoid unnecessary communication
716        # delays. See b/136304694.
717        with distribution.extended.colocate_vars_with(var):
718          with name_scope_only_in_function_or_graph(
719              "update" if eagerly_outside_functions else "update_" +
720              var.op.name):
721            update_ops.extend(distribution.extended.update(
722                var, apply_grad_to_update_var, args=(grad,), group=False))
723
724      any_symbolic = any(isinstance(i, ops.Operation) or
725                         tf_utils.is_symbolic_tensor(i) for i in update_ops)
726      if not context.executing_eagerly() or any_symbolic:
727        # If the current context is graph mode or any of the update ops are
728        # symbolic then the step update should be carried out under a graph
729        # context. (eager updates execute immediately)
730        with backend._current_graph(update_ops).as_default():  # pylint: disable=protected-access
731          with ops.control_dependencies([control_flow_ops.group(update_ops)]):
732            return self._iterations.assign_add(1, read_value=False)
733
734      return self._iterations.assign_add(1)
735
736  def get_gradients(self, loss, params):
737    """Returns gradients of `loss` with respect to `params`.
738
739    Should be used only in legacy v1 graph mode.
740
741    Args:
742      loss: Loss tensor.
743      params: List of variables.
744
745    Returns:
746      List of gradient tensors.
747
748    Raises:
749      ValueError: In case any gradient cannot be computed (e.g. if gradient
750        function not implemented).
751    """
752    params = nest.flatten(params)
753    with backend.get_graph().as_default(), backend.name_scope(self._name +
754                                                              "/gradients"):
755      grads = gradients.gradients(loss, params)
756      for grad, param in zip(grads, params):
757        if grad is None:
758          raise ValueError("Variable {} has `None` for gradient. "
759                           "Please make sure that all of your ops have a "
760                           "gradient defined (i.e. are differentiable). "
761                           "Common ops without gradient: "
762                           "K.argmax, K.round, K.eval.".format(param))
763    return grads
764
765  def get_updates(self, loss, params):
766    grads = self.get_gradients(loss, params)
767    grads_and_vars = list(zip(grads, params))
768    self._assert_valid_dtypes([
769        v for g, v in grads_and_vars
770        if g is not None and v.dtype != dtypes.resource
771    ])
772    return [self.apply_gradients(grads_and_vars)]
773
774  def _set_hyper(self, name, value):
775    """set hyper `name` to value. value can be callable, tensor, numeric."""
776    if isinstance(value, trackable.Trackable):
777      self._track_trackable(value, name, overwrite=True)
778    if name not in self._hyper:
779      self._hyper[name] = value
780    else:
781      prev_value = self._hyper[name]
782      if (callable(prev_value)
783          or isinstance(prev_value,
784                        (ops.Tensor, int, float,
785                         learning_rate_schedule.LearningRateSchedule))
786          or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
787        self._hyper[name] = value
788      else:
789        backend.set_value(self._hyper[name], value)
790
791  def _get_hyper(self, name, dtype=None):
792    if not self._hypers_created:
793      self._create_hypers()
794    value = self._hyper[name]
795    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
796      return value
797    if callable(value):
798      value = value()
799    if dtype:
800      return math_ops.cast(value, dtype)
801    else:
802      return value
803
804  def _create_slots(self, var_list):
805    pass
806
807  def _create_all_weights(self, var_list):
808    """Creates all weights, including iterations, hyperparameters and slot vars.
809
810    This will add newly created variables to `optimizer.weights`.
811
812    New variables are only created when this method is called the first time, or
813    when called with different variables in the var_list.
814
815    Args:
816      var_list: list or tuple of `Variable` objects that will be minimized
817        using this optimizer.
818    """
819
820    _ = self.iterations
821    self._create_hypers()
822    self._create_slots(var_list)
823
824  def __getattribute__(self, name):
825    """Overridden to support hyperparameter access."""
826    try:
827      return super(OptimizerV2, self).__getattribute__(name)
828    except AttributeError as e:
829      # Needed to avoid infinite recursion with __setattr__.
830      if name == "_hyper":
831        raise e
832      # Backwards compatibility with Keras optimizers.
833      if name == "lr":
834        name = "learning_rate"
835      if name in self._hyper:
836        return self._get_hyper(name)
837      raise e
838
839  def __dir__(self):
840    result = set(super(OptimizerV2, self).__dir__())
841    if "_hyper" in result:
842      result |= self._hyper.keys()
843      if "learning_rate" in self._hyper.keys():
844        result.add("lr")
845    return list(result)
846
847  def __setattr__(self, name, value):
848    """Override setattr to support dynamic hyperparameter setting."""
849    # Backwards compatibility with Keras optimizers.
850    if name == "lr":
851      name = "learning_rate"
852    if hasattr(self, "_hyper") and name in self._hyper:
853      self._set_hyper(name, value)
854    else:
855      super(OptimizerV2, self).__setattr__(name, value)
856
857  def get_slot_names(self):
858    """A list of names for this optimizer's slots."""
859    return self._slot_names
860
861  def add_slot(self, var, slot_name, initializer="zeros", shape=None):
862    """Add a new slot variable for `var`.
863
864    A slot variable is an additional variable associated with `var` to train.
865    It is allocated and managed by optimizers, e.g. `Adam`.
866
867    Args:
868      var: a `Variable` object.
869      slot_name: name of the slot variable.
870      initializer: initializer of the slot variable
871      shape: (Optional) shape of the slot variable. If not set, it will default
872      to the shape of `var`.
873
874    Returns:
875      A slot variable.
876    """
877    if slot_name not in self._slot_names:
878      self._slot_names.append(slot_name)
879    var_key = _var_key(var)
880    slot_dict = self._slots.setdefault(var_key, {})
881    weight = slot_dict.get(slot_name, None)
882    if weight is None:
883      if isinstance(initializer, six.string_types) or callable(initializer):
884        initializer = initializers.get(initializer)
885        if isinstance(
886            initializer,
887            trackable.CheckpointInitialValueCallable) or (shape is not None):
888          slot_shape = shape
889        else:
890          slot_shape = var.shape
891        initial_value = functools.partial(
892            initializer, shape=slot_shape, dtype=var.dtype)
893      else:
894        initial_value = initializer
895
896      with self._distribution_strategy_scope():
897        strategy = distribute_ctx.get_strategy()
898        if not strategy.extended.variable_created_in_scope(var):
899          raise ValueError(
900              "Trying to create optimizer slot variable under the scope for "
901              "tf.distribute.Strategy ({}), which is different from the scope "
902              "used for the original variable ({}). Make sure the slot "
903              "variables are created under the same strategy scope. This may "
904              "happen if you're restoring from a checkpoint outside the scope"
905              .format(strategy, var))
906
907        with strategy.extended.colocate_vars_with(var):
908          weight = tf_variables.Variable(
909              name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
910              dtype=var.dtype,
911              trainable=False,
912              initial_value=initial_value)
913      backend.track_variable(weight)
914      slot_dict[slot_name] = weight
915      self._restore_slot_variable(
916          slot_name=slot_name, variable=var,
917          slot_variable=weight)
918      self._weights.append(weight)
919    return weight
920
921  def get_slot(self, var, slot_name):
922    var_key = _var_key(var)
923    slot_dict = self._slots[var_key]
924    return slot_dict[slot_name]
925
926  def _prepare(self, var_list):
927    keys = set()
928    for var in var_list:
929      if isinstance(var, ds_values.DistributedValues):
930        var_devices = var._devices   # pylint: disable=protected-access
931      else:
932        var_devices = [var.device]
933      var_dtype = var.dtype.base_dtype
934      for var_device in var_devices:
935        keys.add((var_device, var_dtype))
936
937    apply_state = {}
938    for var_device, var_dtype in keys:
939      apply_state[(var_device, var_dtype)] = {}
940      with ops.device(var_device):
941        self._prepare_local(var_device, var_dtype, apply_state)
942
943    return apply_state
944
945  def _prepare_local(self, var_device, var_dtype, apply_state):
946    if "learning_rate" in self._hyper:
947      lr_t = array_ops.identity(self._decayed_lr(var_dtype))
948      apply_state[(var_device, var_dtype)]["lr_t"] = lr_t
949
950  def _fallback_apply_state(self, var_device, var_dtype):
951    """Compatibility for subclasses that don't pass apply_state through."""
952    apply_state = {(var_device, var_dtype): {}}
953    self._prepare_local(var_device, var_dtype, apply_state)
954    return apply_state[(var_device, var_dtype)]
955
956  def _create_hypers(self):
957    if self._hypers_created:
958      return
959    with self._distribution_strategy_scope():
960      # Iterate hyper values deterministically.
961      for name, value in sorted(self._hyper.items()):
962        if isinstance(value,
963                      (ops.Tensor, tf_variables.Variable)) or callable(value):
964          # The check for `callable` covers the usage when `value` is a
965          # `LearningRateSchedule`, in which case it does not need to create a
966          # variable.
967          continue
968        else:
969          self._hyper[name] = self.add_weight(
970              name,
971              shape=[],
972              trainable=False,
973              initializer=value,
974              aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
975    self._hypers_created = True
976
977  @property
978  def iterations(self):
979    """Variable. The number of training steps this Optimizer has run."""
980    if self._iterations is None:
981      with self._distribution_strategy_scope():
982        self._iterations = self.add_weight(
983            "iter",
984            shape=[],
985            dtype=dtypes.int64,
986            trainable=False,
987            aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
988      self._weights.append(self._iterations)
989    return self._iterations
990
991  @iterations.setter
992  def iterations(self, variable):
993    if self._iterations is not None:
994      raise RuntimeError("Cannot set `iterations` to a new Variable after "
995                         "the Optimizer weights have been created")
996    self._iterations = variable
997    self._weights.append(self._iterations)
998
999  def _decayed_lr(self, var_dtype):
1000    """Get decayed learning rate as a Tensor with dtype=var_dtype."""
1001    lr_t = self._get_hyper("learning_rate", var_dtype)
1002    if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
1003      local_step = math_ops.cast(self.iterations, var_dtype)
1004      lr_t = math_ops.cast(lr_t(local_step), var_dtype)
1005    if self._initial_decay > 0.:
1006      local_step = math_ops.cast(self.iterations, var_dtype)
1007      decay_t = math_ops.cast(self._initial_decay, var_dtype)
1008      lr_t = lr_t / (1. + decay_t * local_step)
1009    return lr_t
1010
1011  @abc.abstractmethod
1012  def get_config(self):
1013    """Returns the config of the optimizer.
1014
1015    An optimizer config is a Python dictionary (serializable)
1016    containing the configuration of an optimizer.
1017    The same optimizer can be reinstantiated later
1018    (without any saved state) from this configuration.
1019
1020    Returns:
1021        Python dictionary.
1022    """
1023    config = {"name": self._name}
1024    if self.clipnorm is not None:
1025      config["clipnorm"] = self.clipnorm
1026    if self.clipvalue is not None:
1027      config["clipvalue"] = self.clipvalue
1028    if self.global_clipnorm is not None:
1029      config["global_clipnorm"] = self.global_clipnorm
1030    return config
1031
1032  @classmethod
1033  def from_config(cls, config, custom_objects=None):
1034    """Creates an optimizer from its config.
1035
1036    This method is the reverse of `get_config`,
1037    capable of instantiating the same optimizer from the config
1038    dictionary.
1039
1040    Args:
1041        config: A Python dictionary, typically the output of get_config.
1042        custom_objects: A Python dictionary mapping names to additional Python
1043          objects used to create this optimizer, such as a function used for a
1044          hyperparameter.
1045
1046    Returns:
1047        An optimizer instance.
1048    """
1049    if "lr" in config:
1050      config["learning_rate"] = config.pop("lr")
1051    if "learning_rate" in config:
1052      if isinstance(config["learning_rate"], dict):
1053        config["learning_rate"] = learning_rate_schedule.deserialize(
1054            config["learning_rate"], custom_objects=custom_objects)
1055    return cls(**config)
1056
1057  def _serialize_hyperparameter(self, hyperparameter_name):
1058    """Serialize a hyperparameter that can be a float, callable, or Tensor."""
1059    value = self._hyper[hyperparameter_name]
1060    if isinstance(value, learning_rate_schedule.LearningRateSchedule):
1061      return learning_rate_schedule.serialize(value)
1062    if callable(value):
1063      return value()
1064    if tensor_util.is_tf_type(value):
1065      return backend.get_value(value)
1066    return value
1067
1068  def variables(self):
1069    """Returns variables of this Optimizer based on the order created."""
1070    return self._weights
1071
1072  @property
1073  def weights(self):
1074    """Returns variables of this Optimizer based on the order created."""
1075    return self._weights
1076
1077  def get_weights(self):
1078    """Returns the current weights of the optimizer.
1079
1080    The weights of an optimizer are its state (ie, variables).
1081    This function returns the weight values associated with this
1082    optimizer as a list of Numpy arrays. The first value is always the
1083    iterations count of the optimizer, followed by the optimizer's state
1084    variables in the order they were created. The returned list can in turn
1085    be used to load state into similarly parameterized optimizers.
1086
1087    For example, the RMSprop optimizer for this simple model returns a list of
1088    three values-- the iteration count, followed by the root-mean-square value
1089    of the kernel and bias of the single Dense layer:
1090
1091    >>> opt = tf.keras.optimizers.RMSprop()
1092    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1093    >>> m.compile(opt, loss='mse')
1094    >>> data = np.arange(100).reshape(5, 20)
1095    >>> labels = np.zeros(5)
1096    >>> print('Training'); results = m.fit(data, labels)
1097    Training ...
1098    >>> len(opt.get_weights())
1099    3
1100
1101    Returns:
1102        Weights values as a list of numpy arrays.
1103    """
1104    params = self.weights
1105    return backend.batch_get_value(params)
1106
1107  # TODO(tanzheny): Maybe share this logic with base_layer.
1108  def set_weights(self, weights):
1109    """Set the weights of the optimizer.
1110
1111    The weights of an optimizer are its state (ie, variables).
1112    This function takes the weight values associated with this
1113    optimizer as a list of Numpy arrays. The first value is always the
1114    iterations count of the optimizer, followed by the optimizer's state
1115    variables in the order they are created. The passed values are used to set
1116    the new state of the optimizer.
1117
1118    For example, the RMSprop optimizer for this simple model takes a list of
1119    three values-- the iteration count, followed by the root-mean-square value
1120    of the kernel and bias of the single Dense layer:
1121
1122    >>> opt = tf.keras.optimizers.RMSprop()
1123    >>> m = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
1124    >>> m.compile(opt, loss='mse')
1125    >>> data = np.arange(100).reshape(5, 20)
1126    >>> labels = np.zeros(5)
1127    >>> print('Training'); results = m.fit(data, labels)
1128    Training ...
1129    >>> new_weights = [np.array(10), np.ones([20, 10]), np.zeros([10])]
1130    >>> opt.set_weights(new_weights)
1131    >>> opt.iterations
1132    <tf.Variable 'RMSprop/iter:0' shape=() dtype=int64, numpy=10>
1133
1134    Args:
1135        weights: weight values as a list of numpy arrays.
1136    """
1137    params = self.weights
1138    if len(params) != len(weights):
1139      raise ValueError(
1140          "You called `set_weights(weights)` on optimizer " + self._name +
1141          " with a  weight list of length " + str(len(weights)) +
1142          ", but the optimizer was expecting " + str(len(params)) +
1143          " weights. Provided weights: " + str(weights)[:50] + "...")
1144    if not params:
1145      return
1146    weight_value_tuples = []
1147    param_values = backend.batch_get_value(params)
1148    for pv, p, w in zip(param_values, params, weights):
1149      if pv.shape != w.shape:
1150        raise ValueError("Optimizer weight shape " + str(pv.shape) +
1151                         " not compatible with "
1152                         "provided weight shape " + str(w.shape))
1153      weight_value_tuples.append((p, w))
1154    backend.batch_set_value(weight_value_tuples)
1155
1156  def add_weight(self,
1157                 name,
1158                 shape,
1159                 dtype=None,
1160                 initializer="zeros",
1161                 trainable=None,
1162                 synchronization=tf_variables.VariableSynchronization.AUTO,
1163                 aggregation=tf_variables.VariableAggregation.NONE):
1164
1165    if dtype is None:
1166      dtype = dtypes.float32
1167    if isinstance(initializer, six.string_types) or callable(initializer):
1168      initializer = initializers.get(initializer)
1169
1170    if synchronization == tf_variables.VariableSynchronization.ON_READ:
1171      if trainable:
1172        raise ValueError(
1173            "Synchronization value can be set to "
1174            "VariableSynchronization.ON_READ only for non-trainable variables. "
1175            "You have specified trainable=True and "
1176            "synchronization=VariableSynchronization.ON_READ.")
1177      else:
1178        # Set trainable to be false when variable is to be synced on read.
1179        trainable = False
1180    elif trainable is None:
1181      trainable = True
1182
1183    variable = self._add_variable_with_custom_getter(
1184        name=name,
1185        shape=shape,
1186        getter=base_layer_utils.make_variable,
1187        overwrite=True,
1188        initializer=initializer,
1189        dtype=dtype,
1190        trainable=trainable,
1191        use_resource=True,
1192        synchronization=synchronization,
1193        aggregation=aggregation)
1194    backend.track_variable(variable)
1195
1196    return variable
1197
1198  def _init_set_name(self, name, zero_based=True):
1199    if not name:
1200      self._name = backend.unique_object_name(
1201          generic_utils.to_snake_case(self.__class__.__name__),
1202          zero_based=zero_based)
1203    else:
1204      self._name = name
1205
1206  def _assert_valid_dtypes(self, tensors):
1207    """Asserts tensors are all valid types (see `_valid_dtypes`).
1208
1209    Args:
1210      tensors: Tensors to check.
1211
1212    Raises:
1213      ValueError: If any tensor is not a valid type.
1214    """
1215    valid_dtypes = self._valid_dtypes()
1216    for t in tensors:
1217      dtype = t.dtype.base_dtype
1218      if dtype not in valid_dtypes:
1219        raise ValueError("Invalid type %r for %s, expected: %s." %
1220                         (dtype, t.name, [v for v in valid_dtypes]))
1221
1222  def _valid_dtypes(self):
1223    """Valid types for loss, variables and gradients.
1224
1225    Subclasses should override to allow other float types.
1226
1227    Returns:
1228      Valid types for loss, variables and gradients.
1229    """
1230    return _DEFAULT_VALID_DTYPES
1231
1232  def _call_if_callable(self, param):
1233    """Call the function if param is callable."""
1234    return param() if callable(param) else param
1235
1236  def _resource_apply_dense(self, grad, handle, apply_state):
1237    """Add ops to apply dense gradients to the variable `handle`.
1238
1239    Args:
1240      grad: a `Tensor` representing the gradient.
1241      handle: a `Tensor` of dtype `resource` which points to the variable to be
1242        updated.
1243      apply_state: A dict which is used across multiple apply calls.
1244
1245    Returns:
1246      An `Operation` which updates the value of the variable.
1247    """
1248    raise NotImplementedError("Must be implemented in subclasses.")
1249
1250  def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices,
1251                                               **kwargs):
1252    """Add ops to apply sparse gradients to `handle`, with repeated indices.
1253
1254    Optimizers which override this method must deal with repeated indices. See
1255    the docstring of `_apply_sparse_duplicate_indices` for details. By default
1256    the correct behavior, to sum non-unique indices and their associated
1257    gradients, is enforced by first pre-processing `grad` and `indices` and
1258    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
1259    with duplicate indices may instead override this method to avoid the
1260    overhead of summing.
1261
1262    Args:
1263      grad: a `Tensor` representing the gradient for the affected indices.
1264      handle: a `Tensor` of dtype `resource` which points to the variable to be
1265        updated.
1266      indices: a `Tensor` of integral type representing the indices for which
1267        the gradient is nonzero. Indices may be repeated.
1268      **kwargs: May optionally contain `apply_state`
1269
1270    Returns:
1271      An `Operation` which updates the value of the variable.
1272    """
1273    summed_grad, unique_indices = _deduplicate_indexed_slices(
1274        values=grad, indices=indices)
1275    return self._resource_apply_sparse(summed_grad, handle, unique_indices,
1276                                       **kwargs)
1277
1278  def _resource_apply_sparse(self, grad, handle, indices, apply_state):
1279    """Add ops to apply sparse gradients to the variable `handle`.
1280
1281    Similar to `_apply_sparse`, the `indices` argument to this method has been
1282    de-duplicated. Optimizers which deal correctly with non-unique indices may
1283    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
1284    overhead.
1285
1286    Args:
1287      grad: a `Tensor` representing the gradient for the affected indices.
1288      handle: a `Tensor` of dtype `resource` which points to the variable to be
1289        updated.
1290      indices: a `Tensor` of integral type representing the indices for which
1291        the gradient is nonzero. Indices are unique.
1292      apply_state: A dict which is used across multiple apply calls.
1293
1294    Returns:
1295      An `Operation` which updates the value of the variable.
1296    """
1297    raise NotImplementedError("Must be implemented in subclasses.")
1298
1299  def _resource_scatter_add(self, x, i, v):
1300    with ops.control_dependencies([
1301        gen_resource_variable_ops.ResourceScatterAdd(
1302            resource=x.handle, indices=i, updates=v)
1303    ]):
1304      return x.value()
1305
1306  def _resource_scatter_update(self, x, i, v):
1307    with ops.control_dependencies(
1308        [gen_resource_variable_ops.ResourceScatterUpdate(
1309            resource=x.handle, indices=i, updates=v)]):
1310      return x.value()
1311
1312  @property
1313  @layer_utils.cached_per_instance
1314  def _dense_apply_args(self):
1315    return tf_inspect.getfullargspec(self._resource_apply_dense).args
1316
1317  @property
1318  @layer_utils.cached_per_instance
1319  def _sparse_apply_args(self):
1320    return tf_inspect.getfullargspec(self._resource_apply_sparse).args
1321
1322  # ---------------
1323  # For implementing the trackable interface
1324  # ---------------
1325
1326  def _restore_slot_variable(self, slot_name, variable, slot_variable):
1327    """Restore a newly created slot variable's value."""
1328    variable_key = _var_key(variable)
1329    deferred_restorations = self._deferred_slot_restorations.get(
1330        slot_name, {}).pop(variable_key, [])
1331    # Iterate over restores, highest restore UID first to minimize the number
1332    # of assignments.
1333    deferred_restorations.sort(key=lambda position: position.restore_uid,
1334                               reverse=True)
1335    for checkpoint_position in deferred_restorations:
1336      checkpoint_position.restore(slot_variable)
1337
1338  def _create_or_restore_slot_variable(
1339      self, slot_variable_position, slot_name, variable):
1340    """Restore a slot variable's value, possibly creating it.
1341
1342    Called when a variable which has an associated slot variable is created or
1343    restored. When executing eagerly, we create the slot variable with a
1344    restoring initializer.
1345
1346    No new variables are created when graph building. Instead,
1347    _restore_slot_variable catches these after normal creation and adds restore
1348    ops to the graph. This method is nonetheless important when graph building
1349    for the case when a slot variable has already been created but `variable`
1350    has just been added to a dependency graph (causing us to realize that the
1351    slot variable needs to be restored).
1352
1353    Args:
1354      slot_variable_position: A `trackable._CheckpointPosition` object
1355        indicating the slot variable `Trackable` object to be restored.
1356      slot_name: The name of this `Optimizer`'s slot to restore into.
1357      variable: The variable object this slot is being created for.
1358    """
1359    variable_key = _var_key(variable)
1360    slot_dict = self._slots.get(variable_key, {})
1361    slot_variable = slot_dict.get(slot_name, None)
1362    if (slot_variable is None and context.executing_eagerly() and
1363        slot_variable_position.is_simple_variable()
1364        # Defer slot variable creation if there is an active variable creator
1365        # scope. Generally we'd like to eagerly create/restore slot variables
1366        # when possible, but this may mean that scopes intended to catch
1367        # `variable` also catch its eagerly created slot variable
1368        # unintentionally (specifically make_template would add a dependency on
1369        # a slot variable if not for this case). Deferring is mostly harmless
1370        # (aside from double initialization), and makes variable creator scopes
1371        # behave the same way they do when graph building.
1372        #
1373        # One notable case is with distribution strategy, which uses variable
1374        # creator scope but always desires the `variable` and the slot to use
1375        # the same scope, thus we can safely eagerly create/restore slot
1376        # variables.
1377        and (not ops.get_default_graph()._variable_creator_stack or  # pylint: disable=protected-access
1378             self._distribution_strategy)):
1379      initializer = trackable.CheckpointInitialValueCallable(
1380          checkpoint_position=slot_variable_position)
1381      # Shape is unknown until we read the checkpoint value.
1382      slot_variable = self.add_slot(
1383          var=variable,
1384          initializer=initializer,
1385          slot_name=slot_name)
1386      # Slot variables are not owned by any one object (because we don't want to
1387      # save the slot variable if the optimizer is saved without the non-slot
1388      # variable, or if the non-slot variable is saved without the optimizer;
1389      # it's a dependency hypergraph with edges of the form (optimizer, non-slot
1390      # variable, variable)). So we don't _track_ slot variables anywhere, and
1391      # instead special-case this dependency and otherwise pretend it's a normal
1392      # graph.
1393    if slot_variable is not None:
1394      # If we've either made this slot variable, or if we've pulled out an
1395      # existing slot variable, we should restore it.
1396      slot_variable_position.restore(slot_variable)
1397    else:
1398      # We didn't make the slot variable. Defer restoring until it gets created
1399      # normally. We keep a list rather than the one with the highest restore
1400      # UID in case slot variables have their own dependencies, in which case
1401      # those could differ between restores.
1402      self._deferred_slot_restorations.setdefault(
1403          slot_name, {}).setdefault(variable_key, []).append(
1404              slot_variable_position)
1405
1406  @contextlib.contextmanager
1407  def _distribution_strategy_scope(self):
1408    """Returns the `tf.distribute.Strategy` this optimizer was created under."""
1409    if self._distribution_strategy and not distribute_ctx.has_strategy():
1410      with self._distribution_strategy.scope():
1411        yield self._distribution_strategy.scope()
1412    else:
1413      yield
1414
1415
1416def _var_key(var):
1417  """Key for representing a primary variable, for looking up slots.
1418
1419  In graph mode the name is derived from the var shared name.
1420  In eager mode the name is derived from the var unique id.
1421  If distribution strategy exists, get the primary variable first.
1422
1423  Args:
1424    var: the variable.
1425
1426  Returns:
1427    the unique name of the variable.
1428  """
1429
1430  # pylint: disable=protected-access
1431  # Get the distributed variable if it exists.
1432  if hasattr(var, "_distributed_container"):
1433    var = var._distributed_container()
1434  if var._in_graph_mode:
1435    return var._shared_name
1436  return var._unique_id
1437
1438
1439def _get_slot_key_from_var(var, slot_name):
1440  """Get the slot key for the variable: var_name/slot_name."""
1441
1442  name = _var_key(var)
1443  return name + "/" + slot_name
1444
1445
1446class RestoredOptimizer(OptimizerV2):
1447  """A non-functional Optimizer implementation for checkpoint compatibility.
1448
1449  Holds slot variables and hyperparameters when an optimizer is restored from a
1450  SavedModel. These variables may be referenced in functions along with ops
1451  created by the original optimizer, but currently we do not support using the
1452  optimizer object iself (e.g. through `apply_gradients`).
1453  """
1454  # TODO(allenl): Make the restored optimizer functional by tracing its apply
1455  # methods.
1456
1457  def __init__(self):
1458    super(RestoredOptimizer, self).__init__("RestoredOptimizer")
1459    self._hypers_created = True
1460
1461  def get_config(self):
1462    # TODO(allenl): Save and restore the Optimizer's config
1463    raise NotImplementedError(
1464        "Restoring functional Optimizers from SavedModels is not currently "
1465        "supported. Please file a feature request if this limitation bothers "
1466        "you.")
1467
1468revived_types.register_revived_type(
1469    "optimizer",
1470    lambda obj: isinstance(obj, OptimizerV2),
1471    versions=[revived_types.VersionedTypeRegistration(
1472        object_factory=lambda proto: RestoredOptimizer(),
1473        version=1,
1474        min_producer_version=1,
1475        min_consumer_version=1,
1476        setter=RestoredOptimizer._set_hyper  # pylint: disable=protected-access
1477    )])
1478