1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ftrl-proximal for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.keras.optimizer_v2 import optimizer_v2 21from tensorflow.python.ops import init_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import training_ops 24from tensorflow.python.util.tf_export import keras_export 25 26 27@keras_export('keras.optimizers.Ftrl') 28class Ftrl(optimizer_v2.OptimizerV2): 29 r"""Optimizer that implements the FTRL algorithm. 30 31 See Algorithm 1 of this [paper]( 32 https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 33 This version has support for both online L2 (the L2 penalty given in the paper 34 above) and shrinkage-type L2 (which is the addition of an L2 penalty to the 35 loss function). 36 37 Initialization: 38 $t = 0$ 39 $n_{0} = 0$ 40 $\sigma_{0} = 0$ 41 $z_{0} = 0$ 42 43 Update ($i$ is variable index): 44 $t = t + 1$ 45 $n_{t,i} = n_{t-1,i} + g_{t,i}^{2}$ 46 $\sigma_{t,i} = (\sqrt{n_{t,i}} - \sqrt{n_{t-1,i}}) / \alpha$ 47 $z_{t,i} = z_{t-1,i} + g_{t,i} - \sigma_{t,i} * w_{t,i}$ 48 $w_{t,i} = - ((\beta+\sqrt{n+{t}}) / \alpha + \lambda_{2})^{-1} * (z_{i} - 49 sgn(z_{i}) * \lambda_{1}) if \abs{z_{i}} > \lambda_{i} else 0$ 50 51 Check the documentation for the l2_shrinkage_regularization_strength 52 parameter for more details when shrinkage is enabled, where gradient is 53 replaced with gradient_with_shrinkage. 54 """ 55 56 def __init__(self, 57 learning_rate, 58 learning_rate_power=-0.5, 59 initial_accumulator_value=0.1, 60 l1_regularization_strength=0.0, 61 l2_regularization_strength=0.0, 62 name='Ftrl', 63 l2_shrinkage_regularization_strength=0.0, 64 **kwargs): 65 r"""Construct a new FTRL optimizer. 66 67 Args: 68 learning_rate: A float value or a constant float `Tensor`. 69 learning_rate_power: A float value, must be less or equal to zero. 70 Controls how the learning rate decreases during training. Use zero for 71 a fixed learning rate. 72 initial_accumulator_value: The starting value for accumulators. 73 Only zero or positive values are allowed. 74 l1_regularization_strength: A float value, must be greater than or 75 equal to zero. 76 l2_regularization_strength: A float value, must be greater than or 77 equal to zero. 78 name: Optional name prefix for the operations created when applying 79 gradients. Defaults to "Ftrl". 80 l2_shrinkage_regularization_strength: A float value, must be greater than 81 or equal to zero. This differs from L2 above in that the L2 above is a 82 stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. 83 The FTRL formulation can be written as: 84 w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where 85 \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss 86 function w.r.t. the weights w. 87 Specifically, in the absence of L1 regularization, it is equivalent to 88 the following update rule: 89 w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t - 90 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t 91 where lr_t is the learning rate at t. 92 When input is sparse shrinkage will only happen on the active weights.\ 93 **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, 94 `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip 95 gradients by value, `decay` is included for backward compatibility to 96 allow time inverse decay of learning rate. `lr` is included for backward 97 compatibility, recommended to use `learning_rate` instead. 98 99 Raises: 100 ValueError: If one of the arguments is invalid. 101 102 References 103 See [paper] 104 (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) 105 """ 106 super(Ftrl, self).__init__(name, **kwargs) 107 108 if initial_accumulator_value < 0.0: 109 raise ValueError( 110 'initial_accumulator_value %f needs to be positive or zero' % 111 initial_accumulator_value) 112 if learning_rate_power > 0.0: 113 raise ValueError('learning_rate_power %f needs to be negative or zero' % 114 learning_rate_power) 115 if l1_regularization_strength < 0.0: 116 raise ValueError( 117 'l1_regularization_strength %f needs to be positive or zero' % 118 l1_regularization_strength) 119 if l2_regularization_strength < 0.0: 120 raise ValueError( 121 'l2_regularization_strength %f needs to be positive or zero' % 122 l2_regularization_strength) 123 if l2_shrinkage_regularization_strength < 0.0: 124 raise ValueError( 125 'l2_shrinkage_regularization_strength %f needs to be positive' 126 ' or zero' % l2_shrinkage_regularization_strength) 127 128 self._set_hyper('learning_rate', learning_rate) 129 self._set_hyper('decay', self._initial_decay) 130 self._set_hyper('learning_rate_power', learning_rate_power) 131 self._set_hyper('l1_regularization_strength', l1_regularization_strength) 132 self._set_hyper('l2_regularization_strength', l2_regularization_strength) 133 self._initial_accumulator_value = initial_accumulator_value 134 self._l2_shrinkage_regularization_strength = ( 135 l2_shrinkage_regularization_strength) 136 137 def _create_slots(self, var_list): 138 # Create the "accum" and "linear" slots. 139 for var in var_list: 140 dtype = var.dtype.base_dtype 141 init = init_ops.constant_initializer( 142 self._initial_accumulator_value, dtype=dtype) 143 self.add_slot(var, 'accumulator', init) 144 self.add_slot(var, 'linear') 145 146 def _resource_apply_dense(self, grad, var): 147 var_dtype = var.dtype.base_dtype 148 lr_t = self._decayed_lr(var_dtype) 149 learning_rate_power = self._get_hyper('learning_rate_power', var_dtype) 150 l1_regularization_strength = self._get_hyper('l1_regularization_strength', 151 var_dtype) 152 l2_regularization_strength = self._get_hyper('l2_regularization_strength', 153 var_dtype) 154 accum = self.get_slot(var, 'accumulator') 155 linear = self.get_slot(var, 'linear') 156 if self._l2_shrinkage_regularization_strength <= 0.0: 157 return training_ops.resource_apply_ftrl( 158 var.handle, 159 accum.handle, 160 linear.handle, 161 grad, 162 lr_t, 163 l1_regularization_strength, 164 l2_regularization_strength, 165 learning_rate_power, 166 use_locking=self._use_locking) 167 else: 168 return training_ops.resource_apply_ftrl_v2( 169 var.handle, 170 accum.handle, 171 linear.handle, 172 grad, 173 lr_t, 174 l1_regularization_strength, 175 l2_regularization_strength, 176 math_ops.cast(self._l2_shrinkage_regularization_strength, var_dtype), 177 learning_rate_power, 178 use_locking=self._use_locking) 179 180 def _resource_apply_sparse(self, grad, var, indices): 181 var_dtype = var.dtype.base_dtype 182 lr_t = self._decayed_lr(var_dtype) 183 learning_rate_power = self._get_hyper('learning_rate_power', var_dtype) 184 l1_regularization_strength = self._get_hyper('l1_regularization_strength', 185 var_dtype) 186 l2_regularization_strength = self._get_hyper('l2_regularization_strength', 187 var_dtype) 188 accum = self.get_slot(var, 'accumulator') 189 linear = self.get_slot(var, 'linear') 190 if self._l2_shrinkage_regularization_strength <= 0.0: 191 return training_ops.resource_sparse_apply_ftrl( 192 var.handle, 193 accum.handle, 194 linear.handle, 195 grad, 196 indices, 197 lr_t, 198 l1_regularization_strength, 199 l2_regularization_strength, 200 learning_rate_power, 201 use_locking=self._use_locking) 202 else: 203 return training_ops.resource_sparse_apply_ftrl_v2( 204 var.handle, 205 accum.handle, 206 linear.handle, 207 grad, 208 indices, 209 lr_t, 210 l1_regularization_strength, 211 l2_regularization_strength, 212 math_ops.cast(self._l2_shrinkage_regularization_strength, var_dtype), 213 learning_rate_power, 214 use_locking=self._use_locking) 215 216 def get_config(self): 217 config = super(Ftrl, self).get_config() 218 config.update({ 219 'learning_rate': 220 self._serialize_hyperparameter('learning_rate'), 221 'decay': 222 self._serialize_hyperparameter('decay'), 223 'initial_accumulator_value': 224 self._initial_accumulator_value, 225 'learning_rate_power': 226 self._serialize_hyperparameter('learning_rate_power'), 227 'l1_regularization_strength': 228 self._serializer_hyperparameter('l1_regularization_strength'), 229 'l2_regularization_strength': 230 self._serializer_hyperparameter('l2_regularization_strength'), 231 'l2_shrinkage_regularization_strength': 232 self._l2_shrinkage_regularization_strength, 233 }) 234 return config 235