1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various learning rate decay functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21 22from tensorflow.python.eager import context 23from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule 24from tensorflow.python.ops import math_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export(v1=["train.exponential_decay"]) 29def exponential_decay(learning_rate, 30 global_step, 31 decay_steps, 32 decay_rate, 33 staircase=False, 34 name=None): 35 """Applies exponential decay to the learning rate. 36 37 When training a model, it is often recommended to lower the learning rate as 38 the training progresses. This function applies an exponential decay function 39 to a provided initial learning rate. It requires a `global_step` value to 40 compute the decayed learning rate. You can just pass a TensorFlow variable 41 that you increment at each training step. 42 43 The function returns the decayed learning rate. It is computed as: 44 45 ```python 46 decayed_learning_rate = learning_rate * 47 decay_rate ^ (global_step / decay_steps) 48 ``` 49 50 If the argument `staircase` is `True`, then `global_step / decay_steps` is an 51 integer division and the decayed learning rate follows a staircase function. 52 53 Example: decay every 100000 steps with a base of 0.96: 54 55 ```python 56 ... 57 global_step = tf.Variable(0, trainable=False) 58 starter_learning_rate = 0.1 59 learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 60 100000, 0.96, staircase=True) 61 # Passing global_step to minimize() will increment it at each step. 62 learning_step = ( 63 tf.train.GradientDescentOptimizer(learning_rate) 64 .minimize(...my loss..., global_step=global_step) 65 ) 66 ``` 67 68 Args: 69 learning_rate: A scalar `float32` or `float64` `Tensor` or a 70 Python number. The initial learning rate. 71 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 72 Global step to use for the decay computation. Must not be negative. 73 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 74 Must be positive. See the decay computation above. 75 decay_rate: A scalar `float32` or `float64` `Tensor` or a 76 Python number. The decay rate. 77 staircase: Boolean. If `True` decay the learning rate at discrete intervals 78 name: String. Optional name of the operation. Defaults to 79 'ExponentialDecay'. 80 81 Returns: 82 A scalar `Tensor` of the same type as `learning_rate`. The decayed 83 learning rate. 84 85 Raises: 86 ValueError: if `global_step` is not supplied. 87 88 @compatibility(eager) 89 When eager execution is enabled, this function returns a function which in 90 turn returns the decayed learning rate Tensor. This can be useful for changing 91 the learning rate value across different invocations of optimizer functions. 92 @end_compatibility 93 """ 94 decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate, 95 decay_steps, 96 decay_rate, 97 staircase=staircase, 98 name=name) 99 if not context.executing_eagerly(): 100 decayed_lr = decayed_lr(global_step) 101 else: 102 decayed_lr = functools.partial(decayed_lr, global_step) 103 return decayed_lr 104 105 106@tf_export(v1=["train.piecewise_constant_decay", "train.piecewise_constant"]) 107def piecewise_constant(x, boundaries, values, name=None): 108 """Piecewise constant from boundaries and interval values. 109 110 Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 111 for the next 10000 steps, and 0.1 for any additional steps. 112 113 ```python 114 global_step = tf.Variable(0, trainable=False) 115 boundaries = [100000, 110000] 116 values = [1.0, 0.5, 0.1] 117 learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) 118 119 # Later, whenever we perform an optimization step, we increment global_step. 120 ``` 121 122 Args: 123 x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, 124 `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. 125 boundaries: A list of `Tensor`s or `int`s or `float`s with strictly 126 increasing entries, and with all elements having the same type as `x`. 127 values: A list of `Tensor`s or `float`s or `int`s that specifies the values 128 for the intervals defined by `boundaries`. It should have one more element 129 than `boundaries`, and all elements should have the same type. 130 name: A string. Optional name of the operation. Defaults to 131 'PiecewiseConstant'. 132 133 Returns: 134 A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`, 135 `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., 136 and values[-1] when `x > boundaries[-1]`. 137 138 Raises: 139 ValueError: if types of `x` and `boundaries` do not match, or types of all 140 `values` do not match or 141 the number of elements in the lists does not match. 142 143 @compatibility(eager) 144 When eager execution is enabled, this function returns a function which in 145 turn returns the decayed learning rate Tensor. This can be useful for changing 146 the learning rate value across different invocations of optimizer functions. 147 @end_compatibility 148 """ 149 decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( 150 boundaries, values, name=name) 151 if not context.executing_eagerly(): 152 decayed_lr = decayed_lr(x) 153 else: 154 decayed_lr = functools.partial(decayed_lr, x) 155 return decayed_lr 156 157 158@tf_export(v1=["train.polynomial_decay"]) 159def polynomial_decay(learning_rate, 160 global_step, 161 decay_steps, 162 end_learning_rate=0.0001, 163 power=1.0, 164 cycle=False, 165 name=None): 166 """Applies a polynomial decay to the learning rate. 167 168 It is commonly observed that a monotonically decreasing learning rate, whose 169 degree of change is carefully chosen, results in a better performing model. 170 This function applies a polynomial decay function to a provided initial 171 `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. 172 173 It requires a `global_step` value to compute the decayed learning rate. You 174 can just pass a TensorFlow variable that you increment at each training step. 175 176 The function returns the decayed learning rate. It is computed as: 177 178 ```python 179 global_step = min(global_step, decay_steps) 180 decayed_learning_rate = (learning_rate - end_learning_rate) * 181 (1 - global_step / decay_steps) ^ (power) + 182 end_learning_rate 183 184 ``` 185 186 If `cycle` is True then a multiple of `decay_steps` is used, the first one 187 that is bigger than `global_steps`. 188 189 ```python 190 decay_steps = decay_steps * ceil(global_step / decay_steps) 191 decayed_learning_rate = (learning_rate - end_learning_rate) * 192 (1 - global_step / decay_steps) ^ (power) + 193 end_learning_rate 194 195 ``` 196 197 Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5): 198 199 ```python 200 ... 201 global_step = tf.Variable(0, trainable=False) 202 starter_learning_rate = 0.1 203 end_learning_rate = 0.01 204 decay_steps = 10000 205 learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step, 206 decay_steps, end_learning_rate, 207 power=0.5) 208 # Passing global_step to minimize() will increment it at each step. 209 learning_step = ( 210 tf.train.GradientDescentOptimizer(learning_rate) 211 .minimize(...my loss..., global_step=global_step) 212 ) 213 ``` 214 215 Args: 216 learning_rate: A scalar `float32` or `float64` `Tensor` or a 217 Python number. The initial learning rate. 218 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 219 Global step to use for the decay computation. Must not be negative. 220 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 221 Must be positive. See the decay computation above. 222 end_learning_rate: A scalar `float32` or `float64` `Tensor` or a 223 Python number. The minimal end learning rate. 224 power: A scalar `float32` or `float64` `Tensor` or a 225 Python number. The power of the polynomial. Defaults to linear, 1.0. 226 cycle: A boolean, whether or not it should cycle beyond decay_steps. 227 name: String. Optional name of the operation. Defaults to 228 'PolynomialDecay'. 229 230 Returns: 231 A scalar `Tensor` of the same type as `learning_rate`. The decayed 232 learning rate. 233 234 Raises: 235 ValueError: if `global_step` is not supplied. 236 237 @compatibility(eager) 238 When eager execution is enabled, this function returns a function which in 239 turn returns the decayed learning rate Tensor. This can be useful for changing 240 the learning rate value across different invocations of optimizer functions. 241 @end_compatibility 242 """ 243 decayed_lr = learning_rate_schedule.PolynomialDecay( 244 learning_rate, 245 decay_steps, 246 end_learning_rate=end_learning_rate, 247 power=power, 248 cycle=cycle, 249 name=name) 250 251 if not context.executing_eagerly(): 252 decayed_lr = decayed_lr(global_step) 253 else: 254 decayed_lr = functools.partial(decayed_lr, global_step) 255 return decayed_lr 256 257 258@tf_export(v1=["train.natural_exp_decay"]) 259def natural_exp_decay(learning_rate, 260 global_step, 261 decay_steps, 262 decay_rate, 263 staircase=False, 264 name=None): 265 """Applies natural exponential decay to the initial learning rate. 266 267 When training a model, it is often recommended to lower the learning rate as 268 the training progresses. This function applies an exponential decay function 269 to a provided initial learning rate. It requires an `global_step` value to 270 compute the decayed learning rate. You can just pass a TensorFlow variable 271 that you increment at each training step. 272 273 The function returns the decayed learning rate. It is computed as: 274 275 ```python 276 decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / 277 decay_step) 278 ``` 279 280 or, if `staircase` is `True`, as: 281 282 ```python 283 decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / 284 decay_step)) 285 ``` 286 287 Example: decay exponentially with a base of 0.96: 288 289 ```python 290 ... 291 global_step = tf.Variable(0, trainable=False) 292 learning_rate = 0.1 293 decay_steps = 5 294 k = 0.5 295 learning_rate = tf.train.natural_exp_decay(learning_rate, global_step, 296 decay_steps, k) 297 298 # Passing global_step to minimize() will increment it at each step. 299 learning_step = ( 300 tf.train.GradientDescentOptimizer(learning_rate) 301 .minimize(...my loss..., global_step=global_step) 302 ) 303 ``` 304 305 Args: 306 learning_rate: A scalar `float32` or `float64` `Tensor` or a 307 Python number. The initial learning rate. 308 global_step: A Python number. 309 Global step to use for the decay computation. Must not be negative. 310 decay_steps: How often to apply decay. 311 decay_rate: A Python number. The decay rate. 312 staircase: Whether to apply decay in a discrete staircase, as opposed to 313 continuous, fashion. 314 name: String. Optional name of the operation. Defaults to 315 'ExponentialTimeDecay'. 316 317 Returns: 318 A scalar `Tensor` of the same type as `learning_rate`. The decayed 319 learning rate. 320 321 Raises: 322 ValueError: if `global_step` is not supplied. 323 324 @compatibility(eager) 325 When eager execution is enabled, this function returns a function which in 326 turn returns the decayed learning rate Tensor. This can be useful for changing 327 the learning rate value across different invocations of optimizer functions. 328 @end_compatibility 329 """ 330 natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate)) 331 decayed_lr = learning_rate_schedule.ExponentialDecay( 332 learning_rate, decay_steps, natural_exp_rate, staircase=staircase, 333 name=name) 334 335 if not context.executing_eagerly(): 336 decayed_lr = decayed_lr(global_step) 337 else: 338 decayed_lr = functools.partial(decayed_lr, global_step) 339 return decayed_lr 340 341 342@tf_export(v1=["train.inverse_time_decay"]) 343def inverse_time_decay(learning_rate, 344 global_step, 345 decay_steps, 346 decay_rate, 347 staircase=False, 348 name=None): 349 """Applies inverse time decay to the initial learning rate. 350 351 When training a model, it is often recommended to lower the learning rate as 352 the training progresses. This function applies an inverse decay function 353 to a provided initial learning rate. It requires an `global_step` value to 354 compute the decayed learning rate. You can just pass a TensorFlow variable 355 that you increment at each training step. 356 357 The function returns the decayed learning rate. It is computed as: 358 359 ```python 360 decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / 361 decay_step) 362 ``` 363 364 or, if `staircase` is `True`, as: 365 366 ```python 367 decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / 368 decay_step)) 369 ``` 370 371 Example: decay 1/t with a rate of 0.5: 372 373 ```python 374 ... 375 global_step = tf.Variable(0, trainable=False) 376 learning_rate = 0.1 377 decay_steps = 1.0 378 decay_rate = 0.5 379 learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, 380 decay_steps, decay_rate) 381 382 # Passing global_step to minimize() will increment it at each step. 383 learning_step = ( 384 tf.train.GradientDescentOptimizer(learning_rate) 385 .minimize(...my loss..., global_step=global_step) 386 ) 387 ``` 388 389 Args: 390 learning_rate: A scalar `float32` or `float64` `Tensor` or a 391 Python number. The initial learning rate. 392 global_step: A Python number. 393 Global step to use for the decay computation. Must not be negative. 394 decay_steps: How often to apply decay. 395 decay_rate: A Python number. The decay rate. 396 staircase: Whether to apply decay in a discrete staircase, as opposed to 397 continuous, fashion. 398 name: String. Optional name of the operation. Defaults to 399 'InverseTimeDecay'. 400 401 Returns: 402 A scalar `Tensor` of the same type as `learning_rate`. The decayed 403 learning rate. 404 405 Raises: 406 ValueError: if `global_step` is not supplied. 407 408 @compatibility(eager) 409 When eager execution is enabled, this function returns a function which in 410 turn returns the decayed learning rate Tensor. This can be useful for changing 411 the learning rate value across different invocations of optimizer functions. 412 @end_compatibility 413 """ 414 decayed_lr = learning_rate_schedule.InverseTimeDecay( 415 learning_rate, 416 decay_steps, 417 decay_rate, 418 staircase=staircase, 419 name=name) 420 421 if not context.executing_eagerly(): 422 decayed_lr = decayed_lr(global_step) 423 else: 424 decayed_lr = functools.partial(decayed_lr, global_step) 425 return decayed_lr 426 427 428@tf_export(v1=["train.cosine_decay"]) 429def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): 430 """Applies cosine decay to the learning rate. 431 432 See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent 433 with Warm Restarts. https://arxiv.org/abs/1608.03983 434 435 When training a model, it is often recommended to lower the learning rate as 436 the training progresses. This function applies a cosine decay function 437 to a provided initial learning rate. It requires a `global_step` value to 438 compute the decayed learning rate. You can just pass a TensorFlow variable 439 that you increment at each training step. 440 441 The function returns the decayed learning rate. It is computed as: 442 ```python 443 global_step = min(global_step, decay_steps) 444 cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps)) 445 decayed = (1 - alpha) * cosine_decay + alpha 446 decayed_learning_rate = learning_rate * decayed 447 ``` 448 449 Example usage: 450 ```python 451 decay_steps = 1000 452 lr_decayed = cosine_decay(learning_rate, global_step, decay_steps) 453 ``` 454 455 Args: 456 learning_rate: A scalar `float32` or `float64` Tensor or a Python number. 457 The initial learning rate. 458 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 459 Global step to use for the decay computation. 460 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 461 Number of steps to decay over. 462 alpha: A scalar `float32` or `float64` Tensor or a Python number. 463 Minimum learning rate value as a fraction of learning_rate. 464 name: String. Optional name of the operation. Defaults to 'CosineDecay'. 465 Returns: 466 A scalar `Tensor` of the same type as `learning_rate`. The decayed 467 learning rate. 468 Raises: 469 ValueError: if `global_step` is not supplied. 470 471 @compatibility(eager) 472 When eager execution is enabled, this function returns a function which in 473 turn returns the decayed learning rate Tensor. This can be useful for changing 474 the learning rate value across different invocations of optimizer functions. 475 @end_compatibility 476 """ 477 decayed_lr = learning_rate_schedule.CosineDecay( 478 learning_rate, decay_steps, alpha=alpha, name=name) 479 480 if not context.executing_eagerly(): 481 decayed_lr = decayed_lr(global_step) 482 else: 483 decayed_lr = functools.partial(decayed_lr, global_step) 484 return decayed_lr 485 486 487@tf_export(v1=["train.cosine_decay_restarts"]) 488def cosine_decay_restarts(learning_rate, 489 global_step, 490 first_decay_steps, 491 t_mul=2.0, 492 m_mul=1.0, 493 alpha=0.0, 494 name=None): 495 """Applies cosine decay with restarts to the learning rate. 496 497 See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent 498 with Warm Restarts. https://arxiv.org/abs/1608.03983 499 500 When training a model, it is often recommended to lower the learning rate as 501 the training progresses. This function applies a cosine decay function with 502 restarts to a provided initial learning rate. It requires a `global_step` 503 value to compute the decayed learning rate. You can just pass a TensorFlow 504 variable that you increment at each training step. 505 506 The function returns the decayed learning rate while taking into account 507 possible warm restarts. The learning rate multiplier first decays 508 from 1 to `alpha` for `first_decay_steps` steps. Then, a warm 509 restart is performed. Each new warm restart runs for `t_mul` times more steps 510 and with `m_mul` times smaller initial learning rate. 511 512 Example usage: 513 ```python 514 first_decay_steps = 1000 515 lr_decayed = cosine_decay_restarts(learning_rate, global_step, 516 first_decay_steps) 517 ``` 518 519 Args: 520 learning_rate: A scalar `float32` or `float64` Tensor or a Python number. 521 The initial learning rate. 522 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 523 Global step to use for the decay computation. 524 first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 525 Number of steps to decay over. 526 t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. 527 Used to derive the number of iterations in the i-th period 528 m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. 529 Used to derive the initial learning rate of the i-th period: 530 alpha: A scalar `float32` or `float64` Tensor or a Python number. 531 Minimum learning rate value as a fraction of the learning_rate. 532 name: String. Optional name of the operation. Defaults to 'SGDRDecay'. 533 Returns: 534 A scalar `Tensor` of the same type as `learning_rate`. The decayed 535 learning rate. 536 Raises: 537 ValueError: if `global_step` is not supplied. 538 539 @compatibility(eager) 540 When eager execution is enabled, this function returns a function which in 541 turn returns the decayed learning rate Tensor. This can be useful for changing 542 the learning rate value across different invocations of optimizer functions. 543 @end_compatibility 544 """ 545 decayed_lr = learning_rate_schedule.CosineDecayRestarts( 546 learning_rate, 547 first_decay_steps, 548 t_mul=t_mul, 549 m_mul=m_mul, 550 alpha=alpha, 551 name=name) 552 553 if not context.executing_eagerly(): 554 decayed_lr = decayed_lr(global_step) 555 else: 556 decayed_lr = functools.partial(decayed_lr, global_step) 557 return decayed_lr 558 559 560@tf_export(v1=["train.linear_cosine_decay"]) 561def linear_cosine_decay(learning_rate, 562 global_step, 563 decay_steps, 564 num_periods=0.5, 565 alpha=0.0, 566 beta=0.001, 567 name=None): 568 """Applies linear cosine decay to the learning rate. 569 570 See [Bello et al., ICML2017] Neural Optimizer Search with RL. 571 https://arxiv.org/abs/1709.07417 572 573 For the idea of warm starts here controlled by `num_periods`, 574 see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent 575 with Warm Restarts. https://arxiv.org/abs/1608.03983 576 577 Note that linear cosine decay is more aggressive than cosine decay and 578 larger initial learning rates can typically be used. 579 580 When training a model, it is often recommended to lower the learning rate as 581 the training progresses. This function applies a linear cosine decay function 582 to a provided initial learning rate. It requires a `global_step` value to 583 compute the decayed learning rate. You can just pass a TensorFlow variable 584 that you increment at each training step. 585 586 The function returns the decayed learning rate. It is computed as: 587 ```python 588 global_step = min(global_step, decay_steps) 589 linear_decay = (decay_steps - global_step) / decay_steps) 590 cosine_decay = 0.5 * ( 591 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) 592 decayed = (alpha + linear_decay) * cosine_decay + beta 593 decayed_learning_rate = learning_rate * decayed 594 ``` 595 596 Example usage: 597 ```python 598 decay_steps = 1000 599 lr_decayed = linear_cosine_decay(learning_rate, global_step, decay_steps) 600 ``` 601 602 Args: 603 learning_rate: A scalar `float32` or `float64` Tensor or a Python number. 604 The initial learning rate. 605 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 606 Global step to use for the decay computation. 607 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 608 Number of steps to decay over. 609 num_periods: Number of periods in the cosine part of the decay. 610 See computation above. 611 alpha: See computation above. 612 beta: See computation above. 613 name: String. Optional name of the operation. Defaults to 614 'LinearCosineDecay'. 615 Returns: 616 A scalar `Tensor` of the same type as `learning_rate`. The decayed 617 learning rate. 618 Raises: 619 ValueError: if `global_step` is not supplied. 620 621 @compatibility(eager) 622 When eager execution is enabled, this function returns a function which in 623 turn returns the decayed learning rate Tensor. This can be useful for changing 624 the learning rate value across different invocations of optimizer functions. 625 @end_compatibility 626 """ 627 decayed_lr = learning_rate_schedule.LinearCosineDecay( 628 learning_rate, 629 decay_steps, 630 num_periods=num_periods, 631 alpha=alpha, 632 beta=beta, 633 name=name) 634 635 if not context.executing_eagerly(): 636 decayed_lr = decayed_lr(global_step) 637 else: 638 decayed_lr = functools.partial(decayed_lr, global_step) 639 return decayed_lr 640 641 642@tf_export(v1=["train.noisy_linear_cosine_decay"]) 643def noisy_linear_cosine_decay(learning_rate, 644 global_step, 645 decay_steps, 646 initial_variance=1.0, 647 variance_decay=0.55, 648 num_periods=0.5, 649 alpha=0.0, 650 beta=0.001, 651 name=None): 652 """Applies noisy linear cosine decay to the learning rate. 653 654 See [Bello et al., ICML2017] Neural Optimizer Search with RL. 655 https://arxiv.org/abs/1709.07417 656 657 For the idea of warm starts here controlled by `num_periods`, 658 see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent 659 with Warm Restarts. https://arxiv.org/abs/1608.03983 660 661 Note that linear cosine decay is more aggressive than cosine decay and 662 larger initial learning rates can typically be used. 663 664 When training a model, it is often recommended to lower the learning rate as 665 the training progresses. This function applies a noisy linear 666 cosine decay function to a provided initial learning rate. 667 It requires a `global_step` value to compute the decayed learning rate. 668 You can just pass a TensorFlow variable that you increment at each 669 training step. 670 671 The function returns the decayed learning rate. It is computed as: 672 ```python 673 global_step = min(global_step, decay_steps) 674 linear_decay = (decay_steps - global_step) / decay_steps) 675 cosine_decay = 0.5 * ( 676 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) 677 decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta 678 decayed_learning_rate = learning_rate * decayed 679 ``` 680 where eps_t is 0-centered gaussian noise with variance 681 initial_variance / (1 + global_step) ** variance_decay 682 683 Example usage: 684 ```python 685 decay_steps = 1000 686 lr_decayed = noisy_linear_cosine_decay( 687 learning_rate, global_step, decay_steps) 688 ``` 689 690 Args: 691 learning_rate: A scalar `float32` or `float64` Tensor or a Python number. 692 The initial learning rate. 693 global_step: A scalar `int32` or `int64` `Tensor` or a Python number. 694 Global step to use for the decay computation. 695 decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. 696 Number of steps to decay over. 697 initial_variance: initial variance for the noise. See computation above. 698 variance_decay: decay for the noise's variance. See computation above. 699 num_periods: Number of periods in the cosine part of the decay. 700 See computation above. 701 alpha: See computation above. 702 beta: See computation above. 703 name: String. Optional name of the operation. Defaults to 704 'NoisyLinearCosineDecay'. 705 Returns: 706 A scalar `Tensor` of the same type as `learning_rate`. The decayed 707 learning rate. 708 Raises: 709 ValueError: if `global_step` is not supplied. 710 711 @compatibility(eager) 712 When eager execution is enabled, this function returns a function which in 713 turn returns the decayed learning rate Tensor. This can be useful for changing 714 the learning rate value across different invocations of optimizer functions. 715 @end_compatibility 716 """ 717 decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( 718 learning_rate, 719 decay_steps, 720 initial_variance=initial_variance, 721 variance_decay=variance_decay, 722 num_periods=num_periods, 723 alpha=alpha, 724 beta=beta, 725 name=name) 726 727 if not context.executing_eagerly(): 728 decayed_lr = decayed_lr(global_step) 729 else: 730 decayed_lr = functools.partial(decayed_lr, global_step) 731 return decayed_lr 732