• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17 
18 #define EIGEN_USE_GPU
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/training_ops.h"
22 #include "tensorflow/core/util/gpu_kernel_helper.h"
23 
24 namespace tensorflow {
25 
26 typedef Eigen::GpuDevice GPUDevice;
27 
28 namespace functor {
29 
30 #if TENSORFLOW_USE_ROCM
31 
32 #include "rocm/include/hip/hip_complex.h"
33 
34 #endif  // TENSORFLOW_USE_ROCM
35 
36 // if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
37 // process completes without errors,but the resulting executable ends up
38 // unusable (throwing errors "no device code available for function" for
39 /// completely unrelated kernels.)
40 // We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
41 // not provide sqrt for complex.
42 // We have no choice but to implement sqrt and rsqrt by hand
43 template <typename T>
impl_sqrt(T x)44 __device__ T impl_sqrt(T x) {
45   return sqrt(x);
46 }
47 template <typename T>
impl_rsqrt(T x)48 __device__ T impl_rsqrt(T x) {
49   return rsqrt(x);
50 }
51 template <>
impl_sqrt(Eigen::half x)52 __device__ Eigen::half impl_sqrt(Eigen::half x) {
53   return __float2half(sqrt(__half2float(x)));
54 }
55 template <>
impl_rsqrt(Eigen::half x)56 __device__ Eigen::half impl_rsqrt(Eigen::half x) {
57   return __float2half(rsqrt(__half2float(x)));
58 }
59 
60 template <class T>
impl_sqrt(std::complex<T> x)61 __device__ std::complex<T> impl_sqrt(std::complex<T> x) {
62   T re = x.real(), im = x.imag();
63   T mod_x = sqrt(re * re + im * im);
64   const T root2 = 0.7071067811865475;
65   // We pick the root with the same sign of the imaginary component as
66   // the input.
67   T root[2] = {T(sqrt(mod_x + re) * root2),
68                T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
69   // hcc/clang is really weird with its support of complex in device code;
70   // for some reason it does not permit a 2-argument constructor
71   return *(reinterpret_cast<std::complex<T>*>(&root));
72 }
73 
74 template <class T>
rsqrt_helper(T x)75 __device__ T rsqrt_helper(T x) {
76   return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
77 }
78 
79 template <class T>
impl_rsqrt(std::complex<T> x)80 __device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
81   T re = x.real(), im = x.imag();
82   T r = rsqrt(re * re + im * im);
83   T ar2 = re * r * r;
84   const T root2 = 0.7071067811865475;
85   T root[2];
86   // With float, calculating 1+re*r and 1-re*r may result in excessive errors
87   // due to subtraction of two close values. We have to get fancy
88   root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
89                           ? rsqrt_helper(im * im * r * r)
90                           : max(T(0.0), 1 + re * r))) *
91             root2;
92   root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
93                           ? rsqrt_helper(im * im * r * r)
94                           : max(T(0.0), 1 - re * r))) *
95             root2 * (im >= 0 ? -1. : 1.);
96   return *(reinterpret_cast<std::complex<T>*>(&root));
97 }
98 
99 template <typename T>
impl_fabs(T x)100 __device__ T impl_fabs(T x) {
101   return fabs(x);
102 }
103 template <>
impl_fabs(Eigen::half x)104 __device__ Eigen::half impl_fabs(Eigen::half x) {
105   return __float2half(fabs(__half2float(x)));
106 }
107 
108 template <typename T>
impl_sign(T x)109 __device__ T impl_sign(T x) {
110   return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
111 }
112 
113 template <typename T, typename Tindex, bool has_epsilon>
SparseApplyAdagradKernel(T * var,T * accum,const T * lr,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool update_slots)114 __global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
115     T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
116     const Tindex* indices, Tindex param_rows, Tindex updates_size,
117     Tindex indices_size, bool update_slots) {
118   Tindex col_size = updates_size / indices_size;
119   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
120     Tindex indices_row = grad_index / col_size;
121     Tindex param_row = indices[indices_row];
122     if (param_row < 0 || param_row >= param_rows) {
123       // Ignore indices that are out of range.
124       continue;
125     }
126 
127     // Compute the index of var and accum.
128     Tindex param_index = param_row * col_size + (grad_index % col_size);
129 
130     // Read variables.
131     T var_i = var[param_index];
132     T accum_i = accum[param_index];
133     T grad_i = grad[grad_index];
134     const T lr_t = *lr;
135     const T epsilon_t = *epsilon;
136 
137     if (update_slots) {
138       accum_i += grad_i * grad_i;
139     }
140     if (has_epsilon) {
141       var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
142     } else {
143       var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
144     }
145 
146     // Write update back to variables.
147     var[param_index] = var_i;
148     accum[param_index] = accum_i;
149   }
150 }
151 
152 template <typename T, typename Tindex>
SparseApplyProximalAdagradKernel(T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)153 __global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
154     T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
155     const Tindex* indices, Tindex param_rows, Tindex updates_size,
156     Tindex indices_size) {
157   Tindex col_size = updates_size / indices_size;
158   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
159     Tindex indices_row = grad_index / col_size;
160     Tindex param_row = indices[indices_row];
161     if (param_row < 0 || param_row >= param_rows) {
162       // Ignore indices that are out of range.
163       continue;
164     }
165 
166     // Compute the index of var and accum.
167     Tindex param_index = param_row * col_size + (grad_index % col_size);
168 
169     // Read variables.
170     T var_i = var[param_index];
171     T accum_i = accum[param_index];
172     T grad_i = grad[grad_index];
173     const T lr_t = *lr;
174     const T l1_t = *l1;
175     const T l2_t = *l2;
176 
177     accum_i += grad_i * grad_i;
178     T learning_rate = lr_t * impl_rsqrt(accum_i);
179     // compute v = w - lr * grad.
180     T prox_var_i = var_i - grad_i * learning_rate;
181     // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
182     var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
183             max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
184             (T(1.) + l2_t * learning_rate);
185 
186     // Write update back to variables.
187     var[param_index] = var_i;
188     accum[param_index] = accum_i;
189   }
190 }
191 
192 template <typename T, typename Tindex, bool has_l2_shrinkage>
SparseApplyFtrlKernel(T * var,T * accum,T * linear,const T * lr,const T * l1,const T * l2,const T * l2_shrinkage,const T * lr_power,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool multiply_linear_by_lr)193 __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
194                                       const T* l1, const T* l2,
195                                       const T* l2_shrinkage, const T* lr_power,
196                                       const T* grad, const Tindex* indices,
197                                       Tindex param_rows, Tindex updates_size,
198                                       Tindex indices_size,
199                                       bool multiply_linear_by_lr) {
200   const Tindex col_size = updates_size / indices_size;
201   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
202     const Tindex indices_row = grad_index / col_size;
203     const Tindex param_row = indices[indices_row];
204     if (param_row < 0 || param_row >= param_rows) {
205       // Ignore indices that are out of range.
206       continue;
207     }
208 
209     // Compute the index of var and accum.
210     const Tindex param_index = param_row * col_size + (grad_index % col_size);
211 
212     // Read variables.
213     T var_i = var[param_index];
214     T accum_i = accum[param_index];
215     T linear_i = linear[param_index];
216     const T grad_i = grad[grad_index];
217     const T lr_t = *lr;
218     const T l1_t = *l1;
219     const T l2_t = *l2;
220     const T lr_power_t = *lr_power;
221 
222     const T grad_shr_i =
223         has_l2_shrinkage ? grad_i + static_cast<T>(2) * (*l2_shrinkage) * var_i
224                          : grad_i;
225     const T new_accum_i = accum_i + grad_i * grad_i;
226     const bool lr_power_is_neg_half = lr_power_t == static_cast<T>(-0.5);
227     const T pow_new_accum = lr_power_is_neg_half
228                                 ? sqrt(new_accum_i)
229                                 : pow(new_accum_i, -lr_power_t);
230     const T pow_accum =
231         lr_power_is_neg_half ? sqrt(accum_i) : pow(accum_i, -lr_power_t);
232     T linear_change = grad_shr_i * lr_t - (pow_new_accum - pow_accum) * var_i;
233     if (!multiply_linear_by_lr) {
234       linear_change /= lr_t;
235     }
236     linear_i += linear_change;
237 
238     T l1_mult = l1_t;
239     if (multiply_linear_by_lr) {
240       l1_mult *= lr_t;
241     }
242     const T l1_reg_adjust = max(min(linear_i, l1_mult), -l1_mult);
243     const T x = l1_reg_adjust - linear_i;
244     T y = pow_new_accum + static_cast<T>(2) * l2_t * lr_t;
245     if (!multiply_linear_by_lr) {
246       y /= lr_t;
247     }
248     var_i = x / y;
249     accum_i = new_accum_i;
250 
251     // Write update back to variables.
252     var[param_index] = var_i;
253     accum[param_index] = accum_i;
254     linear[param_index] = linear_i;
255   }
256 }
257 
258 template <typename T>
ApplyAdamKernel(int32 data_dim,T * var,T * m,T * v,const T * const beta1_power_,const T * const beta2_power_,const T * const lr_,const T * const beta1_,const T * const beta2_,const T * const epsilon_,const T * grad,bool use_nesterov)259 __global__ __launch_bounds__(1024) void ApplyAdamKernel(
260     int32 data_dim, T* var, T* m, T* v, const T* const beta1_power_,
261     const T* const beta2_power_, const T* const lr_, const T* const beta1_,
262     const T* const beta2_, const T* const epsilon_, const T* grad,
263     bool use_nesterov) {
264   eigen_assert(blockDim.y == 1);
265   eigen_assert(blockDim.z == 1);
266   eigen_assert(gridDim.y == 1);
267   eigen_assert(gridDim.z == 1);
268 
269   const T mul_factor = (*lr_) * sqrt(static_cast<T>(1.0) - (*beta2_power_)) /
270                        (static_cast<T>(1.0) - (*beta1_power_));
271   const T epsilon = (*epsilon_);
272   const T beta1 = (*beta1_);
273   const T one_minus_beta1 = static_cast<T>(1.0) - (beta1);
274   const T one_minus_beta2 = static_cast<T>(1.0) - (*beta2_);
275   const int32 stripe = gridDim.x * blockDim.x;
276 
277   for (int32 i = blockIdx.x * blockDim.x + threadIdx.x; i < data_dim;
278        i += stripe) {
279     auto m_i = m[i];
280     auto g_i = grad[i];
281     auto v_i = v[i];
282 
283     m_i += one_minus_beta1 * (g_i - m_i);
284     v_i += one_minus_beta2 * (g_i * g_i - v_i);
285     if (use_nesterov) {
286       var[i] -= mul_factor * (m_i * beta1 + one_minus_beta1 * g_i) /
287                 (epsilon + sqrt(v_i));
288     } else {
289       var[i] -= mul_factor * m_i / (epsilon + sqrt(v_i));
290     }
291 
292     m[i] = m_i;
293     v[i] = v_i;
294   }
295 }
296 
297 template <typename T, typename Tindex>
SparseApplyKerasMomentumKernel(T * var,T * accum,const T * lr,const T * grad,const Tindex * indices,const T * momentum,bool use_nesterov,Tindex param_rows,Tindex updates_size,Tindex indices_size)298 __global__ __launch_bounds__(1024) void SparseApplyKerasMomentumKernel(
299     T* var, T* accum, const T* lr, const T* grad, const Tindex* indices,
300     const T* momentum, bool use_nesterov, Tindex param_rows,
301     Tindex updates_size, Tindex indices_size) {
302   Tindex col_size = updates_size / indices_size;
303   GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
304     Tindex indices_row = grad_index / col_size;
305     Tindex param_row = indices[indices_row];
306     if (param_row < 0 || param_row >= param_rows) {
307       // Ignore indices that are out of range.
308       continue;
309     }
310 
311     // Compute the index of var and accum.
312     Tindex param_index = param_row * col_size + (grad_index % col_size);
313 
314     // Read variables.
315     T var_i = var[param_index];
316     T accum_i = accum[param_index];
317     T grad_i = grad[grad_index];
318     const T momentum_t = *momentum;
319     const T lr_t = *lr;
320 
321     // Variable update computation.
322     accum_i = momentum_t * accum_i - lr_t * grad_i;
323     // static branching in cuda does not impact performance.
324     if (use_nesterov) {
325       var_i += (momentum_t * accum_i - lr_t * grad_i);
326     } else {
327       var_i += accum_i;
328     }
329 
330     // Write update back to variables.
331     var[param_index] = var_i;
332     accum[param_index] = accum_i;
333   }
334 }
335 
336 template <typename T>
337 struct ApplyGradientDescent<GPUDevice, T> {
operator ()tensorflow::functor::ApplyGradientDescent338   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
339                   typename TTypes<T>::ConstScalar lr,
340                   typename TTypes<T>::ConstFlat grad) {
341     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
342     bcast[0] = grad.dimension(0);
343     Eigen::Sizes<1> single;
344     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad;
345   }
346 };
347 
348 template <typename T>
ApplyAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * grad,bool update_slots)349 __global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
350                                                            T* var, T* accum,
351                                                            const T* lr,
352                                                            const T* grad,
353                                                            bool update_slots) {
354   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
355     if (update_slots) accum[i] += grad[i] * grad[i];
356     var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]);
357   }
358 }
359 
360 template <typename T>
ApplyAdagradV2Kernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * epsilon,const T * grad,bool update_slots)361 __global__ __launch_bounds__(1024) void ApplyAdagradV2Kernel(
362     GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* epsilon,
363     const T* grad, bool update_slots) {
364   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
365     if (update_slots) accum[i] += grad[i] * grad[i];
366     T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]);
367     var[i] -= lr[0] * update;
368   }
369 }
370 
371 template <typename T>
ApplyProximalAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad)372 __global__ __launch_bounds__(1024) void ApplyProximalAdagradKernel(
373     GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* l1,
374     const T* l2, const T* grad) {
375   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
376     accum[i] += grad[i] * grad[i];
377     T lr_scaled = lr[0] * impl_rsqrt(accum[i]);
378     T prox_var = var[i] - grad[i] * lr_scaled;
379     var[i] = impl_sign(prox_var) *
380              max(impl_fabs(prox_var) - lr_scaled * max(l1[0], T(0.f)), T(0.f)) /
381              (T(1.f) + l2[0] * lr_scaled);
382   }
383 }
384 
385 template <typename T>
ApplyAdadeltaKernel(GpuLaunchConfig cfg,T * var,T * accum,T * accum_update,const T * plr,const T * prho,const T * peps,const T * grad)386 __global__ __launch_bounds__(1024) void ApplyAdadeltaKernel(
387     GpuLaunchConfig cfg, T* var, T* accum, T* accum_update, const T* plr,
388     const T* prho, const T* peps, const T* grad) {
389   T rho = prho[0];
390   T eps = peps[0];
391   T lr = plr[0];
392   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
393     accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
394     T update =
395         impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps);
396     var[i] -= update * lr;
397     accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
398   }
399 }
400 
401 template <typename T>
ApplyRMSPropKernel(GpuLaunchConfig cfg,T * var,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)402 __global__ __launch_bounds__(1024) void ApplyRMSPropKernel(
403     GpuLaunchConfig cfg, T* var, T* ms, T* mom, const T* plr, const T* prho,
404     const T* pmomentum, const T* peps, const T* grad) {
405   T rho = prho[0];
406   T eps = peps[0];
407   T lr = plr[0];
408   T momentum = pmomentum[0];
409   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
410     ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
411     mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]);
412     var[i] -= mom[i];
413   }
414 }
415 
416 template <typename T>
ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg,T * var,T * mg,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)417 __global__ __launch_bounds__(1024) void ApplyCenteredRMSPropKernel(
418     GpuLaunchConfig cfg, T* var, T* mg, T* ms, T* mom, const T* plr,
419     const T* prho, const T* pmomentum, const T* peps, const T* grad) {
420   T rho = prho[0];
421   T eps = peps[0];
422   T lr = plr[0];
423   T momentum = pmomentum[0];
424   T one_minus_rho = T(1.0) - rho;
425   GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
426     ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
427     mg[i] += one_minus_rho * (grad[i] - mg[i]);
428     T denom = (ms[i] - mg[i] * mg[i]) + eps;
429     mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom);
430     var[i] -= mom[i];
431   }
432 }
433 
434 namespace kernel_forward {
to_pointers(bool x)435 bool to_pointers(bool x) { return x; }
436 template <class T>
to_pointers(T & x)437 typename T::PointerType to_pointers(T& x) {
438   return x.data();
439 }
440 template <class T>
to_pointers(const T & x)441 typename T::ConstPointerType to_pointers(const T& x) {
442   return x.data();
443 }
444 
445 template <typename T, typename... CallerArgs, typename... KernelArgs>
wrap_kernel_call(void (* func)(KernelArgs...),const GPUDevice & d,T var,CallerArgs...args)446 void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
447                       CallerArgs... args) {
448   int32 data_dim = var.dimension(0);
449   auto config = GetGpuLaunchConfig(data_dim, d);
450   TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
451                               0, d.stream(), config, var.data(),
452                               to_pointers(args)...));
453 }
454 };  // namespace kernel_forward
455 
456 using kernel_forward::wrap_kernel_call;
457 
458 template <typename T>
459 struct ApplyAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagrad460   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
461                   typename TTypes<T>::Flat accum,
462                   typename TTypes<T>::ConstScalar lr,
463                   typename TTypes<T>::ConstFlat grad, bool update_slots) {
464 #if TENSORFLOW_USE_ROCM
465     wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
466                      update_slots);
467 #else
468     if (update_slots) {
469       accum.device(d) += grad.square();
470     }
471     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
472     bcast[0] = grad.dimension(0);
473     Eigen::Sizes<1> single;
474     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
475 #endif
476   }
477 };
478 
479 template <typename T>
480 struct ApplyAdagradV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagradV2481   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
482                   typename TTypes<T>::Flat accum,
483                   typename TTypes<T>::ConstScalar lr,
484                   typename TTypes<T>::ConstScalar epsilon,
485                   typename TTypes<T>::ConstFlat grad, bool update_slots) {
486 #if TENSORFLOW_USE_ROCM
487     wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
488                      update_slots);
489 #else
490     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
491     bcast[0] = grad.dimension(0);
492     Eigen::Sizes<1> single;
493     if (update_slots) {
494       accum.device(d) += grad.square();
495     }
496     const auto update =
497         grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
498     var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
499 #endif
500   }
501 };
502 
503 template <typename T, typename Tindex, bool has_epsilon>
504 struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
operator ()tensorflow::functor::SparseApplyAdagrad505   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
506                     typename TTypes<T>::Matrix accum,
507                     typename TTypes<T>::ConstScalar lr,
508                     typename TTypes<T>::ConstScalar epsilon,
509                     typename TTypes<T>::ConstMatrix grad,
510                     typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
511                     bool update_slots) {
512     const Tindex first_dim_size = var.dimension(0);
513     const Tindex grad_size = grad.size();
514     const Tindex indices_size = indices.size();
515     if (grad_size == 0) {
516       return Status::OK();
517     }
518     GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
519     return GpuLaunchKernel(
520         SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
521         config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
522         lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
523         grad_size, indices_size, update_slots);
524   }
525 };
526 
527 template <typename T>
528 struct ApplyProximalAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyProximalAdagrad529   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
530                   typename TTypes<T>::Flat accum,
531                   typename TTypes<T>::ConstScalar lr,
532                   typename TTypes<T>::ConstScalar l1,
533                   typename TTypes<T>::ConstScalar l2,
534                   typename TTypes<T>::ConstFlat grad) {
535 #if TENSORFLOW_USE_ROCM
536     wrap_kernel_call(ApplyProximalAdagradKernel<T>, d, var, accum, lr, l1, l2,
537                      grad);
538 #else
539     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
540     bcast[0] = grad.dimension(0);
541     Eigen::Sizes<1> single;
542     // Fobos update per paper with Adagrad learning rate.
543     accum.device(d) += grad.square();
544     // Adagrad learning rate.
545     // The following is the GPU equivalent of the CPU version:
546     // auto learning_rate = accum.constant(lr()) * accum.rsqrt();
547     auto lr_bcast = lr.reshape(single).broadcast(bcast);
548     auto l1_bcast = l1.reshape(single).broadcast(bcast);
549     auto l2_bcast = l2.reshape(single).broadcast(bcast);
550     auto learning_rate = lr_bcast * accum.rsqrt();
551     auto prox_var = var;
552     // compute v = w - lr * grad.
553     prox_var.device(d) -= grad * learning_rate;
554     // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
555     var.device(d) = prox_var.sign() *
556                     (prox_var.abs() - learning_rate * l1_bcast.cwiseMax(T(0.f)))
557                         .cwiseMax(T(0.f)) /
558                     (var.constant(T(1.f)) + l2_bcast * learning_rate);
559 #endif
560   }
561 };
562 
563 template <typename T, typename Tindex>
564 struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyProximalAdagrad565   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
566                     typename TTypes<T>::Matrix accum,
567                     typename TTypes<T>::ConstScalar lr,
568                     typename TTypes<T>::ConstScalar l1,
569                     typename TTypes<T>::ConstScalar l2,
570                     typename TTypes<T>::ConstMatrix grad,
571                     typename TTypes<Tindex>::ConstVec indices,
572                     int64 inner_dim) {
573     const Tindex first_dim_size = var.dimension(0);
574     const Tindex grad_size = grad.size();
575     const Tindex indices_size = indices.size();
576     if (grad_size == 0) {
577       return Status::OK();
578     }
579     GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
580     return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
581                            config.block_count, config.thread_per_block, 0,
582                            d.stream(), var.data(), accum.data(), lr.data(),
583                            l1.data(), l2.data(), grad.data(), indices.data(),
584                            first_dim_size, grad_size, indices_size);
585   }
586 };
587 
588 template <typename T>
589 struct ApplyAdadelta<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdadelta590   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
591                   typename TTypes<T>::Flat accum,
592                   typename TTypes<T>::Flat accum_update,
593                   typename TTypes<T>::ConstScalar lr,
594                   typename TTypes<T>::ConstScalar rho,
595                   typename TTypes<T>::ConstScalar epsilon,
596                   typename TTypes<T>::ConstFlat grad) {
597 #if TENSORFLOW_USE_ROCM
598     wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
599                      rho, epsilon, grad);
600 #else
601     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
602     bcast[0] = grad.dimension(0);
603     Eigen::Sizes<1> single;
604 
605     accum.device(d) = accum * rho.reshape(single).broadcast(bcast) +
606                       grad.square() * (grad.constant(T(1)) -
607                                        rho.reshape(single).broadcast(bcast));
608     const auto update =
609         (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
610         (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
611     var.device(d) -= update * lr.reshape(single).broadcast(bcast);
612     accum_update.device(d) =
613         accum_update * rho.reshape(single).broadcast(bcast) +
614         update.square() *
615             (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
616 #endif
617   }
618 };
619 
620 template <typename T>
621 struct ApplyFtrl<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrl622   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
623                   typename TTypes<T>::Flat accum,
624                   typename TTypes<T>::Flat linear,
625                   typename TTypes<T>::ConstFlat grad,
626                   typename TTypes<T>::ConstScalar lr,
627                   typename TTypes<T>::ConstScalar l1,
628                   typename TTypes<T>::ConstScalar l2,
629                   typename TTypes<T>::ConstScalar lr_power) {
630     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
631     bcast[0] = grad.dimension(0);
632     Eigen::Sizes<1> single;
633 
634     auto l1_bcast = l1.reshape(single).broadcast(bcast);
635     auto l2_bcast = l2.reshape(single).broadcast(bcast);
636     auto lr_bcast = lr.reshape(single).broadcast(bcast);
637     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
638     const auto two = static_cast<T>(2.0);
639 
640     auto new_accum = accum + grad.square();
641     auto accum_power = accum.binaryExpr(lr_power_bcast,
642                                         Eigen::internal::scalar_pow_op<T, T>());
643     auto new_accum_power = new_accum.binaryExpr(
644         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
645     linear.device(d) += grad - (new_accum_power - accum_power) * var / lr_bcast;
646     auto x = (l1_bcast * linear.sign() - linear);
647     auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
648     auto pre_shrink = x / y;
649     var.device(d) = (linear.abs() > l1_bcast)
650                         .select(pre_shrink, var.constant(static_cast<T>(0)));
651     accum.device(d) += grad.square();
652   }
653 };
654 
655 template <typename T>
656 struct ApplyFtrlMultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlMultiplyLinearByLr657   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
658                   typename TTypes<T>::Flat accum,
659                   typename TTypes<T>::Flat linear,
660                   typename TTypes<T>::ConstFlat grad,
661                   typename TTypes<T>::ConstScalar lr,
662                   typename TTypes<T>::ConstScalar l1,
663                   typename TTypes<T>::ConstScalar l2,
664                   typename TTypes<T>::ConstScalar lr_power) {
665     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
666     bcast[0] = grad.dimension(0);
667     Eigen::Sizes<1> single;
668 
669     auto lr_bcast = lr.reshape(single).broadcast(bcast);
670     auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
671     auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
672     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
673     const auto two = static_cast<T>(2.0);
674 
675     auto new_accum = accum + grad.square();
676     auto accum_power = accum.binaryExpr(lr_power_bcast,
677                                         Eigen::internal::scalar_pow_op<T, T>());
678     auto new_accum_power = new_accum.binaryExpr(
679         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
680     linear.device(d) += grad * lr_bcast - (new_accum_power - accum_power) * var;
681     auto x = (l1_lr_bcast * linear.sign() - linear);
682     auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
683     auto pre_shrink = x / y;
684     var.device(d) = (linear.abs() > l1_lr_bcast)
685                         .select(pre_shrink, var.constant(static_cast<T>(0)));
686     accum.device(d) += grad.square();
687   }
688 };
689 
690 template <typename T>
691 struct ApplyFtrlV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2692   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
693                   typename TTypes<T>::Flat accum,
694                   typename TTypes<T>::Flat linear,
695                   typename TTypes<T>::ConstFlat grad,
696                   typename TTypes<T>::ConstScalar lr,
697                   typename TTypes<T>::ConstScalar l1,
698                   typename TTypes<T>::ConstScalar l2,
699                   typename TTypes<T>::ConstScalar l2_shrinkage,
700                   typename TTypes<T>::ConstScalar lr_power) {
701     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
702     bcast[0] = grad.dimension(0);
703     Eigen::Sizes<1> single;
704 
705     auto l1_bcast = l1.reshape(single).broadcast(bcast);
706     auto l2_bcast = l2.reshape(single).broadcast(bcast);
707     auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
708     auto lr_bcast = lr.reshape(single).broadcast(bcast);
709     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
710     const auto two = static_cast<T>(2.0);
711 
712     auto new_accum = accum + grad.square();
713     auto accum_power = accum.binaryExpr(lr_power_bcast,
714                                         Eigen::internal::scalar_pow_op<T, T>());
715     auto new_accum_power = new_accum.binaryExpr(
716         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
717     auto grad_with_shrinkage =
718         grad + (var.constant(two) * l2_shrinkage_bcast * var);
719     linear.device(d) +=
720         grad_with_shrinkage - (new_accum_power - accum_power) * var / lr_bcast;
721     auto x = (l1_bcast * linear.sign() - linear);
722     auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
723     auto pre_shrink = x / y;
724     var.device(d) = (linear.abs() > l1_bcast)
725                         .select(pre_shrink, var.constant(static_cast<T>(0)));
726     accum.device(d) += grad.square();
727   }
728 };
729 
730 template <typename T>
731 struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2MultiplyLinearByLr732   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
733                   typename TTypes<T>::Flat accum,
734                   typename TTypes<T>::Flat linear,
735                   typename TTypes<T>::ConstFlat grad,
736                   typename TTypes<T>::ConstScalar lr,
737                   typename TTypes<T>::ConstScalar l1,
738                   typename TTypes<T>::ConstScalar l2,
739                   typename TTypes<T>::ConstScalar l2_shrinkage,
740                   typename TTypes<T>::ConstScalar lr_power) {
741     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
742     bcast[0] = grad.dimension(0);
743     Eigen::Sizes<1> single;
744 
745     auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
746     auto lr_bcast = lr.reshape(single).broadcast(bcast);
747     auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
748     auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
749     auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
750     const auto two = static_cast<T>(2.0);
751 
752     auto new_accum = accum + grad.square();
753     auto accum_power = accum.binaryExpr(lr_power_bcast,
754                                         Eigen::internal::scalar_pow_op<T, T>());
755     auto new_accum_power = new_accum.binaryExpr(
756         lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
757     auto grad_with_shrinkage =
758         grad + (var.constant(two) * l2_shrinkage_bcast * var);
759     linear.device(d) +=
760         grad_with_shrinkage * lr_bcast - (new_accum_power - accum_power) * var;
761     auto x = (l1_lr_bcast * linear.sign() - linear);
762     auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
763     auto pre_shrink = x / y;
764     var.device(d) = (linear.abs() > l1_lr_bcast)
765                         .select(pre_shrink, var.constant(static_cast<T>(0)));
766     accum.device(d) += grad.square();
767   }
768 };
769 
770 template <typename T, typename Tindex, bool has_l2_shrinkage>
771 struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
operator ()tensorflow::functor::SparseApplyFtrl772   Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
773                     typename TTypes<T>::Matrix accum,
774                     typename TTypes<T>::Matrix linear,
775                     typename TTypes<T>::ConstScalar lr,
776                     typename TTypes<T>::ConstScalar l1,
777                     typename TTypes<T>::ConstScalar l2,
778                     typename TTypes<T>::ConstScalar l2_shrinkage,
779                     typename TTypes<T>::ConstScalar lr_power,
780                     typename TTypes<T>::ConstMatrix grad,
781                     typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
782                     bool multiply_linear_by_lr) {
783     const Tindex first_dim_size = var.dimension(0);
784     const Tindex grad_size = grad.size();
785     const Tindex indices_size = indices.size();
786     if (grad_size == 0) {
787       return Status::OK();
788     }
789     // The simpler overload of GetGpuLaunchConfig() would result in a "too many
790     // resources requested for launch" error.
791     auto* device_func = SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>;
792     GpuLaunchConfig config =
793         GetGpuLaunchConfig(grad_size, d, device_func, 0, 0);
794     return GpuLaunchKernel(
795         device_func, config.block_count, config.thread_per_block, 0, d.stream(),
796         /*var=*/var.data(),
797         /*accum=*/accum.data(),
798         /*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
799         /*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
800         /*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
801         /*indices=*/indices.data(), /*param_rows=*/first_dim_size,
802         /*updates_size=*/grad_size,
803         /*indices_size=*/indices_size,
804         /*multiply_linear_by_lr=*/multiply_linear_by_lr);
805   }
806 };
807 
808 template <typename T>
809 struct ApplyMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyMomentum810   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
811                   typename TTypes<T>::Flat accum,
812                   typename TTypes<T>::ConstScalar lr,
813                   typename TTypes<T>::ConstFlat grad,
814                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
815     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
816     bcast[0] = grad.dimension(0);
817     Eigen::Sizes<1> single;
818     accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
819     if (use_nesterov) {
820       var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
821                        accum * momentum.reshape(single).broadcast(bcast) *
822                            lr.reshape(single).broadcast(bcast);
823     } else {
824       var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
825     }
826   }
827 };
828 
829 template <typename T>
830 struct ApplyKerasMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyKerasMomentum831   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
832                   typename TTypes<T>::Flat accum,
833                   typename TTypes<T>::ConstScalar lr,
834                   typename TTypes<T>::ConstFlat grad,
835                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
836     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
837     bcast[0] = grad.dimension(0);
838     Eigen::Sizes<1> single;
839     accum.device(d) = (accum * momentum.reshape(single).broadcast(bcast) -
840                        grad * lr.reshape(single).broadcast(bcast));
841     if (use_nesterov) {
842       var.device(d) += (accum * momentum.reshape(single).broadcast(bcast) -
843                         grad * lr.reshape(single).broadcast(bcast));
844     } else {
845       var.device(d) += accum;
846     }
847   }
848 };
849 
850 template <typename T, typename Tindex>
851 struct SparseApplyKerasMomentum<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyKerasMomentum852   Tindex operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
853                     typename TTypes<T>::Matrix accum,
854                     typename TTypes<T>::ConstScalar lr,
855                     typename TTypes<T>::ConstMatrix grad,
856                     typename TTypes<Tindex>::ConstVec indices,
857                     typename TTypes<T>::ConstScalar momentum,
858                     bool use_nesterov) {
859     const Tindex first_dim_size = var.dimension(0);
860     const Tindex grad_size = grad.size();
861     const Tindex indices_size = indices.size();
862     if (grad_size != 0) {
863       GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
864       TF_CHECK_OK(GpuLaunchKernel(
865           SparseApplyKerasMomentumKernel<T, Tindex>, config.block_count,
866           config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
867           lr.data(), grad.data(), indices.data(), momentum.data(), use_nesterov,
868           first_dim_size, grad_size, indices_size));
869     }
870     return static_cast<Tindex>(-1);
871   }
872 };
873 
874 template <typename T>
875 struct ApplyAdam<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdam876   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
877                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
878                   typename TTypes<T>::ConstScalar beta1_power,
879                   typename TTypes<T>::ConstScalar beta2_power,
880                   typename TTypes<T>::ConstScalar lr,
881                   typename TTypes<T>::ConstScalar beta1,
882                   typename TTypes<T>::ConstScalar beta2,
883                   typename TTypes<T>::ConstScalar epsilon,
884                   typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
885     int32 data_dim = grad.dimension(0);
886     GpuLaunchConfig config = GetGpuLaunchConfig(data_dim, d);
887     eigen_assert(static_cast<int64>(grad.dimension(0)) +
888                      static_cast<int64>(config.block_count) *
889                          static_cast<int64>(config.thread_per_block) <
890                  std::numeric_limits<int32>::max());
891 
892     TF_CHECK_OK(GpuLaunchKernel(
893         ApplyAdamKernel<T>, config.block_count, config.thread_per_block, 0,
894         d.stream(), data_dim, var.data(), m.data(), v.data(),
895         beta1_power.data(), beta2_power.data(), lr.data(), beta1.data(),
896         beta2.data(), epsilon.data(), grad.data(), use_nesterov));
897   }
898 };
899 
900 template <typename T>
901 struct ApplyAdamWithAmsgrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdamWithAmsgrad902   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
903                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
904                   typename TTypes<T>::Flat vhat,
905                   typename TTypes<T>::ConstScalar beta1_power,
906                   typename TTypes<T>::ConstScalar beta2_power,
907                   typename TTypes<T>::ConstScalar lr,
908                   typename TTypes<T>::ConstScalar beta1,
909                   typename TTypes<T>::ConstScalar beta2,
910                   typename TTypes<T>::ConstScalar epsilon,
911                   typename TTypes<T>::ConstFlat grad) {
912     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
913     bcast[0] = grad.dimension(0);
914     Eigen::Sizes<1> single;
915     const auto one = static_cast<T>(1.0);
916     m.device(d) =
917         m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
918                 (grad - m);
919     v.device(d) =
920         v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
921                 (grad.square() - v);
922     vhat.device(d) = vhat.cwiseMax(v);
923 
924     var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
925                       (beta1_power.constant(one) - beta1_power))
926                          .reshape(single)
927                          .broadcast(bcast) *
928                      m /
929                      (epsilon.reshape(single).broadcast(bcast) + vhat.sqrt());
930   }
931 };
932 
933 template <typename T>
934 struct ApplyAdaMax<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdaMax935   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
936                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
937                   typename TTypes<T>::ConstScalar beta1_power,
938                   typename TTypes<T>::ConstScalar lr,
939                   typename TTypes<T>::ConstScalar beta1,
940                   typename TTypes<T>::ConstScalar beta2,
941                   typename TTypes<T>::ConstScalar epsilon,
942                   typename TTypes<T>::ConstFlat grad) {
943     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
944     bcast[0] = grad.dimension(0);
945     Eigen::Sizes<1> single;
946     const auto one = static_cast<T>(1.0);
947     m.device(d) +=
948         (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
949         (grad - m);
950     v.device(d) =
951         (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
952     var.device(d) -= lr.reshape(single).broadcast(bcast) /
953                      (beta1_power.constant(one) - beta1_power)
954                          .reshape(single)
955                          .broadcast(bcast) *
956                      (m / (v + epsilon.reshape(single).broadcast(bcast)));
957   }
958 };
959 
960 template <typename T>
961 struct ApplyRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyRMSProp962   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
963                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
964                   typename TTypes<T>::ConstScalar lr,
965                   typename TTypes<T>::ConstScalar rho,
966                   typename TTypes<T>::ConstScalar momentum,
967                   typename TTypes<T>::ConstScalar epsilon,
968                   typename TTypes<T>::ConstFlat grad) {
969 #if TENSORFLOW_USE_ROCM
970     wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
971                      epsilon, grad);
972 #else
973     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
974     bcast[0] = grad.dimension(0);
975     Eigen::Sizes<1> single;
976     const auto one = static_cast<T>(1.0);
977     ms.device(d) =
978         ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
979                  (grad.square() - ms);
980     mom.device(d) =
981         mom * momentum.reshape(single).broadcast(bcast) +
982         lr.reshape(single).broadcast(bcast) * grad /
983             ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
984     var.device(d) -= mom;
985 #endif
986   }
987 };
988 
989 template <typename T>
990 struct ApplyCenteredRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyCenteredRMSProp991   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
992                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
993                   typename TTypes<T>::Flat mom,
994                   typename TTypes<T>::ConstScalar lr,
995                   typename TTypes<T>::ConstScalar rho,
996                   typename TTypes<T>::ConstScalar momentum,
997                   typename TTypes<T>::ConstScalar epsilon,
998                   typename TTypes<T>::ConstFlat grad) {
999 #if TENSORFLOW_USE_ROCM
1000     wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
1001                      rho, momentum, epsilon, grad);
1002 #else
1003     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1004     bcast[0] = grad.dimension(0);
1005     Eigen::Sizes<1> single;
1006     const auto one = static_cast<T>(1.0);
1007     const auto one_minus_rho =
1008         (rho.constant(one) - rho).reshape(single).broadcast(bcast);
1009     ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
1010     mg.device(d) = mg + one_minus_rho * (grad - mg);
1011     auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast);
1012     mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
1013                     lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
1014     var.device(d) -= mom;
1015 #endif
1016   }
1017 };
1018 
1019 template <typename T>
1020 struct ApplyAddSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAddSign1021   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1022                   typename TTypes<T>::Flat m,
1023                   typename TTypes<T>::ConstScalar lr,
1024                   typename TTypes<T>::ConstScalar alpha,
1025                   typename TTypes<T>::ConstScalar sign_decay,
1026                   typename TTypes<T>::ConstScalar beta,
1027                   typename TTypes<T>::ConstFlat grad) {
1028     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1029     bcast[0] = grad.dimension(0);
1030     Eigen::Sizes<1> single;
1031 
1032     // The following is the GPU equivalent of the CPU version:
1033     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1034     const auto one = static_cast<T>(1.0);
1035     auto beta_bcast = beta.reshape(single).broadcast(bcast);
1036     auto one_minus_beta =
1037         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1038     m.device(d) = m * beta_bcast + grad * one_minus_beta;
1039 
1040     // The following is the GPU equivalent of the CPU version:
1041     // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
1042     auto sign_gm = grad.sign() * m.sign();
1043     auto lr_bcast = lr.reshape(single).broadcast(bcast);
1044     auto alpha_bcast = alpha.reshape(single).broadcast(bcast);
1045     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1046     var.device(d) -=
1047         lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad;
1048   }
1049 };
1050 
1051 template <typename T>
1052 struct ApplyPowerSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyPowerSign1053   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1054                   typename TTypes<T>::Flat m,
1055                   typename TTypes<T>::ConstScalar lr,
1056                   typename TTypes<T>::ConstScalar logbase,
1057                   typename TTypes<T>::ConstScalar sign_decay,
1058                   typename TTypes<T>::ConstScalar beta,
1059                   typename TTypes<T>::ConstFlat grad) {
1060     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1061     bcast[0] = grad.dimension(0);
1062     Eigen::Sizes<1> single;
1063 
1064     // The following is the GPU equivalent of the CPU version:
1065     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1066     const auto one = static_cast<T>(1.0);
1067     auto beta_bcast = beta.reshape(single).broadcast(bcast);
1068     auto one_minus_beta =
1069         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1070     m.device(d) = m * beta_bcast + grad * one_minus_beta;
1071 
1072     // The following is the GPU equivalent of the CPU version:
1073     // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
1074     // var.device(d) -= lr() * grad_scale * grad;
1075     auto sign_gm = grad.sign() * m.sign();
1076     auto lr_bcast = lr.reshape(single).broadcast(bcast);
1077     auto logbase_bcast = logbase.reshape(single).broadcast(bcast);
1078     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1079     auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp();
1080     var.device(d) -= lr_bcast * grad_scale * grad;
1081   }
1082 };
1083 
1084 }  // namespace functor
1085 
1086 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
1087 template struct functor::ApplyGradientDescent<GPUDevice, float>;
1088 template struct functor::ApplyGradientDescent<GPUDevice, double>;
1089 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1090                              // complex sqrt
1091 template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
1092 template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
1093 #endif
1094 
1095 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
1096 template struct functor::ApplyAdagrad<GPUDevice, float>;
1097 template struct functor::ApplyAdagrad<GPUDevice, double>;
1098 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1099                              // complex sqrt
1100 template struct functor::ApplyAdagrad<GPUDevice, complex64>;
1101 template struct functor::ApplyAdagrad<GPUDevice, complex128>;
1102 #endif
1103 
1104 template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
1105 template struct functor::ApplyAdagradV2<GPUDevice, float>;
1106 template struct functor::ApplyAdagradV2<GPUDevice, double>;
1107 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1108                              // complex sqrt
1109 template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
1110 template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
1111 #endif
1112 
1113 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                             \
1114   template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
1115                                               /*has_epsilon=*/false>; \
1116   template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
1117                                               /*has_epsilon=*/false>; \
1118   template struct functor::SparseApplyAdagrad<GPUDevice, T, int32,    \
1119                                               /*has_epsilon=*/true>;  \
1120   template struct functor::SparseApplyAdagrad<GPUDevice, T, int64,    \
1121                                               /*has_epsilon=*/true>
1122 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1123 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1124 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1125 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1126 
1127 template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
1128 template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
1129 template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
1130 
1131 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1132                                                     int32>;
1133 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1134                                                     int64>;
1135 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
1136 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
1137 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
1138 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
1139 
1140 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
1141 template struct functor::ApplyAdadelta<GPUDevice, float>;
1142 template struct functor::ApplyAdadelta<GPUDevice, double>;
1143 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1144                              // complex sqrt
1145 template struct functor::ApplyAdadelta<GPUDevice, complex64>;
1146 template struct functor::ApplyAdadelta<GPUDevice, complex128>;
1147 #endif
1148 
1149 template struct functor::ApplyFtrl<GPUDevice, Eigen::half>;
1150 template struct functor::ApplyFtrl<GPUDevice, float>;
1151 template struct functor::ApplyFtrl<GPUDevice, double>;
1152 
1153 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, Eigen::half>;
1154 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, float>;
1155 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, double>;
1156 
1157 template struct functor::ApplyFtrlV2<GPUDevice, Eigen::half>;
1158 template struct functor::ApplyFtrlV2<GPUDevice, float>;
1159 template struct functor::ApplyFtrlV2<GPUDevice, double>;
1160 
1161 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, Eigen::half>;
1162 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, float>;
1163 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, double>;
1164 
1165 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T)                               \
1166   template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
1167                                            /*has_l2_shrinkage=*/false>; \
1168   template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
1169                                            /*has_l2_shrinkage=*/false>; \
1170   template struct functor::SparseApplyFtrl<GPUDevice, T, int32,         \
1171                                            /*has_l2_shrinkage=*/true>;  \
1172   template struct functor::SparseApplyFtrl<GPUDevice, T, int64,         \
1173                                            /*has_l2_shrinkage=*/true>
1174 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1175 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1176 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1177 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1178 
1179 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
1180 template struct functor::ApplyMomentum<GPUDevice, float>;
1181 template struct functor::ApplyMomentum<GPUDevice, double>;
1182 #if !defined(TENSORFLOW_USE_NVCC) && \
1183     !defined(TENSORFLOW_USE_ROCM)  // TODO(b/143684500): Eigen to support
1184                                    // complex sqrt
1185 template struct functor::ApplyMomentum<GPUDevice, complex64>;
1186 template struct functor::ApplyMomentum<GPUDevice, complex128>;
1187 #endif
1188 
1189 template struct functor::ApplyKerasMomentum<GPUDevice, Eigen::half>;
1190 template struct functor::ApplyKerasMomentum<GPUDevice, float>;
1191 template struct functor::ApplyKerasMomentum<GPUDevice, double>;
1192 #if !defined(TENSORFLOW_USE_NVCC) && \
1193     !defined(TENSORFLOW_USE_ROCM)  // TODO(b/143684500): Eigen to support
1194                                    // complex sqrt
1195 template struct functor::ApplyKerasMomentum<GPUDevice, complex64>;
1196 template struct functor::ApplyKerasMomentum<GPUDevice, complex128>;
1197 #endif
1198 
1199 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1200                                                   int32>;
1201 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1202                                                   int64>;
1203 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int32>;
1204 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int64>;
1205 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int32>;
1206 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int64>;
1207 #if !defined(TENSORFLOW_USE_NVCC) && \
1208     !defined(TENSORFLOW_USE_ROCM)  // TODO(b/143684500): Eigen to support
1209                                    // complex sqrt
1210 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int32>;
1211 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int64>;
1212 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int32>;
1213 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int64>;
1214 #endif
1215 
1216 template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
1217 template struct functor::ApplyAdam<GPUDevice, float>;
1218 template struct functor::ApplyAdam<GPUDevice, double>;
1219 #if !defined(TENSORFLOW_USE_NVCC) && \
1220     !defined(TENSORFLOW_USE_ROCM)  // TODO(b/143684500): Eigen to support
1221                                    // complex sqrt
1222 template struct functor::ApplyAdam<GPUDevice, complex64>;
1223 template struct functor::ApplyAdam<GPUDevice, complex128>;
1224 #endif
1225 
1226 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, Eigen::half>;
1227 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, float>;
1228 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, double>;
1229 
1230 template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>;
1231 template struct functor::ApplyAdaMax<GPUDevice, float>;
1232 template struct functor::ApplyAdaMax<GPUDevice, double>;
1233 
1234 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
1235 template struct functor::ApplyRMSProp<GPUDevice, float>;
1236 template struct functor::ApplyRMSProp<GPUDevice, double>;
1237 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1238                              // complex sqrt
1239 template struct functor::ApplyRMSProp<GPUDevice, complex64>;
1240 template struct functor::ApplyRMSProp<GPUDevice, complex128>;
1241 #endif
1242 
1243 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
1244 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
1245 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
1246 #ifndef TENSORFLOW_USE_NVCC  // TODO(b/143684500): Eigen to support
1247                              // complex sqrt
1248 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
1249 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
1250 #endif
1251 
1252 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>;
1253 template struct functor::ApplyAddSign<GPUDevice, float>;
1254 template struct functor::ApplyAddSign<GPUDevice, double>;
1255 
1256 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>;
1257 template struct functor::ApplyPowerSign<GPUDevice, float>;
1258 template struct functor::ApplyPowerSign<GPUDevice, double>;
1259 
1260 }  // end namespace tensorflow
1261 
1262 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1263