1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various learning rate decay functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import math
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import ops
25from tensorflow.python.keras.utils import generic_utils
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import random_ops
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export("keras.optimizers.schedules.LearningRateSchedule")
34class LearningRateSchedule(object):
35  """A serializable learning rate decay schedule.
36
37  `LearningRateSchedule`s can be passed in as the learning rate of optimizers in
38  `tf.keras.optimizers`. They can be serialized and deserialized using
39  `tf.keras.optimizers.schedules.serialize` and
40  `tf.keras.optimizers.schedules.deserialize`.
41  """
42
43  @abc.abstractmethod
44  def __call__(self, step):
45    raise NotImplementedError("Learning rate schedule must override __call__")
46
47  @abc.abstractmethod
48  def get_config(self):
49    raise NotImplementedError("Learning rate schedule must override get_config")
50
51  @classmethod
52  def from_config(cls, config):
53    """Instantiates a `LearningRateSchedule` from its config.
54
55    Args:
56        config: Output of `get_config()`.
57
58    Returns:
59        A `LearningRateSchedule` instance.
60    """
61    return cls(**config)
62
63
64@keras_export("keras.optimizers.schedules.ExponentialDecay")
65class ExponentialDecay(LearningRateSchedule):
66  """A LearningRateSchedule that uses an exponential decay schedule.
67
68  When training a model, it is often useful to lower the learning rate as
69  the training progresses. This schedule applies an exponential decay function
70  to an optimizer step, given a provided initial learning rate.
71
72  The schedule a 1-arg callable that produces a decayed learning
73  rate when passed the current optimizer step. This can be useful for changing
74  the learning rate value across different invocations of optimizer functions.
75  It is computed as:
76
77  ```python
78  def decayed_learning_rate(step):
79    return initial_learning_rate * decay_rate ^ (step / decay_steps)
80  ```
81
82  If the argument `staircase` is `True`, then `step / decay_steps` is
83  an integer division and the decayed learning rate follows a
84  staircase function.
85
86  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
87  as the learning rate.
88  Example: When fitting a Keras model, decay every 100000 steps with a base
89  of 0.96:
90
91  ```python
92  initial_learning_rate = 0.1
93  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
94      initial_learning_rate,
95      decay_steps=100000,
96      decay_rate=0.96,
97      staircase=True)
98
99  model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule),
100                loss='sparse_categorical_crossentropy',
101                metrics=['accuracy'])
102
103  model.fit(data, labels, epochs=5)
104  ```
105
106  The learning rate schedule is also serializable and deserializable using
107  `tf.keras.optimizers.schedules.serialize` and
108  `tf.keras.optimizers.schedules.deserialize`.
109
110  Returns:
111    A 1-arg callable learning rate schedule that takes the current optimizer
112    step and outputs the decayed learning rate, a scalar `Tensor` of the same
113    type as `initial_learning_rate`.
114  """
115
116  def __init__(
117      self,
118      initial_learning_rate,
119      decay_steps,
120      decay_rate,
121      staircase=False,
122      name=None):
123    """Applies exponential decay to the learning rate.
124
125    Args:
126      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
127        Python number.  The initial learning rate.
128      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
129        Must be positive.  See the decay computation above.
130      decay_rate: A scalar `float32` or `float64` `Tensor` or a
131        Python number.  The decay rate.
132      staircase: Boolean.  If `True` decay the learning rate at discrete
133        intervals
134      name: String.  Optional name of the operation.  Defaults to
135        'ExponentialDecay'.
136    """
137    super(ExponentialDecay, self).__init__()
138    self.initial_learning_rate = initial_learning_rate
139    self.decay_steps = decay_steps
140    self.decay_rate = decay_rate
141    self.staircase = staircase
142    self.name = name
143
144  def __call__(self, step):
145    with ops.name_scope_v2(self.name or "ExponentialDecay") as name:
146      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
147          self.initial_learning_rate, name="initial_learning_rate")
148      dtype = initial_learning_rate.dtype
149      decay_steps = math_ops.cast(self.decay_steps, dtype)
150      decay_rate = math_ops.cast(self.decay_rate, dtype)
151
152      global_step_recomp = math_ops.cast(step, dtype)
153      p = global_step_recomp / decay_steps
154      if self.staircase:
155        p = math_ops.floor(p)
156      return math_ops.multiply(
157          initial_learning_rate, math_ops.pow(decay_rate, p), name=name)
158
159  def get_config(self):
160    return {
161        "initial_learning_rate": self.initial_learning_rate,
162        "decay_steps": self.decay_steps,
163        "decay_rate": self.decay_rate,
164        "staircase": self.staircase,
165        "name": self.name
166    }
167
168
169@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay")
170class PiecewiseConstantDecay(LearningRateSchedule):
171  """A LearningRateSchedule that uses a piecewise constant decay schedule.
172
173  The function returns a 1-arg callable to compute the piecewise constant
174  when passed the current optimizer step. This can be useful for changing the
175  learning rate value across different invocations of optimizer functions.
176
177  Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
178    for the next 10000 steps, and 0.1 for any additional steps.
179
180  ```python
181  step = tf.Variable(0, trainable=False)
182  boundaries = [100000, 110000]
183  values = [1.0, 0.5, 0.1]
184  learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(
185      boundaries, values)
186
187  # Later, whenever we perform an optimization step, we pass in the step.
188  learning_rate = learning_rate_fn(step)
189  ```
190
191  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
192  as the learning rate. The learning rate schedule is also serializable and
193  deserializable using `tf.keras.optimizers.schedules.serialize` and
194  `tf.keras.optimizers.schedules.deserialize`.
195
196  Returns:
197    A 1-arg callable learning rate schedule that takes the current optimizer
198    step and outputs the decayed learning rate, a scalar `Tensor` of the same
199    type as the boundary tensors.
200
201    The output of the 1-arg function that takes the `step`
202    is `values[0]` when `step <= boundaries[0]`,
203    `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ...,
204    and values[-1] when `step > boundaries[-1]`.
205  """
206
207  def __init__(
208      self,
209      boundaries,
210      values,
211      name=None):
212    """Piecewise constant from boundaries and interval values.
213
214    Args:
215      boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
216        increasing entries, and with all elements having the same type as the
217        optimizer step.
218      values: A list of `Tensor`s or `float`s or `int`s that specifies the
219        values for the intervals defined by `boundaries`. It should have one
220        more element than `boundaries`, and all elements should have the same
221        type.
222      name: A string. Optional name of the operation. Defaults to
223        'PiecewiseConstant'.
224
225    Raises:
226      ValueError: if the number of elements in the lists do not match.
227    """
228    super(PiecewiseConstantDecay, self).__init__()
229
230    if len(boundaries) != len(values) - 1:
231      raise ValueError(
232          "The length of boundaries should be 1 less than the length of values")
233
234    self.boundaries = boundaries
235    self.values = values
236    self.name = name
237
238  def __call__(self, step):
239    with ops.name_scope_v2(self.name or "PiecewiseConstant"):
240      boundaries = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
241                                      nest.flatten(self.boundaries))
242      values = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch,
243                                  nest.flatten(self.values))
244      x_recomp = ops.convert_to_tensor_v2_with_dispatch(step)
245      for i, b in enumerate(boundaries):
246        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
247          # We cast the boundaries to have the same type as the step
248          b = math_ops.cast(b, x_recomp.dtype.base_dtype)
249          boundaries[i] = b
250      pred_fn_pairs = []
251      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
252      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
253      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
254        # Need to bind v here; can do this with lambda v=v: ...
255        pred = (x_recomp > low) & (x_recomp <= high)
256        pred_fn_pairs.append((pred, lambda v=v: v))
257
258      # The default isn't needed here because our conditions are mutually
259      # exclusive and exhaustive, but tf.case requires it.
260      default = lambda: values[0]
261      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
262
263  def get_config(self):
264    return {
265        "boundaries": self.boundaries,
266        "values": self.values,
267        "name": self.name
268    }
269
270
271@keras_export("keras.optimizers.schedules.PolynomialDecay")
272class PolynomialDecay(LearningRateSchedule):
273  """A LearningRateSchedule that uses a polynomial decay schedule.
274
275  It is commonly observed that a monotonically decreasing learning rate, whose
276  degree of change is carefully chosen, results in a better performing model.
277  This schedule applies a polynomial decay function to an optimizer step,
278  given a provided `initial_learning_rate`, to reach an `end_learning_rate`
279  in the given `decay_steps`.
280
281  It requires a `step` value to compute the decayed learning rate. You
282  can just pass a TensorFlow variable that you increment at each training
283  step.
284
285  The schedule is a 1-arg callable that produces a decayed learning rate
286  when passed the current optimizer step. This can be useful for changing the
287  learning rate value across different invocations of optimizer functions.
288  It is computed as:
289
290  ```python
291  def decayed_learning_rate(step):
292    step = min(step, decay_steps)
293    return ((initial_learning_rate - end_learning_rate) *
294            (1 - step / decay_steps) ^ (power)
295           ) + end_learning_rate
296  ```
297
298  If `cycle` is True then a multiple of `decay_steps` is used, the first one
299  that is bigger than `step`.
300
301  ```python
302  def decayed_learning_rate(step):
303    decay_steps = decay_steps * ceil(step / decay_steps)
304    return ((initial_learning_rate - end_learning_rate) *
305            (1 - step / decay_steps) ^ (power)
306           ) + end_learning_rate
307  ```
308
309  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
310  as the learning rate.
311  Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using
312  sqrt (i.e. power=0.5):
313
314  ```python
315  ...
316  starter_learning_rate = 0.1
317  end_learning_rate = 0.01
318  decay_steps = 10000
319  learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
320      starter_learning_rate,
321      decay_steps,
322      end_learning_rate,
323      power=0.5)
324
325  model.compile(optimizer=tf.keras.optimizers.SGD(
326                    learning_rate=learning_rate_fn),
327                loss='sparse_categorical_crossentropy',
328                metrics=['accuracy'])
329
330  model.fit(data, labels, epochs=5)
331  ```
332
333  The learning rate schedule is also serializable and deserializable using
334  `tf.keras.optimizers.schedules.serialize` and
335  `tf.keras.optimizers.schedules.deserialize`.
336
337  Returns:
338    A 1-arg callable learning rate schedule that takes the current optimizer
339    step and outputs the decayed learning rate, a scalar `Tensor` of the same
340    type as `initial_learning_rate`.
341  """
342
343  def __init__(
344      self,
345      initial_learning_rate,
346      decay_steps,
347      end_learning_rate=0.0001,
348      power=1.0,
349      cycle=False,
350      name=None):
351    """Applies a polynomial decay to the learning rate.
352
353    Args:
354      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
355        Python number.  The initial learning rate.
356      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
357        Must be positive.  See the decay computation above.
358      end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
359        Python number.  The minimal end learning rate.
360      power: A scalar `float32` or `float64` `Tensor` or a
361        Python number.  The power of the polynomial. Defaults to linear, 1.0.
362      cycle: A boolean, whether or not it should cycle beyond decay_steps.
363      name: String.  Optional name of the operation. Defaults to
364        'PolynomialDecay'.
365    """
366    super(PolynomialDecay, self).__init__()
367
368    self.initial_learning_rate = initial_learning_rate
369    self.decay_steps = decay_steps
370    self.end_learning_rate = end_learning_rate
371    self.power = power
372    self.cycle = cycle
373    self.name = name
374
375  def __call__(self, step):
376    with ops.name_scope_v2(self.name or "PolynomialDecay") as name:
377      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
378          self.initial_learning_rate, name="initial_learning_rate")
379      dtype = initial_learning_rate.dtype
380      end_learning_rate = math_ops.cast(self.end_learning_rate, dtype)
381      power = math_ops.cast(self.power, dtype)
382
383      global_step_recomp = math_ops.cast(step, dtype)
384      decay_steps_recomp = math_ops.cast(self.decay_steps, dtype)
385      if self.cycle:
386        # Find the first multiple of decay_steps that is bigger than
387        # global_step. If global_step is zero set the multiplier to 1
388        multiplier = control_flow_ops.cond(
389            math_ops.equal(global_step_recomp, 0), lambda: 1.0,
390            lambda: math_ops.ceil(global_step_recomp / self.decay_steps))
391        decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
392      else:
393        # Make sure that the global_step used is not bigger than decay_steps.
394        global_step_recomp = math_ops.minimum(global_step_recomp,
395                                              decay_steps_recomp)
396
397      p = math_ops.divide(global_step_recomp, decay_steps_recomp)
398      return math_ops.add(
399          math_ops.multiply(initial_learning_rate - end_learning_rate,
400                            math_ops.pow(1 - p, power)),
401          end_learning_rate,
402          name=name)
403
404  def get_config(self):
405    return {
406        "initial_learning_rate": self.initial_learning_rate,
407        "decay_steps": self.decay_steps,
408        "end_learning_rate": self.end_learning_rate,
409        "power": self.power,
410        "cycle": self.cycle,
411        "name": self.name
412    }
413
414
415@keras_export("keras.optimizers.schedules.InverseTimeDecay")
416class InverseTimeDecay(LearningRateSchedule):
417  """A LearningRateSchedule that uses an inverse time decay schedule.
418
419  When training a model, it is often useful to lower the learning rate as
420  the training progresses. This schedule applies the inverse decay function
421  to an optimizer step, given a provided initial learning rate.
422  It requires a `step` value to compute the decayed learning rate. You can
423  just pass a TensorFlow variable that you increment at each training step.
424
425  The schedule a 1-arg callable that produces a decayed learning
426  rate when passed the current optimizer step. This can be useful for changing
427  the learning rate value across different invocations of optimizer functions.
428  It is computed as:
429
430  ```python
431  def decayed_learning_rate(step):
432    return initial_learning_rate / (1 + decay_rate * step / decay_step)
433  ```
434
435  or, if `staircase` is `True`, as:
436
437  ```python
438  def decayed_learning_rate(step):
439    return initial_learning_rate / (1 + decay_rate * floor(step / decay_step))
440  ```
441
442  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
443  as the learning rate.
444  Example: Fit a Keras model when decaying 1/t with a rate of 0.5:
445
446  ```python
447  ...
448  initial_learning_rate = 0.1
449  decay_steps = 1.0
450  decay_rate = 0.5
451  learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(
452    initial_learning_rate, decay_steps, decay_rate)
453
454  model.compile(optimizer=tf.keras.optimizers.SGD(
455                    learning_rate=learning_rate_fn),
456                loss='sparse_categorical_crossentropy',
457                metrics=['accuracy'])
458
459  model.fit(data, labels, epochs=5)
460  ```
461
462  Returns:
463    A 1-arg callable learning rate schedule that takes the current optimizer
464    step and outputs the decayed learning rate, a scalar `Tensor` of the same
465    type as `initial_learning_rate`.
466  """
467
468  def __init__(
469      self,
470      initial_learning_rate,
471      decay_steps,
472      decay_rate,
473      staircase=False,
474      name=None):
475    """Applies inverse time decay to the initial learning rate.
476
477    Args:
478      initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a
479        Python number.  The initial learning rate.
480      decay_steps: How often to apply decay.
481      decay_rate: A Python number.  The decay rate.
482      staircase: Whether to apply decay in a discrete staircase, as opposed to
483        continuous, fashion.
484      name: String.  Optional name of the operation.  Defaults to
485        'InverseTimeDecay'.
486    """
487    super(InverseTimeDecay, self).__init__()
488
489    self.initial_learning_rate = initial_learning_rate
490    self.decay_steps = decay_steps
491    self.decay_rate = decay_rate
492    self.staircase = staircase
493    self.name = name
494
495  def __call__(self, step):
496    with ops.name_scope_v2(self.name or "InverseTimeDecay") as name:
497      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
498          self.initial_learning_rate, name="initial_learning_rate")
499      dtype = initial_learning_rate.dtype
500      decay_steps = math_ops.cast(self.decay_steps, dtype)
501      decay_rate = math_ops.cast(self.decay_rate, dtype)
502
503      global_step_recomp = math_ops.cast(step, dtype)
504      p = global_step_recomp / decay_steps
505      if self.staircase:
506        p = math_ops.floor(p)
507      const = math_ops.cast(constant_op.constant(1), dtype)
508      denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
509      return math_ops.divide(initial_learning_rate, denom, name=name)
510
511  def get_config(self):
512    return {
513        "initial_learning_rate": self.initial_learning_rate,
514        "decay_steps": self.decay_steps,
515        "decay_rate": self.decay_rate,
516        "staircase": self.staircase,
517        "name": self.name
518    }
519
520
521@keras_export("keras.optimizers.schedules.CosineDecay",
522              "keras.experimental.CosineDecay")
523class CosineDecay(LearningRateSchedule):
524  """A LearningRateSchedule that uses a cosine decay schedule.
525
526  See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
527  SGDR: Stochastic Gradient Descent with Warm Restarts.
528
529  When training a model, it is often useful to lower the learning rate as
530  the training progresses. This schedule applies a cosine decay function
531  to an optimizer step, given a provided initial learning rate.
532  It requires a `step` value to compute the decayed learning rate. You can
533  just pass a TensorFlow variable that you increment at each training step.
534
535  The schedule a 1-arg callable that produces a decayed learning
536  rate when passed the current optimizer step. This can be useful for changing
537  the learning rate value across different invocations of optimizer functions.
538  It is computed as:
539
540  ```python
541  def decayed_learning_rate(step):
542    step = min(step, decay_steps)
543    cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))
544    decayed = (1 - alpha) * cosine_decay + alpha
545    return initial_learning_rate * decayed
546  ```
547
548  Example usage:
549  ```python
550  decay_steps = 1000
551  lr_decayed_fn = tf.keras.optimizers.schedules.CosineDecay(
552      initial_learning_rate, decay_steps)
553  ```
554
555  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
556  as the learning rate. The learning rate schedule is also serializable and
557  deserializable using `tf.keras.optimizers.schedules.serialize` and
558  `tf.keras.optimizers.schedules.deserialize`.
559
560  Returns:
561    A 1-arg callable learning rate schedule that takes the current optimizer
562    step and outputs the decayed learning rate, a scalar `Tensor` of the same
563    type as `initial_learning_rate`.
564  """
565
566  def __init__(
567      self,
568      initial_learning_rate,
569      decay_steps,
570      alpha=0.0,
571      name=None):
572    """Applies cosine decay to the learning rate.
573
574    Args:
575      initial_learning_rate: A scalar `float32` or `float64` Tensor or a
576        Python number. The initial learning rate.
577      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
578        Number of steps to decay over.
579      alpha: A scalar `float32` or `float64` Tensor or a Python number.
580        Minimum learning rate value as a fraction of initial_learning_rate.
581      name: String. Optional name of the operation.  Defaults to 'CosineDecay'.
582    """
583    super(CosineDecay, self).__init__()
584
585    self.initial_learning_rate = initial_learning_rate
586    self.decay_steps = decay_steps
587    self.alpha = alpha
588    self.name = name
589
590  def __call__(self, step):
591    with ops.name_scope_v2(self.name or "CosineDecay"):
592      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
593          self.initial_learning_rate, name="initial_learning_rate")
594      dtype = initial_learning_rate.dtype
595      decay_steps = math_ops.cast(self.decay_steps, dtype)
596
597      global_step_recomp = math_ops.cast(step, dtype)
598      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
599      completed_fraction = global_step_recomp / decay_steps
600      cosine_decayed = 0.5 * (1.0 + math_ops.cos(
601          constant_op.constant(math.pi) * completed_fraction))
602
603      decayed = (1 - self.alpha) * cosine_decayed + self.alpha
604      return math_ops.multiply(initial_learning_rate, decayed)
605
606  def get_config(self):
607    return {
608        "initial_learning_rate": self.initial_learning_rate,
609        "decay_steps": self.decay_steps,
610        "alpha": self.alpha,
611        "name": self.name
612    }
613
614
615@keras_export("keras.optimizers.schedules.CosineDecayRestarts",
616              "keras.experimental.CosineDecayRestarts")
617class CosineDecayRestarts(LearningRateSchedule):
618  """A LearningRateSchedule that uses a cosine decay schedule with restarts.
619
620  See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),
621  SGDR: Stochastic Gradient Descent with Warm Restarts.
622
623  When training a model, it is often useful to lower the learning rate as
624  the training progresses. This schedule applies a cosine decay function with
625  restarts to an optimizer step, given a provided initial learning rate.
626  It requires a `step` value to compute the decayed learning rate. You can
627  just pass a TensorFlow variable that you increment at each training step.
628
629  The schedule a 1-arg callable that produces a decayed learning
630  rate when passed the current optimizer step. This can be useful for changing
631  the learning rate value across different invocations of optimizer functions.
632
633  The learning rate multiplier first decays
634  from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
635  restart is performed. Each new warm restart runs for `t_mul` times more
636  steps and with `m_mul` times smaller initial learning rate.
637
638  Example usage:
639  ```python
640  first_decay_steps = 1000
641  lr_decayed_fn = (
642    tf.keras.optimizers.schedules.CosineDecayRestarts(
643        initial_learning_rate,
644        first_decay_steps))
645  ```
646
647  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
648  as the learning rate. The learning rate schedule is also serializable and
649  deserializable using `tf.keras.optimizers.schedules.serialize` and
650  `tf.keras.optimizers.schedules.deserialize`.
651
652  Returns:
653    A 1-arg callable learning rate schedule that takes the current optimizer
654    step and outputs the decayed learning rate, a scalar `Tensor` of the same
655    type as `initial_learning_rate`.
656  """
657
658  def __init__(
659      self,
660      initial_learning_rate,
661      first_decay_steps,
662      t_mul=2.0,
663      m_mul=1.0,
664      alpha=0.0,
665      name=None):
666    """Applies cosine decay with restarts to the learning rate.
667
668    Args:
669      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
670        number. The initial learning rate.
671      first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python
672        number. Number of steps to decay over.
673      t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
674        Used to derive the number of iterations in the i-th period
675      m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
676        Used to derive the initial learning rate of the i-th period:
677      alpha: A scalar `float32` or `float64` Tensor or a Python number.
678        Minimum learning rate value as a fraction of the initial_learning_rate.
679      name: String. Optional name of the operation.  Defaults to 'SGDRDecay'.
680    """
681    super(CosineDecayRestarts, self).__init__()
682
683    self.initial_learning_rate = initial_learning_rate
684    self.first_decay_steps = first_decay_steps
685    self._t_mul = t_mul
686    self._m_mul = m_mul
687    self.alpha = alpha
688    self.name = name
689
690  def __call__(self, step):
691    with ops.name_scope_v2(self.name or "SGDRDecay") as name:
692      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
693          self.initial_learning_rate, name="initial_learning_rate")
694      dtype = initial_learning_rate.dtype
695      first_decay_steps = math_ops.cast(self.first_decay_steps, dtype)
696      alpha = math_ops.cast(self.alpha, dtype)
697      t_mul = math_ops.cast(self._t_mul, dtype)
698      m_mul = math_ops.cast(self._m_mul, dtype)
699
700      global_step_recomp = math_ops.cast(step, dtype)
701      completed_fraction = global_step_recomp / first_decay_steps
702
703      def compute_step(completed_fraction, geometric=False):
704        """Helper for `cond` operation."""
705        if geometric:
706          i_restart = math_ops.floor(
707              math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
708              math_ops.log(t_mul))
709
710          sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
711          completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
712
713        else:
714          i_restart = math_ops.floor(completed_fraction)
715          completed_fraction -= i_restart
716
717        return i_restart, completed_fraction
718
719      i_restart, completed_fraction = control_flow_ops.cond(
720          math_ops.equal(t_mul, 1.0),
721          lambda: compute_step(completed_fraction, geometric=False),
722          lambda: compute_step(completed_fraction, geometric=True))
723
724      m_fac = m_mul**i_restart
725      cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
726          constant_op.constant(math.pi) * completed_fraction))
727      decayed = (1 - alpha) * cosine_decayed + alpha
728
729      return math_ops.multiply(initial_learning_rate, decayed, name=name)
730
731  def get_config(self):
732    return {
733        "initial_learning_rate": self.initial_learning_rate,
734        "first_decay_steps": self.first_decay_steps,
735        "t_mul": self._t_mul,
736        "m_mul": self._m_mul,
737        "alpha": self.alpha,
738        "name": self.name
739    }
740
741
742# Note: this code is still used by V1 APIs.
743class LinearCosineDecay(LearningRateSchedule):
744  """A LearningRateSchedule that uses a linear cosine decay schedule.
745
746  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
747  https://arxiv.org/abs/1709.07417
748
749  For the idea of warm starts here controlled by `num_periods`,
750  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
751  with Warm Restarts. https://arxiv.org/abs/1608.03983
752
753  Note that linear cosine decay is more aggressive than cosine decay and
754  larger initial learning rates can typically be used.
755
756  When training a model, it is often recommended to lower the learning rate as
757  the training progresses. This schedule applies a linear cosine decay
758  function to an optimizer step, given a provided initial learning rate.
759  It requires a `step` value to compute the decayed learning rate. You can
760  just pass a TensorFlow variable that you increment at each training step.
761
762  The schedule a 1-arg callable that produces a decayed learning
763  rate when passed the current optimizer step. This can be useful for changing
764  the learning rate value across different invocations of optimizer functions.
765  It is computed as:
766
767  ```python
768  def decayed_learning_rate(step):
769    step = min(step, decay_steps)
770    linear_decay = (decay_steps - step) / decay_steps
771    cosine_decay = 0.5 * (
772        1 + cos(pi * 2 * num_periods * step / decay_steps))
773    decayed = (alpha + linear_decay) * cosine_decay + beta
774    return initial_learning_rate * decayed
775  ```
776
777  Example usage:
778  ```python
779  decay_steps = 1000
780  lr_decayed_fn = (
781    tf.keras.experimental.LinearCosineDecay(
782      initial_learning_rate, decay_steps))
783  ```
784
785  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
786  as the learning rate. The learning rate schedule is also serializable and
787  deserializable using `tf.keras.optimizers.schedules.serialize` and
788  `tf.keras.optimizers.schedules.deserialize`.
789
790  Returns:
791    A 1-arg callable learning rate schedule that takes the current optimizer
792    step and outputs the decayed learning rate, a scalar `Tensor` of the same
793    type as `initial_learning_rate`.
794  """
795
796  def __init__(
797      self,
798      initial_learning_rate,
799      decay_steps,
800      num_periods=0.5,
801      alpha=0.0,
802      beta=0.001,
803      name=None):
804    """Applies linear cosine decay to the learning rate.
805
806    Args:
807      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
808        number. The initial learning rate.
809      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
810        Number of steps to decay over.
811      num_periods: Number of periods in the cosine part of the decay.
812        See computation above.
813      alpha: See computation above.
814      beta: See computation above.
815      name: String.  Optional name of the operation.  Defaults to
816        'LinearCosineDecay'.
817    """
818    super(LinearCosineDecay, self).__init__()
819
820    self.initial_learning_rate = initial_learning_rate
821    self.decay_steps = decay_steps
822    self.num_periods = num_periods
823    self.alpha = alpha
824    self.beta = beta
825    self.name = name
826
827  def __call__(self, step):
828    with ops.name_scope_v2(self.name or "LinearCosineDecay") as name:
829      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
830          self.initial_learning_rate, name="initial_learning_rate")
831      dtype = initial_learning_rate.dtype
832      decay_steps = math_ops.cast(self.decay_steps, dtype)
833      num_periods = math_ops.cast(self.num_periods, dtype)
834      alpha = math_ops.cast(self.alpha, dtype)
835      beta = math_ops.cast(self.beta, dtype)
836
837      global_step_recomp = math_ops.cast(step, dtype)
838      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
839      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
840      completed_fraction = global_step_recomp / decay_steps
841      fraction = 2.0 * num_periods * completed_fraction
842      cosine_decayed = 0.5 * (
843          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
844
845      linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
846      return math_ops.multiply(initial_learning_rate, linear_cosine_decayed,
847                               name=name)
848
849  def get_config(self):
850    return {
851        "initial_learning_rate": self.initial_learning_rate,
852        "decay_steps": self.decay_steps,
853        "num_periods": self.num_periods,
854        "alpha": self.alpha,
855        "beta": self.beta,
856        "name": self.name
857    }
858
859
860# Note: this code is still used by V1 APIs.
861class NoisyLinearCosineDecay(LearningRateSchedule):
862  """A LearningRateSchedule that uses a noisy linear cosine decay schedule.
863
864  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
865  https://arxiv.org/abs/1709.07417
866
867  For the idea of warm starts here controlled by `num_periods`,
868  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
869  with Warm Restarts. https://arxiv.org/abs/1608.03983
870
871  Note that linear cosine decay is more aggressive than cosine decay and
872  larger initial learning rates can typically be used.
873
874  When training a model, it is often recommended to lower the learning rate as
875  the training progresses. This schedule applies a noisy linear cosine decay
876  function to an optimizer step, given a provided initial learning rate.
877  It requires a `step` value to compute the decayed learning rate. You can
878  just pass a TensorFlow variable that you increment at each training step.
879
880  The schedule a 1-arg callable that produces a decayed learning
881  rate when passed the current optimizer step. This can be useful for changing
882  the learning rate value across different invocations of optimizer functions.
883  It is computed as:
884
885  ```python
886  def decayed_learning_rate(step):
887    step = min(step, decay_steps)
888    linear_decay = (decay_steps - step) / decay_steps)
889    cosine_decay = 0.5 * (
890        1 + cos(pi * 2 * num_periods * step / decay_steps))
891    decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
892    return initial_learning_rate * decayed
893  ```
894  where eps_t is 0-centered gaussian noise with variance
895  initial_variance / (1 + global_step) ** variance_decay
896
897  Example usage:
898  ```python
899  decay_steps = 1000
900  lr_decayed_fn = (
901    tf.keras.experimental.NoisyLinearCosineDecay(
902      initial_learning_rate, decay_steps))
903  ```
904
905  You can pass this schedule directly into a `tf.keras.optimizers.Optimizer`
906  as the learning rate. The learning rate schedule is also serializable and
907  deserializable using `tf.keras.optimizers.schedules.serialize` and
908  `tf.keras.optimizers.schedules.deserialize`.
909
910  Returns:
911    A 1-arg callable learning rate schedule that takes the current optimizer
912    step and outputs the decayed learning rate, a scalar `Tensor` of the same
913    type as `initial_learning_rate`.
914  """
915
916  def __init__(
917      self,
918      initial_learning_rate,
919      decay_steps,
920      initial_variance=1.0,
921      variance_decay=0.55,
922      num_periods=0.5,
923      alpha=0.0,
924      beta=0.001,
925      name=None):
926    """Applies noisy linear cosine decay to the learning rate.
927
928    Args:
929      initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python
930        number. The initial learning rate.
931      decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
932        Number of steps to decay over.
933      initial_variance: initial variance for the noise. See computation above.
934      variance_decay: decay for the noise's variance. See computation above.
935      num_periods: Number of periods in the cosine part of the decay.
936        See computation above.
937      alpha: See computation above.
938      beta: See computation above.
939      name: String.  Optional name of the operation.  Defaults to
940        'NoisyLinearCosineDecay'.
941    """
942    super(NoisyLinearCosineDecay, self).__init__()
943
944    self.initial_learning_rate = initial_learning_rate
945    self.decay_steps = decay_steps
946    self.initial_variance = initial_variance
947    self.variance_decay = variance_decay
948    self.num_periods = num_periods
949    self.alpha = alpha
950    self.beta = beta
951    self.name = name
952
953  def __call__(self, step):
954    with ops.name_scope_v2(self.name or "NoisyLinearCosineDecay") as name:
955      initial_learning_rate = ops.convert_to_tensor_v2_with_dispatch(
956          self.initial_learning_rate, name="initial_learning_rate")
957      dtype = initial_learning_rate.dtype
958      decay_steps = math_ops.cast(self.decay_steps, dtype)
959      initial_variance = math_ops.cast(self.initial_variance, dtype)
960      variance_decay = math_ops.cast(self.variance_decay, dtype)
961      num_periods = math_ops.cast(self.num_periods, dtype)
962      alpha = math_ops.cast(self.alpha, dtype)
963      beta = math_ops.cast(self.beta, dtype)
964
965      global_step_recomp = math_ops.cast(step, dtype)
966      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
967      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
968      variance = initial_variance / (
969          math_ops.pow(1.0 + global_step_recomp, variance_decay))
970      std = math_ops.sqrt(variance)
971      noisy_linear_decayed = (
972          linear_decayed + random_ops.random_normal(
973              linear_decayed.shape, stddev=std))
974
975      completed_fraction = global_step_recomp / decay_steps
976      fraction = 2.0 * num_periods * completed_fraction
977      cosine_decayed = 0.5 * (
978          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
979      noisy_linear_cosine_decayed = (
980          (alpha + noisy_linear_decayed) * cosine_decayed + beta)
981
982      return math_ops.multiply(
983          initial_learning_rate, noisy_linear_cosine_decayed, name=name)
984
985  def get_config(self):
986    return {
987        "initial_learning_rate": self.initial_learning_rate,
988        "decay_steps": self.decay_steps,
989        "initial_variance": self.initial_variance,
990        "variance_decay": self.variance_decay,
991        "num_periods": self.num_periods,
992        "alpha": self.alpha,
993        "beta": self.beta,
994        "name": self.name
995    }
996
997
998@keras_export("keras.optimizers.schedules.serialize")
999def serialize(learning_rate_schedule):
1000  return generic_utils.serialize_keras_object(learning_rate_schedule)
1001
1002
1003@keras_export("keras.optimizers.schedules.deserialize")
1004def deserialize(config, custom_objects=None):
1005  return generic_utils.deserialize_keras_object(
1006      config,
1007      module_objects=globals(),
1008      custom_objects=custom_objects,
1009      printable_module_name="decay")
1010