1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various learning rate decay functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import math
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.keras.utils import generic_utils
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export("keras.optimizers.schedules.LearningRateSchedule")
34class LearningRateSchedule(object):
35  """A serializable learning rate decay schedule.
36
37  `LearningRateSchedule`s can be passed in as the learning rate of optimizers in
38  `tf.keras.optimizers`. They can be serialized and deserialized using
39  `tf.keras.optimizers.schedules.serialize` and
40  `tf.keras.optimizers.schedules.deserialize`.
41  """
42
43  @abc.abstractmethod
44  def __call__(self, step):
45    raise NotImplementedError("Learning rate schedule must override __call__")
46
47  @abc.abstractmethod
48  def get_config(self):
49    raise NotImplementedError("Learning rate schedule must override get_config")
50
51  @classmethod
52  def from_config(cls, config):
53    """Instantiates a `LearningRateSchedule` from its config.
54
55    Args:
56        config: Output of `get_config()`.
57
58    Returns:
59        A `LearningRateSchedule` instance.
60    """
61    return cls(**config)
62
63
64@keras_export("keras.optimizers.schedules.ExponentialDecay")
65class ExponentialDecay(LearningRateSchedule):
66  """A LearningRateSchedule that uses an exponential decay schedule."""
67
68  def __init__(
69      self,
70      initial_learning_rate,
71      decay_steps,
72      decay_rate,
73      staircase=False,
74      name=None):
75    """Applies exponential decay to the learning rate.
76
77    When training a model, it is often recommended to lower the learning rate as
78    the training progresses. This schedule applies an exponential decay function
79    to an optimizer step, given a provided initial learning rate.
80
81    The schedule a 1-arg callable that produces a decayed learning
82    rate when passed the current optimizer step. This can be useful for changing
83    the learning rate value across different invocations of optimizer functions.
84    It is computed as:
85
86    ```python
87    def decayed_learning_rate(step):
88      return initial_learning_rate * decay_rate ^ (step / decay_steps)
89    ```
90
91    If the argument `staircase` is `True`, then `step / decay_steps` is
92    an integer division and the decayed learning rate follows a
93    staircase function.
94
95    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
96    as the learning rate.
97    Example: When fitting a Keras model, decay every 100000 steps with a base
98    of 0.96:
99
100    ```python
101    initial_learning_rate = 0.1
102    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
103        initial_learning_rate,
104        decay_steps=100000,
105        decay_rate=0.96,
106        staircase=True)
107
108    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
109                  loss='sparse_categorical_crossentropy',
110                  metrics=['accuracy'])
111
112    model.fit(data, labels, epochs=5)
113    ```
114
115    The learning rate schedule is also serializable and deserializable using
116    `tf.keras.optimizers.schedules.serialize` and
117    `tf.keras.optimizers.schedules.deserialize`.
118
119    Args:
120      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
121        Python number.  The initial learning rate.
122      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
123        Must be positive.  See the decay computation above.
124      decay_rate: A scalar `float32` or `float64` `Tensor` or a
125        Python number.  The decay rate.
126      staircase: Boolean.  If `True` decay the learning rate at discrete
127        intervals
128      name: String.  Optional name of the operation.  Defaults to
129        'ExponentialDecay'.
130
131    Returns:
132      A 1-arg callable learning rate schedule that takes the current optimizer
133      step and outputs the decayed learning rate, a scalar `Tensor` of the same
134      type as `initial_learning_rate`.
135    """
136    super(ExponentialDecay, self).__init__()
137    self.initial_learning_rate = initial_learning_rate
138    self.decay_steps = decay_steps
139    self.decay_rate = decay_rate
140    self.staircase = staircase
141    self.name = name
142
143  def __call__(self, step):
144    with ops.name_scope(
145        self.name, "ExponentialDecay",
146        [self.initial_learning_rate, step, self.decay_steps, self.decay_rate]
147    ) as name:
148      initial_learning_rate = ops.convert_to_tensor(
149          self.initial_learning_rate, name="initial_learning_rate")
150      dtype = initial_learning_rate.dtype
151      decay_steps = math_ops.cast(self.decay_steps, dtype)
152      decay_rate = math_ops.cast(self.decay_rate, dtype)
153
154      global_step_recomp = math_ops.cast(step, dtype)
155      p = global_step_recomp / decay_steps
156      if self.staircase:
157        p = math_ops.floor(p)
158      return math_ops.multiply(
159          initial_learning_rate, math_ops.pow(decay_rate, p), name=name)
160
161  def get_config(self):
162    return {
163        "initial_learning_rate": self.initial_learning_rate,
164        "decay_steps": self.decay_steps,
165        "decay_rate": self.decay_rate,
166        "staircase": self.staircase,
167        "name": self.name
168    }
169
170
171@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
172class PiecewiseConstantDecay(LearningRateSchedule):
173  """A LearningRateSchedule that uses a piecewise constant decay schedule."""
174
175  def __init__(
176      self,
177      boundaries,
178      values,
179      name=None):
180    """Piecewise constant from boundaries and interval values.
181
182    The function returns a 1-arg callable to compute the piecewise constant
183    when passed the current optimizer step. This can be useful for changing the
184    learning rate value across different invocations of optimizer functions.
185
186    Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
187      for the next 10000 steps, and 0.1 for any additional steps.
188
189    ```python
190    step = tf.Variable(0, trainable=False)
191    boundaries = [100000, 110000]
192    values = [1.0, 0.5, 0.1]
193    learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
194        boundaries, values)
195
196    # Later, whenever we perform an optimization step, we pass in the step.
197    learning_rate = learning_rate_fn(step)
198    ```
199
200    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
201    as the learning rate. The learning rate schedule is also serializable and
202    deserializable using `tf.keras.optimizers.schedules.serialize` and
203    `tf.keras.optimizers.schedules.deserialize`.
204
205    Args:
206      boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
207        increasing entries, and with all elements having the same type as the
208        optimizer step.
209      values: A list of `Tensor`s or `float`s or `int`s that specifies the
210        values for the intervals defined by `boundaries`. It should have one
211        more element than `boundaries`, and all elements should have the same
212        type.
213      name: A string. Optional name of the operation. Defaults to
214        'PiecewiseConstant'.
215
216    Returns:
217      A 1-arg callable learning rate schedule that takes the current optimizer
218      step and outputs the decayed learning rate, a scalar `Tensor` of the same
219      type as the boundary tensors.
220
221      The output of the 1-arg function that takes the `step`
222      is `values[0]` when `step <= boundaries[0]`,
223      `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
224      and values[-1] when `step > boundaries[-1]`.
225
226    Raises:
227      ValueError: if types of all `values` do not match or
228          the number of elements in the lists does not match.
229    """
230    super(PiecewiseConstantDecay, self).__init__()
231
232    if len(boundaries) != len(values) - 1:
233      raise ValueError(
234          "The length of boundaries should be 1 less than the length of values")
235
236    self.boundaries = boundaries
237    self.values = values
238    self.name = name
239
240  def __call__(self, step):
241    with ops.name_scope(self.name, "PiecewiseConstant",
242                        [step, self.boundaries, self.values, self.name]):
243      boundaries = ops.convert_n_to_tensor(self.boundaries)
244      values = ops.convert_n_to_tensor(self.values)
245      x_recomp = ops.convert_to_tensor(step)
246      # Avoid explicit conversion to x's dtype. This could result in faulty
247      # comparisons, for example if floats are converted to integers.
248      for i, b in enumerate(boundaries):
249        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
250          # We can promote int32 boundaries to int64 without loss of precision.
251          # This covers the most common case where the user passes in boundaries
252          # as an array of Python integers.
253          if (b.dtype.base_dtype == dtypes.int32 and
254              x_recomp.dtype.base_dtype == dtypes.int64):
255            b = math_ops.cast(b, x_recomp.dtype.base_dtype)
256            boundaries[i] = b
257          else:
258            raise ValueError(
259                "Boundaries (%s) must have the same dtype as x (%s)." %
260                (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
261      # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
262      for v in values[1:]:
263        if v.dtype.base_dtype != values[0].dtype.base_dtype:
264          raise ValueError(
265              "Values must have elements all with the same dtype (%s vs %s)." %
266              (values[0].dtype.base_dtype, v.dtype.base_dtype))
267      pred_fn_pairs = []
268      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
269      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
270      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
271        # Need to bind v here; can do this with lambda v=v: ...
272        pred = (x_recomp > low) & (x_recomp <= high)
273        pred_fn_pairs.append((pred, lambda v=v: v))
274
275      # The default isn't needed here because our conditions are mutually
276      # exclusive and exhaustive, but tf.case requires it.
277      default = lambda: values[0]
278      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
279
280  def get_config(self):
281    return {
282        "boundaries": self.boundaries,
283        "values": self.values,
284        "name": self.name
285    }
286
287
288@keras_export("keras.optimizers.schedules.PolynomialDecay")
289class PolynomialDecay(LearningRateSchedule):
290  """A LearningRateSchedule that uses a polynomial decay schedule."""
291
292  def __init__(
293      self,
294      initial_learning_rate,
295      decay_steps,
296      end_learning_rate=0.0001,
297      power=1.0,
298      cycle=False,
299      name=None):
300    """Applies a polynomial decay to the learning rate.
301
302    It is commonly observed that a monotonically decreasing learning rate, whose
303    degree of change is carefully chosen, results in a better performing model.
304    This schedule applies a polynomial decay function to an optimizer step,
305    given a provided `initial_learning_rate`, to reach an `end_learning_rate`
306    in the given `decay_steps`.
307
308    It requires a `step` value to compute the decayed learning rate. You
309    can just pass a TensorFlow variable that you increment at each training
310    step.
311
312    The schedule is a 1-arg callable that produces a decayed learning rate
313    when passed the current optimizer step. This can be useful for changing the
314    learning rate value across different invocations of optimizer functions.
315    It is computed as:
316
317    ```python
318    def decayed_learning_rate(step):
319      step = min(step, decay_steps)
320      return ((initial_learning_rate - end_learning_rate) *
321              (1 - step / decay_steps) ^ (power)
322             ) + end_learning_rate
323    ```
324
325    If `cycle` is True then a multiple of `decay_steps` is used, the first one
326    that is bigger than `step`.
327
328    ```python
329    def decayed_learning_rate(step):
330      decay_steps = decay_steps * ceil(step / decay_steps)
331      return ((initial_learning_rate - end_learning_rate) *
332              (1 - step / decay_steps) ^ (power)
333             ) + end_learning_rate
334    ```
335
336    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
337    as the learning rate.
338    Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
339    sqrt (i.e. power=0.5):
340
341    ```python
342    ...
343    starter_learning_rate = 0.1
344    end_learning_rate = 0.01
345    decay_steps = 10000
346    learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
347        starter_learning_rate,
348        decay_steps,
349        end_learning_rate,
350        power=0.5)
351
352    model.compile(optimizer=tf.keras.optimizers.SGD(
353                      learning_rate=learning_rate_fn),
354                  loss='sparse_categorical_crossentropy',
355                  metrics=['accuracy'])
356
357    model.fit(data, labels, epochs=5)
358    ```
359
360    The learning rate schedule is also serializable and deserializable using
361    `tf.keras.optimizers.schedules.serialize` and
362    `tf.keras.optimizers.schedules.deserialize`.
363
364    Args:
365      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
366        Python number.  The initial learning rate.
367      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
368        Must be positive.  See the decay computation above.
369      end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
370        Python number.  The minimal end learning rate.
371      power: A scalar `float32` or `float64` `Tensor` or a
372        Python number.  The power of the polynomial. Defaults to linear, 1.0.
373      cycle: A boolean, whether or not it should cycle beyond decay_steps.
374      name: String.  Optional name of the operation. Defaults to
375        'PolynomialDecay'.
376
377    Returns:
378      A 1-arg callable learning rate schedule that takes the current optimizer
379      step and outputs the decayed learning rate, a scalar `Tensor` of the same
380      type as `initial_learning_rate`.
381    """
382    super(PolynomialDecay, self).__init__()
383
384    self.initial_learning_rate = initial_learning_rate
385    self.decay_steps = decay_steps
386    self.end_learning_rate = end_learning_rate
387    self.power = power
388    self.cycle = cycle
389    self.name = name
390
391  def __call__(self, step):
392    with ops.name_scope(
393        self.name, "PolynomialDecay",
394        [self.initial_learning_rate, step, self.decay_steps,
395         self.end_learning_rate, self.power]
396    ) as name:
397      initial_learning_rate = ops.convert_to_tensor(
398          self.initial_learning_rate, name="initial_learning_rate")
399      dtype = initial_learning_rate.dtype
400      end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
401      power = math_ops.cast(self.power, dtype)
402
403      global_step_recomp = math_ops.cast(step, dtype)
404      decay_steps_recomp = math_ops.cast(self.decay_steps, dtype)
405      if self.cycle:
406        # Find the first multiple of decay_steps that is bigger than
407        # global_step. If global_step is zero set the multiplier to 1
408        multiplier = control_flow_ops.cond(
409            math_ops.equal(global_step_recomp, 0), lambda: 1.0,
410            lambda: math_ops.ceil(global_step_recomp / self.decay_steps))
411        decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
412      else:
413        # Make sure that the global_step used is not bigger than decay_steps.
414        global_step_recomp = math_ops.minimum(global_step_recomp,
415                                              self.decay_steps)
416
417      p = math_ops.div(global_step_recomp, decay_steps_recomp)
418      return math_ops.add(
419          math_ops.multiply(initial_learning_rate - end_learning_rate,
420                            math_ops.pow(1 - p, power)),
421          end_learning_rate,
422          name=name)
423
424  def get_config(self):
425    return {
426        "initial_learning_rate": self.initial_learning_rate,
427        "decay_steps": self.decay_steps,
428        "end_learning_rate": self.end_learning_rate,
429        "power": self.power,
430        "cycle": self.cycle,
431        "name": self.name
432    }
433
434
435@keras_export("keras.optimizers.schedules.InverseTimeDecay")
436class InverseTimeDecay(LearningRateSchedule):
437  """A LearningRateSchedule that uses an inverse time decay schedule."""
438
439  def __init__(
440      self,
441      initial_learning_rate,
442      decay_steps,
443      decay_rate,
444      staircase=False,
445      name=None):
446    """Applies inverse time decay to the initial learning rate.
447
448    When training a model, it is often recommended to lower the learning rate as
449    the training progresses. This schedule applies the inverse decay function
450    to an optimizer step, given a provided initial learning rate.
451    It requires a `step` value to compute the decayed learning rate. You can
452    just pass a TensorFlow variable that you increment at each training step.
453
454    The schedule a 1-arg callable that produces a decayed learning
455    rate when passed the current optimizer step. This can be useful for changing
456    the learning rate value across different invocations of optimizer functions.
457    It is computed as:
458
459    ```python
460    def decayed_learning_rate(step):
461      return initial_learning_rate / (1 + decay_rate * step / decay_step)
462    ```
463
464    or, if `staircase` is `True`, as:
465
466    ```python
467    def decayed_learning_rate(step):
468      return initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
469    ```
470
471    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
472    as the learning rate.
473    Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
474
475    ```python
476    ...
477    initial_learning_rate = 0.1
478    decay_steps = 1.0
479    decay_rate = 0.5
480    learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
481      initial_learning_rate, global_step, decay_steps, decay_rate)
482
483    model.compile(optimizer=tf.keras.optimizers.SGD(
484                      learning_rate=learning_rate_fn),
485                  loss='sparse_categorical_crossentropy',
486                  metrics=['accuracy'])
487
488    model.fit(data, labels, epochs=5)
489    ```
490
491    Args:
492      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
493        Python number.  The initial learning rate.
494      decay_steps: How often to apply decay.
495      decay_rate: A Python number.  The decay rate.
496      staircase: Whether to apply decay in a discrete staircase, as opposed to
497        continuous, fashion.
498      name: String.  Optional name of the operation.  Defaults to
499        'InverseTimeDecay'.
500
501    Returns:
502      A 1-arg callable learning rate schedule that takes the current optimizer
503      step and outputs the decayed learning rate, a scalar `Tensor` of the same
504      type as `initial_learning_rate`.
505    """
506    super(InverseTimeDecay, self).__init__()
507
508    self.initial_learning_rate = initial_learning_rate
509    self.decay_steps = decay_steps
510    self.decay_rate = decay_rate
511    self.staircase = staircase
512    self.name = name
513
514  def __call__(self, step):
515    with ops.name_scope(self.name, "InverseTimeDecay",
516                        [self.initial_learning_rate, step, self.decay_rate]
517                       ) as name:
518      initial_learning_rate = ops.convert_to_tensor(
519          self.initial_learning_rate, name="initial_learning_rate")
520      dtype = initial_learning_rate.dtype
521      decay_steps = math_ops.cast(self.decay_steps, dtype)
522      decay_rate = math_ops.cast(self.decay_rate, dtype)
523
524      global_step_recomp = math_ops.cast(step, dtype)
525      p = global_step_recomp / decay_steps
526      if self.staircase:
527        p = math_ops.floor(p)
528      const = math_ops.cast(constant_op.constant(1), dtype)
529      denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
530      return math_ops.div(initial_learning_rate, denom, name=name)
531
532  def get_config(self):
533    return {
534        "initial_learning_rate": self.initial_learning_rate,
535        "decay_steps": self.decay_steps,
536        "decay_rate": self.decay_rate,
537        "staircase": self.staircase,
538        "name": self.name
539    }
540
541
542@keras_export("keras.experimental.CosineDecay")
543class CosineDecay(LearningRateSchedule):
544  """A LearningRateSchedule that uses a cosine decay schedule."""
545
546  def __init__(
547      self,
548      initial_learning_rate,
549      decay_steps,
550      alpha=0.0,
551      name=None):
552    """Applies cosine decay to the learning rate.
553
554    See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
555    with Warm Restarts. https://arxiv.org/abs/1608.03983
556
557    When training a model, it is often recommended to lower the learning rate as
558    the training progresses. This schedule applies a cosine decay function
559    to an optimizer step, given a provided initial learning rate.
560    It requires a `step` value to compute the decayed learning rate. You can
561    just pass a TensorFlow variable that you increment at each training step.
562
563    The schedule a 1-arg callable that produces a decayed learning
564    rate when passed the current optimizer step. This can be useful for changing
565    the learning rate value across different invocations of optimizer functions.
566    It is computed as:
567
568    ```python
569    def decayed_learning_rate(step):
570      step = min(step, decay_steps)
571      cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
572      decayed = (1 - alpha) * cosine_decay + alpha
573      return initial_learning_rate * decayed
574    ```
575
576    Example usage:
577    ```python
578    decay_steps = 1000
579    lr_decayed_fn = tf.keras.experimental.CosineDecay(
580        initial_learning_rate, global_step, decay_steps)
581    ```
582
583    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
584    as the learning rate. The learning rate schedule is also serializable and
585    deserializable using `tf.keras.optimizers.schedules.serialize` and
586    `tf.keras.optimizers.schedules.deserialize`.
587
588    Args:
589      initial_learning_rate: A scalar `float32` or `float64` Tensor or a
590        Python number. The initial learning rate.
591      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
592        Number of steps to decay over.
593      alpha: A scalar `float32` or `float64` Tensor or a Python number.
594        Minimum learning rate value as a fraction of initial_learning_rate.
595      name: String. Optional name of the operation.  Defaults to 'CosineDecay'.
596    Returns:
597      A 1-arg callable learning rate schedule that takes the current optimizer
598      step and outputs the decayed learning rate, a scalar `Tensor` of the same
599      type as `initial_learning_rate`.
600    """
601    super(CosineDecay, self).__init__()
602
603    self.initial_learning_rate = initial_learning_rate
604    self.decay_steps = decay_steps
605    self.alpha = alpha
606    self.name = name
607
608  def __call__(self, step):
609    with ops.name_scope(self.name, "CosineDecay",
610                        [self.initial_learning_rate, step]):
611      initial_learning_rate = ops.convert_to_tensor(
612          self.initial_learning_rate, name="initial_learning_rate")
613      dtype = initial_learning_rate.dtype
614      decay_steps = math_ops.cast(self.decay_steps, dtype)
615
616      global_step_recomp = math_ops.cast(step, dtype)
617      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
618      completed_fraction = global_step_recomp / decay_steps
619      cosine_decayed = 0.5 * (1.0 + math_ops.cos(
620          constant_op.constant(math.pi) * completed_fraction))
621
622      decayed = (1 - self.alpha) * cosine_decayed + self.alpha
623      return math_ops.multiply(initial_learning_rate, decayed)
624
625  def get_config(self):
626    return {
627        "initial_learning_rate": self.initial_learning_rate,
628        "decay_steps": self.decay_steps,
629        "alpha": self.alpha,
630        "name": self.name
631    }
632
633
634@keras_export("keras.experimental.CosineDecayRestarts")
635class CosineDecayRestarts(LearningRateSchedule):
636  """A LearningRateSchedule that uses a cosine decay schedule with restarts."""
637
638  def __init__(
639      self,
640      initial_learning_rate,
641      first_decay_steps,
642      t_mul=2.0,
643      m_mul=1.0,
644      alpha=0.0,
645      name=None):
646    """Applies cosine decay with restarts to the learning rate.
647
648    See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
649    with Warm Restarts. https://arxiv.org/abs/1608.03983
650
651    When training a model, it is often recommended to lower the learning rate as
652    the training progresses. This schedule applies a cosine decay function with
653    restarts to an optimizer step, given a provided initial learning rate.
654    It requires a `step` value to compute the decayed learning rate. You can
655    just pass a TensorFlow variable that you increment at each training step.
656
657    The schedule a 1-arg callable that produces a decayed learning
658    rate when passed the current optimizer step. This can be useful for changing
659    the learning rate value across different invocations of optimizer functions.
660
661    The learning rate multiplier first decays
662    from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
663    restart is performed. Each new warm restart runs for `t_mul` times more
664    steps and with `m_mul` times smaller initial learning rate.
665
666    Example usage:
667    ```python
668    first_decay_steps = 1000
669    lr_decayed_fn = (
670      tf.keras.experimental.CosineDecayRestarts(
671          initial_learning_rate,
672          global_step,
673          first_decay_steps))
674    ```
675
676    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
677    as the learning rate. The learning rate schedule is also serializable and
678    deserializable using `tf.keras.optimizers.schedules.serialize` and
679    `tf.keras.optimizers.schedules.deserialize`.
680
681    Args:
682      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
683        number. The initial learning rate.
684      first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python
685        number. Number of steps to decay over.
686      t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
687        Used to derive the number of iterations in the i-th period
688      m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
689        Used to derive the initial learning rate of the i-th period:
690      alpha: A scalar `float32` or `float64` Tensor or a Python number.
691        Minimum learning rate value as a fraction of the initial_learning_rate.
692      name: String. Optional name of the operation.  Defaults to 'SGDRDecay'.
693    Returns:
694      A 1-arg callable learning rate schedule that takes the current optimizer
695      step and outputs the decayed learning rate, a scalar `Tensor` of the same
696      type as `initial_learning_rate`.
697    Raises:
698      ValueError: if `global_step` is not supplied.
699    """
700    super(CosineDecayRestarts, self).__init__()
701
702    self.initial_learning_rate = initial_learning_rate
703    self.first_decay_steps = first_decay_steps
704    self._t_mul = t_mul
705    self._m_mul = m_mul
706    self.alpha = alpha
707    self.name = name
708
709  def __call__(self, step):
710    with ops.name_scope(self.name, "SGDRDecay",
711                        [self.initial_learning_rate, step]
712                       ) as name:
713      initial_learning_rate = ops.convert_to_tensor(
714          self.initial_learning_rate, name="initial_learning_rate")
715      dtype = initial_learning_rate.dtype
716      first_decay_steps = math_ops.cast(self.first_decay_steps, dtype)
717      alpha = math_ops.cast(self.alpha, dtype)
718      t_mul = math_ops.cast(self._t_mul, dtype)
719      m_mul = math_ops.cast(self._m_mul, dtype)
720
721      global_step_recomp = math_ops.cast(step, dtype)
722      completed_fraction = global_step_recomp / first_decay_steps
723
724      def compute_step(completed_fraction, geometric=False):
725        """Helper for `cond` operation."""
726        if geometric:
727          i_restart = math_ops.floor(
728              math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
729              math_ops.log(t_mul))
730
731          sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
732          completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
733
734        else:
735          i_restart = math_ops.floor(completed_fraction)
736          completed_fraction -= i_restart
737
738        return i_restart, completed_fraction
739
740      i_restart, completed_fraction = control_flow_ops.cond(
741          math_ops.equal(t_mul, 1.0),
742          lambda: compute_step(completed_fraction, geometric=False),
743          lambda: compute_step(completed_fraction, geometric=True))
744
745      m_fac = m_mul**i_restart
746      cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
747          constant_op.constant(math.pi) * completed_fraction))
748      decayed = (1 - alpha) * cosine_decayed + alpha
749
750      return math_ops.multiply(initial_learning_rate, decayed, name=name)
751
752  def get_config(self):
753    return {
754        "initial_learning_rate": self.initial_learning_rate,
755        "first_decay_steps": self.first_decay_steps,
756        "t_mul": self._t_mul,
757        "m_mul": self._m_mul,
758        "alpha": self.alpha,
759        "name": self.name
760    }
761
762
763@keras_export("keras.experimental.LinearCosineDecay")
764class LinearCosineDecay(LearningRateSchedule):
765  """A LearningRateSchedule that uses a linear cosine decay schedule."""
766
767  def __init__(
768      self,
769      initial_learning_rate,
770      decay_steps,
771      num_periods=0.5,
772      alpha=0.0,
773      beta=0.001,
774      name=None):
775    """Applies linear cosine decay to the learning rate.
776
777    See [Bello et al., ICML2017] Neural Optimizer Search with RL.
778    https://arxiv.org/abs/1709.07417
779
780    For the idea of warm starts here controlled by `num_periods`,
781    see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
782    with Warm Restarts. https://arxiv.org/abs/1608.03983
783
784    Note that linear cosine decay is more aggressive than cosine decay and
785    larger initial learning rates can typically be used.
786
787    When training a model, it is often recommended to lower the learning rate as
788    the training progresses. This schedule applies a linear cosine decay
789    function to an optimizer step, given a provided initial learning rate.
790    It requires a `step` value to compute the decayed learning rate. You can
791    just pass a TensorFlow variable that you increment at each training step.
792
793    The schedule a 1-arg callable that produces a decayed learning
794    rate when passed the current optimizer step. This can be useful for changing
795    the learning rate value across different invocations of optimizer functions.
796    It is computed as:
797
798    ```python
799    def decayed_learning_rate(step):
800      step = min(step, decay_steps)
801      linear_decay = (decay_steps - step) / decay_steps
802      cosine_decay = 0.5 * (
803          1 + cos(pi * 2 * num_periods * step / decay_steps))
804      decayed = (alpha + linear_decay) * cosine_decay + beta
805      return initial_learning_rate * decayed
806    ```
807
808    Example usage:
809    ```python
810    decay_steps = 1000
811    lr_decayed_fn = (
812      tf.keras.experimental.LinearCosineDecay(
813        initial_learning_rate, global_step, decay_steps))
814    ```
815
816    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
817    as the learning rate. The learning rate schedule is also serializable and
818    deserializable using `tf.keras.optimizers.schedules.serialize` and
819    `tf.keras.optimizers.schedules.deserialize`.
820
821    Args:
822      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
823        number. The initial learning rate.
824      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
825        Number of steps to decay over.
826      num_periods: Number of periods in the cosine part of the decay.
827        See computation above.
828      alpha: See computation above.
829      beta: See computation above.
830      name: String.  Optional name of the operation.  Defaults to
831        'LinearCosineDecay'.
832    Returns:
833      A 1-arg callable learning rate schedule that takes the current optimizer
834      step and outputs the decayed learning rate, a scalar `Tensor` of the same
835      type as `initial_learning_rate`.
836    """
837    super(LinearCosineDecay, self).__init__()
838
839    self.initial_learning_rate = initial_learning_rate
840    self.decay_steps = decay_steps
841    self.num_periods = num_periods
842    self.alpha = alpha
843    self.beta = beta
844    self.name = name
845
846  def __call__(self, step):
847    with ops.name_scope(self.name, "LinearCosineDecay",
848                        [self.initial_learning_rate, step]) as name:
849      initial_learning_rate = ops.convert_to_tensor(
850          self.initial_learning_rate, name="initial_learning_rate")
851      dtype = initial_learning_rate.dtype
852      decay_steps = math_ops.cast(self.decay_steps, dtype)
853      num_periods = math_ops.cast(self.num_periods, dtype)
854      alpha = math_ops.cast(self.alpha, dtype)
855      beta = math_ops.cast(self.beta, dtype)
856
857      global_step_recomp = math_ops.cast(step, dtype)
858      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
859      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
860      completed_fraction = global_step_recomp / decay_steps
861      fraction = 2.0 * num_periods * completed_fraction
862      cosine_decayed = 0.5 * (
863          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
864
865      linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
866      return math_ops.multiply(initial_learning_rate, linear_cosine_decayed,
867                               name=name)
868
869  def get_config(self):
870    return {
871        "initial_learning_rate": self.initial_learning_rate,
872        "decay_steps": self.decay_steps,
873        "num_periods": self.num_periods,
874        "alpha": self.alpha,
875        "beta": self.beta,
876        "name": self.name
877    }
878
879
880@keras_export("keras.experimental.NoisyLinearCosineDecay")
881class NoisyLinearCosineDecay(LearningRateSchedule):
882  """A LearningRateSchedule that uses a noisy linear cosine decay schedule."""
883
884  def __init__(
885      self,
886      initial_learning_rate,
887      decay_steps,
888      initial_variance=1.0,
889      variance_decay=0.55,
890      num_periods=0.5,
891      alpha=0.0,
892      beta=0.001,
893      name=None):
894    """Applies noisy linear cosine decay to the learning rate.
895
896    See [Bello et al., ICML2017] Neural Optimizer Search with RL.
897    https://arxiv.org/abs/1709.07417
898
899    For the idea of warm starts here controlled by `num_periods`,
900    see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
901    with Warm Restarts. https://arxiv.org/abs/1608.03983
902
903    Note that linear cosine decay is more aggressive than cosine decay and
904    larger initial learning rates can typically be used.
905
906    When training a model, it is often recommended to lower the learning rate as
907    the training progresses. This schedule applies a noisy linear cosine decay
908    function to an optimizer step, given a provided initial learning rate.
909    It requires a `step` value to compute the decayed learning rate. You can
910    just pass a TensorFlow variable that you increment at each training step.
911
912    The schedule a 1-arg callable that produces a decayed learning
913    rate when passed the current optimizer step. This can be useful for changing
914    the learning rate value across different invocations of optimizer functions.
915    It is computed as:
916
917    ```python
918    def decayed_learning_rate(step):
919      step = min(step, decay_steps)
920      linear_decay = (decay_steps - step) / decay_steps)
921      cosine_decay = 0.5 * (
922          1 + cos(pi * 2 * num_periods * step / decay_steps))
923      decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
924      return initial_learning_rate * decayed
925    ```
926    where eps_t is 0-centered gaussian noise with variance
927    initial_variance / (1 + global_step) ** variance_decay
928
929    Example usage:
930    ```python
931    decay_steps = 1000
932    lr_decayed_fn = (
933      tf.keras.experimental.NoisyLinearCosineDecay(
934        initial_learning_rate, global_step, decay_steps))
935    ```
936
937    You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
938    as the learning rate. The learning rate schedule is also serializable and
939    deserializable using `tf.keras.optimizers.schedules.serialize` and
940    `tf.keras.optimizers.schedules.deserialize`.
941
942    Args:
943      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
944        number. The initial learning rate.
945      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
946        Number of steps to decay over.
947      initial_variance: initial variance for the noise. See computation above.
948      variance_decay: decay for the noise's variance. See computation above.
949      num_periods: Number of periods in the cosine part of the decay.
950        See computation above.
951      alpha: See computation above.
952      beta: See computation above.
953      name: String.  Optional name of the operation.  Defaults to
954        'NoisyLinearCosineDecay'.
955    Returns:
956      A 1-arg callable learning rate schedule that takes the current optimizer
957      step and outputs the decayed learning rate, a scalar `Tensor` of the same
958      type as `initial_learning_rate`.
959    """
960    super(NoisyLinearCosineDecay, self).__init__()
961
962    self.initial_learning_rate = initial_learning_rate
963    self.decay_steps = decay_steps
964    self.initial_variance = initial_variance
965    self.variance_decay = variance_decay
966    self.num_periods = num_periods
967    self.alpha = alpha
968    self.beta = beta
969    self.name = name
970
971  def __call__(self, step):
972    with ops.name_scope(self.name, "NoisyLinearCosineDecay",
973                        [self.initial_learning_rate, step]) as name:
974      initial_learning_rate = ops.convert_to_tensor(
975          self.initial_learning_rate, name="initial_learning_rate")
976      dtype = initial_learning_rate.dtype
977      decay_steps = math_ops.cast(self.decay_steps, dtype)
978      initial_variance = math_ops.cast(self.initial_variance, dtype)
979      variance_decay = math_ops.cast(self.variance_decay, dtype)
980      num_periods = math_ops.cast(self.num_periods, dtype)
981      alpha = math_ops.cast(self.alpha, dtype)
982      beta = math_ops.cast(self.beta, dtype)
983
984      global_step_recomp = math_ops.cast(step, dtype)
985      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
986      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
987      variance = initial_variance / (
988          math_ops.pow(1.0 + global_step_recomp, variance_decay))
989      std = math_ops.sqrt(variance)
990      noisy_linear_decayed = (
991          linear_decayed + random_ops.random_normal(
992              linear_decayed.shape, stddev=std))
993
994      completed_fraction = global_step_recomp / decay_steps
995      fraction = 2.0 * num_periods * completed_fraction
996      cosine_decayed = 0.5 * (
997          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
998      noisy_linear_cosine_decayed = (
999          (alpha + noisy_linear_decayed) * cosine_decayed + beta)
1000
1001      return math_ops.multiply(
1002          initial_learning_rate, noisy_linear_cosine_decayed, name=name)
1003
1004  def get_config(self):
1005    return {
1006        "initial_learning_rate": self.initial_learning_rate,
1007        "decay_steps": self.decay_steps,
1008        "initial_variance": self.initial_variance,
1009        "variance_decay": self.variance_decay,
1010        "num_periods": self.num_periods,
1011        "alpha": self.alpha,
1012        "beta": self.beta,
1013        "name": self.name
1014    }
1015
1016
1017@keras_export("keras.optimizers.schedules.serialize")
1018def serialize(learning_rate_schedule):
1019  return generic_utils.serialize_keras_object(learning_rate_schedule)
1020
1021
1022@keras_export("keras.optimizers.schedules.deserialize")
1023def deserialize(config, custom_objects=None):
1024  return generic_utils.deserialize_keras_object(
1025      config,
1026      module_objects=globals(),
1027      custom_objects=custom_objects,
1028      printable_module_name="decay")
1029