1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
17
18 #define EIGEN_USE_GPU
19
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/kernels/training_ops.h"
22 #include "tensorflow/core/util/gpu_kernel_helper.h"
23
24 namespace tensorflow {
25
26 typedef Eigen::GpuDevice GPUDevice;
27
28 namespace functor {
29
30 #if TENSORFLOW_USE_ROCM
31
32 #include "rocm/include/hip/hip_complex.h"
33
34 #endif // TENSORFLOW_USE_ROCM
35
36 // if any kernels involving complex sqrt/rsqrt are compiled with ROCm, build
37 // process completes without errors,but the resulting executable ends up
38 // unusable (throwing errors "no device code available for function" for
39 /// completely unrelated kernels.)
40 // We also can't cast to hipFloatComplex etc. because (as of 2020-01) HIP does
41 // not provide sqrt for complex.
42 // We have no choice but to implement sqrt and rsqrt by hand
43 template <typename T>
impl_sqrt(T x)44 __device__ T impl_sqrt(T x) {
45 return sqrt(x);
46 }
47 template <typename T>
impl_rsqrt(T x)48 __device__ T impl_rsqrt(T x) {
49 return rsqrt(x);
50 }
51 template <>
impl_sqrt(Eigen::half x)52 __device__ Eigen::half impl_sqrt(Eigen::half x) {
53 return __float2half(sqrt(__half2float(x)));
54 }
55 template <>
impl_rsqrt(Eigen::half x)56 __device__ Eigen::half impl_rsqrt(Eigen::half x) {
57 return __float2half(rsqrt(__half2float(x)));
58 }
59
60 template <class T>
impl_sqrt(std::complex<T> x)61 __device__ std::complex<T> impl_sqrt(std::complex<T> x) {
62 T re = x.real(), im = x.imag();
63 T mod_x = sqrt(re * re + im * im);
64 const T root2 = 0.7071067811865475;
65 // We pick the root with the same sign of the imaginary component as
66 // the input.
67 T root[2] = {T(sqrt(mod_x + re) * root2),
68 T(sqrt(mod_x - re) * root2 * (im >= 0 ? 1. : -1.))};
69 // hcc/clang is really weird with its support of complex in device code;
70 // for some reason it does not permit a 2-argument constructor
71 return *(reinterpret_cast<std::complex<T>*>(&root));
72 }
73
74 template <class T>
rsqrt_helper(T x)75 __device__ T rsqrt_helper(T x) {
76 return 0.5 * x + 0.125 * x * x + 0.0625 * x * x * x;
77 }
78
79 template <class T>
impl_rsqrt(std::complex<T> x)80 __device__ std::complex<T> impl_rsqrt(std::complex<T> x) {
81 T re = x.real(), im = x.imag();
82 T r = rsqrt(re * re + im * im);
83 T ar2 = re * r * r;
84 const T root2 = 0.7071067811865475;
85 T root[2];
86 // With float, calculating 1+re*r and 1-re*r may result in excessive errors
87 // due to subtraction of two close values. We have to get fancy
88 root[0] = sqrt(r * ((std::is_same<T, float>::value && re * r < -0.98)
89 ? rsqrt_helper(im * im * r * r)
90 : max(T(0.0), 1 + re * r))) *
91 root2;
92 root[1] = sqrt(r * ((std::is_same<T, float>::value && re * r > 0.98)
93 ? rsqrt_helper(im * im * r * r)
94 : max(T(0.0), 1 - re * r))) *
95 root2 * (im >= 0 ? -1. : 1.);
96 return *(reinterpret_cast<std::complex<T>*>(&root));
97 }
98
99 template <typename T>
impl_fabs(T x)100 __device__ T impl_fabs(T x) {
101 return fabs(x);
102 }
103 template <>
impl_fabs(Eigen::half x)104 __device__ Eigen::half impl_fabs(Eigen::half x) {
105 return __float2half(fabs(__half2float(x)));
106 }
107
108 template <typename T>
impl_sign(T x)109 __device__ T impl_sign(T x) {
110 return x == T(0) ? T(0) : x < T(0) ? T(-1) : T(1);
111 }
112
113 template <typename T, typename Tindex, bool has_epsilon>
SparseApplyAdagradKernel(T * var,T * accum,const T * lr,const T * epsilon,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool update_slots)114 __global__ __launch_bounds__(1024) void SparseApplyAdagradKernel(
115 T* var, T* accum, const T* lr, const T* epsilon, const T* grad,
116 const Tindex* indices, Tindex param_rows, Tindex updates_size,
117 Tindex indices_size, bool update_slots) {
118 Tindex col_size = updates_size / indices_size;
119 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
120 Tindex indices_row = grad_index / col_size;
121 Tindex param_row = indices[indices_row];
122 if (param_row < 0 || param_row >= param_rows) {
123 // Ignore indices that are out of range.
124 continue;
125 }
126
127 // Compute the index of var and accum.
128 Tindex param_index = param_row * col_size + (grad_index % col_size);
129
130 // Read variables.
131 T var_i = var[param_index];
132 T accum_i = accum[param_index];
133 T grad_i = grad[grad_index];
134 const T lr_t = *lr;
135 const T epsilon_t = *epsilon;
136
137 if (update_slots) {
138 accum_i += grad_i * grad_i;
139 }
140 if (has_epsilon) {
141 var_i -= lr_t * grad_i / (sqrt(accum_i) + epsilon_t);
142 } else {
143 var_i -= lr_t * grad_i * impl_rsqrt(accum_i);
144 }
145
146 // Write update back to variables.
147 var[param_index] = var_i;
148 accum[param_index] = accum_i;
149 }
150 }
151
152 template <typename T, typename Tindex>
SparseApplyProximalAdagradKernel(T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size)153 __global__ __launch_bounds__(1024) void SparseApplyProximalAdagradKernel(
154 T* var, T* accum, const T* lr, const T* l1, const T* l2, const T* grad,
155 const Tindex* indices, Tindex param_rows, Tindex updates_size,
156 Tindex indices_size) {
157 Tindex col_size = updates_size / indices_size;
158 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
159 Tindex indices_row = grad_index / col_size;
160 Tindex param_row = indices[indices_row];
161 if (param_row < 0 || param_row >= param_rows) {
162 // Ignore indices that are out of range.
163 continue;
164 }
165
166 // Compute the index of var and accum.
167 Tindex param_index = param_row * col_size + (grad_index % col_size);
168
169 // Read variables.
170 T var_i = var[param_index];
171 T accum_i = accum[param_index];
172 T grad_i = grad[grad_index];
173 const T lr_t = *lr;
174 const T l1_t = *l1;
175 const T l2_t = *l2;
176
177 accum_i += grad_i * grad_i;
178 T learning_rate = lr_t * impl_rsqrt(accum_i);
179 // compute v = w - lr * grad.
180 T prox_var_i = var_i - grad_i * learning_rate;
181 // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
182 var_i = (prox_var_i >= 0 ? T(1.) : T(-1.)) *
183 max(abs(prox_var_i) - learning_rate * max(l1_t, T(0)), T(0)) /
184 (T(1.) + l2_t * learning_rate);
185
186 // Write update back to variables.
187 var[param_index] = var_i;
188 accum[param_index] = accum_i;
189 }
190 }
191
192 template <typename T, typename Tindex, bool has_l2_shrinkage>
SparseApplyFtrlKernel(T * var,T * accum,T * linear,const T * lr,const T * l1,const T * l2,const T * l2_shrinkage,const T * lr_power,const T * grad,const Tindex * indices,Tindex param_rows,Tindex updates_size,Tindex indices_size,bool multiply_linear_by_lr)193 __global__ void SparseApplyFtrlKernel(T* var, T* accum, T* linear, const T* lr,
194 const T* l1, const T* l2,
195 const T* l2_shrinkage, const T* lr_power,
196 const T* grad, const Tindex* indices,
197 Tindex param_rows, Tindex updates_size,
198 Tindex indices_size,
199 bool multiply_linear_by_lr) {
200 const Tindex col_size = updates_size / indices_size;
201 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
202 const Tindex indices_row = grad_index / col_size;
203 const Tindex param_row = indices[indices_row];
204 if (param_row < 0 || param_row >= param_rows) {
205 // Ignore indices that are out of range.
206 continue;
207 }
208
209 // Compute the index of var and accum.
210 const Tindex param_index = param_row * col_size + (grad_index % col_size);
211
212 // Read variables.
213 T var_i = var[param_index];
214 T accum_i = accum[param_index];
215 T linear_i = linear[param_index];
216 const T grad_i = grad[grad_index];
217 const T lr_t = *lr;
218 const T l1_t = *l1;
219 const T l2_t = *l2;
220 const T lr_power_t = *lr_power;
221
222 const T grad_shr_i =
223 has_l2_shrinkage ? grad_i + static_cast<T>(2) * (*l2_shrinkage) * var_i
224 : grad_i;
225 const T new_accum_i = accum_i + grad_i * grad_i;
226 const bool lr_power_is_neg_half = lr_power_t == static_cast<T>(-0.5);
227 const T pow_new_accum = lr_power_is_neg_half
228 ? sqrt(new_accum_i)
229 : pow(new_accum_i, -lr_power_t);
230 const T pow_accum =
231 lr_power_is_neg_half ? sqrt(accum_i) : pow(accum_i, -lr_power_t);
232 T linear_change = grad_shr_i * lr_t - (pow_new_accum - pow_accum) * var_i;
233 if (!multiply_linear_by_lr) {
234 linear_change /= lr_t;
235 }
236 linear_i += linear_change;
237
238 T l1_mult = l1_t;
239 if (multiply_linear_by_lr) {
240 l1_mult *= lr_t;
241 }
242 const T l1_reg_adjust = max(min(linear_i, l1_mult), -l1_mult);
243 const T x = l1_reg_adjust - linear_i;
244 T y = pow_new_accum + static_cast<T>(2) * l2_t * lr_t;
245 if (!multiply_linear_by_lr) {
246 y /= lr_t;
247 }
248 var_i = x / y;
249 accum_i = new_accum_i;
250
251 // Write update back to variables.
252 var[param_index] = var_i;
253 accum[param_index] = accum_i;
254 linear[param_index] = linear_i;
255 }
256 }
257
258 template <typename T>
ApplyAdamKernel(int32 data_dim,T * var,T * m,T * v,const T * const beta1_power_,const T * const beta2_power_,const T * const lr_,const T * const beta1_,const T * const beta2_,const T * const epsilon_,const T * grad,bool use_nesterov)259 __global__ __launch_bounds__(1024) void ApplyAdamKernel(
260 int32 data_dim, T* var, T* m, T* v, const T* const beta1_power_,
261 const T* const beta2_power_, const T* const lr_, const T* const beta1_,
262 const T* const beta2_, const T* const epsilon_, const T* grad,
263 bool use_nesterov) {
264 eigen_assert(blockDim.y == 1);
265 eigen_assert(blockDim.z == 1);
266 eigen_assert(gridDim.y == 1);
267 eigen_assert(gridDim.z == 1);
268
269 const T mul_factor = (*lr_) * sqrt(static_cast<T>(1.0) - (*beta2_power_)) /
270 (static_cast<T>(1.0) - (*beta1_power_));
271 const T epsilon = (*epsilon_);
272 const T beta1 = (*beta1_);
273 const T one_minus_beta1 = static_cast<T>(1.0) - (beta1);
274 const T one_minus_beta2 = static_cast<T>(1.0) - (*beta2_);
275 const int32 stripe = gridDim.x * blockDim.x;
276
277 for (int32 i = blockIdx.x * blockDim.x + threadIdx.x; i < data_dim;
278 i += stripe) {
279 auto m_i = m[i];
280 auto g_i = grad[i];
281 auto v_i = v[i];
282
283 m_i += one_minus_beta1 * (g_i - m_i);
284 v_i += one_minus_beta2 * (g_i * g_i - v_i);
285 if (use_nesterov) {
286 var[i] -= mul_factor * (m_i * beta1 + one_minus_beta1 * g_i) /
287 (epsilon + sqrt(v_i));
288 } else {
289 var[i] -= mul_factor * m_i / (epsilon + sqrt(v_i));
290 }
291
292 m[i] = m_i;
293 v[i] = v_i;
294 }
295 }
296
297 template <typename T, typename Tindex>
SparseApplyKerasMomentumKernel(T * var,T * accum,const T * lr,const T * grad,const Tindex * indices,const T * momentum,bool use_nesterov,Tindex param_rows,Tindex updates_size,Tindex indices_size)298 __global__ __launch_bounds__(1024) void SparseApplyKerasMomentumKernel(
299 T* var, T* accum, const T* lr, const T* grad, const Tindex* indices,
300 const T* momentum, bool use_nesterov, Tindex param_rows,
301 Tindex updates_size, Tindex indices_size) {
302 Tindex col_size = updates_size / indices_size;
303 GPU_1D_KERNEL_LOOP(grad_index, updates_size) {
304 Tindex indices_row = grad_index / col_size;
305 Tindex param_row = indices[indices_row];
306 if (param_row < 0 || param_row >= param_rows) {
307 // Ignore indices that are out of range.
308 continue;
309 }
310
311 // Compute the index of var and accum.
312 Tindex param_index = param_row * col_size + (grad_index % col_size);
313
314 // Read variables.
315 T var_i = var[param_index];
316 T accum_i = accum[param_index];
317 T grad_i = grad[grad_index];
318 const T momentum_t = *momentum;
319 const T lr_t = *lr;
320
321 // Variable update computation.
322 accum_i = momentum_t * accum_i - lr_t * grad_i;
323 // static branching in cuda does not impact performance.
324 if (use_nesterov) {
325 var_i += (momentum_t * accum_i - lr_t * grad_i);
326 } else {
327 var_i += accum_i;
328 }
329
330 // Write update back to variables.
331 var[param_index] = var_i;
332 accum[param_index] = accum_i;
333 }
334 }
335
336 template <typename T>
337 struct ApplyGradientDescent<GPUDevice, T> {
operator ()tensorflow::functor::ApplyGradientDescent338 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
339 typename TTypes<T>::ConstScalar lr,
340 typename TTypes<T>::ConstFlat grad) {
341 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
342 bcast[0] = grad.dimension(0);
343 Eigen::Sizes<1> single;
344 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad;
345 }
346 };
347
348 template <typename T>
ApplyAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * grad,bool update_slots)349 __global__ __launch_bounds__(1024) void ApplyAdagradKernel(GpuLaunchConfig cfg,
350 T* var, T* accum,
351 const T* lr,
352 const T* grad,
353 bool update_slots) {
354 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
355 if (update_slots) accum[i] += grad[i] * grad[i];
356 var[i] -= lr[0] * grad[i] * impl_rsqrt(accum[i]);
357 }
358 }
359
360 template <typename T>
ApplyAdagradV2Kernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * epsilon,const T * grad,bool update_slots)361 __global__ __launch_bounds__(1024) void ApplyAdagradV2Kernel(
362 GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* epsilon,
363 const T* grad, bool update_slots) {
364 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
365 if (update_slots) accum[i] += grad[i] * grad[i];
366 T update = grad[i] / (impl_sqrt(accum[i]) + epsilon[0]);
367 var[i] -= lr[0] * update;
368 }
369 }
370
371 template <typename T>
ApplyProximalAdagradKernel(GpuLaunchConfig cfg,T * var,T * accum,const T * lr,const T * l1,const T * l2,const T * grad)372 __global__ __launch_bounds__(1024) void ApplyProximalAdagradKernel(
373 GpuLaunchConfig cfg, T* var, T* accum, const T* lr, const T* l1,
374 const T* l2, const T* grad) {
375 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
376 accum[i] += grad[i] * grad[i];
377 T lr_scaled = lr[0] * impl_rsqrt(accum[i]);
378 T prox_var = var[i] - grad[i] * lr_scaled;
379 var[i] = impl_sign(prox_var) *
380 max(impl_fabs(prox_var) - lr_scaled * max(l1[0], T(0.f)), T(0.f)) /
381 (T(1.f) + l2[0] * lr_scaled);
382 }
383 }
384
385 template <typename T>
ApplyAdadeltaKernel(GpuLaunchConfig cfg,T * var,T * accum,T * accum_update,const T * plr,const T * prho,const T * peps,const T * grad)386 __global__ __launch_bounds__(1024) void ApplyAdadeltaKernel(
387 GpuLaunchConfig cfg, T* var, T* accum, T* accum_update, const T* plr,
388 const T* prho, const T* peps, const T* grad) {
389 T rho = prho[0];
390 T eps = peps[0];
391 T lr = plr[0];
392 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
393 accum[i] = accum[i] * rho + grad[i] * grad[i] * (T(1.0) - rho);
394 T update =
395 impl_sqrt(accum_update[i] + eps) * grad[i] * impl_rsqrt(accum[i] + eps);
396 var[i] -= update * lr;
397 accum_update[i] = accum_update[i] * rho + update * update * (T(1.0) - rho);
398 }
399 }
400
401 template <typename T>
ApplyRMSPropKernel(GpuLaunchConfig cfg,T * var,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)402 __global__ __launch_bounds__(1024) void ApplyRMSPropKernel(
403 GpuLaunchConfig cfg, T* var, T* ms, T* mom, const T* plr, const T* prho,
404 const T* pmomentum, const T* peps, const T* grad) {
405 T rho = prho[0];
406 T eps = peps[0];
407 T lr = plr[0];
408 T momentum = pmomentum[0];
409 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
410 ms[i] += (T(1.0) - rho) * (grad[i] * grad[i] - ms[i]);
411 mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(eps + ms[i]);
412 var[i] -= mom[i];
413 }
414 }
415
416 template <typename T>
ApplyCenteredRMSPropKernel(GpuLaunchConfig cfg,T * var,T * mg,T * ms,T * mom,const T * plr,const T * prho,const T * pmomentum,const T * peps,const T * grad)417 __global__ __launch_bounds__(1024) void ApplyCenteredRMSPropKernel(
418 GpuLaunchConfig cfg, T* var, T* mg, T* ms, T* mom, const T* plr,
419 const T* prho, const T* pmomentum, const T* peps, const T* grad) {
420 T rho = prho[0];
421 T eps = peps[0];
422 T lr = plr[0];
423 T momentum = pmomentum[0];
424 T one_minus_rho = T(1.0) - rho;
425 GPU_1D_KERNEL_LOOP(i, cfg.virtual_thread_count) {
426 ms[i] += one_minus_rho * (grad[i] * grad[i] - ms[i]);
427 mg[i] += one_minus_rho * (grad[i] - mg[i]);
428 T denom = (ms[i] - mg[i] * mg[i]) + eps;
429 mom[i] = mom[i] * momentum + lr * grad[i] * impl_rsqrt(denom);
430 var[i] -= mom[i];
431 }
432 }
433
434 namespace kernel_forward {
to_pointers(bool x)435 bool to_pointers(bool x) { return x; }
436 template <class T>
to_pointers(T & x)437 typename T::PointerType to_pointers(T& x) {
438 return x.data();
439 }
440 template <class T>
to_pointers(const T & x)441 typename T::ConstPointerType to_pointers(const T& x) {
442 return x.data();
443 }
444
445 template <typename T, typename... CallerArgs, typename... KernelArgs>
wrap_kernel_call(void (* func)(KernelArgs...),const GPUDevice & d,T var,CallerArgs...args)446 void wrap_kernel_call(void (*func)(KernelArgs...), const GPUDevice& d, T var,
447 CallerArgs... args) {
448 int32 data_dim = var.dimension(0);
449 auto config = GetGpuLaunchConfig(data_dim, d);
450 TF_CHECK_OK(GpuLaunchKernel(func, config.block_count, config.thread_per_block,
451 0, d.stream(), config, var.data(),
452 to_pointers(args)...));
453 }
454 }; // namespace kernel_forward
455
456 using kernel_forward::wrap_kernel_call;
457
458 template <typename T>
459 struct ApplyAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagrad460 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
461 typename TTypes<T>::Flat accum,
462 typename TTypes<T>::ConstScalar lr,
463 typename TTypes<T>::ConstFlat grad, bool update_slots) {
464 #if TENSORFLOW_USE_ROCM
465 wrap_kernel_call(ApplyAdagradKernel<T>, d, var, accum, lr, grad,
466 update_slots);
467 #else
468 if (update_slots) {
469 accum.device(d) += grad.square();
470 }
471 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
472 bcast[0] = grad.dimension(0);
473 Eigen::Sizes<1> single;
474 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
475 #endif
476 }
477 };
478
479 template <typename T>
480 struct ApplyAdagradV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdagradV2481 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
482 typename TTypes<T>::Flat accum,
483 typename TTypes<T>::ConstScalar lr,
484 typename TTypes<T>::ConstScalar epsilon,
485 typename TTypes<T>::ConstFlat grad, bool update_slots) {
486 #if TENSORFLOW_USE_ROCM
487 wrap_kernel_call(ApplyAdagradV2Kernel<T>, d, var, accum, lr, epsilon, grad,
488 update_slots);
489 #else
490 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
491 bcast[0] = grad.dimension(0);
492 Eigen::Sizes<1> single;
493 if (update_slots) {
494 accum.device(d) += grad.square();
495 }
496 const auto update =
497 grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast));
498 var.device(d) -= lr.reshape(single).broadcast(bcast) * update;
499 #endif
500 }
501 };
502
503 template <typename T, typename Tindex, bool has_epsilon>
504 struct SparseApplyAdagrad<GPUDevice, T, Tindex, has_epsilon> {
operator ()tensorflow::functor::SparseApplyAdagrad505 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
506 typename TTypes<T>::Matrix accum,
507 typename TTypes<T>::ConstScalar lr,
508 typename TTypes<T>::ConstScalar epsilon,
509 typename TTypes<T>::ConstMatrix grad,
510 typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
511 bool update_slots) {
512 const Tindex first_dim_size = var.dimension(0);
513 const Tindex grad_size = grad.size();
514 const Tindex indices_size = indices.size();
515 if (grad_size == 0) {
516 return Status::OK();
517 }
518 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
519 return GpuLaunchKernel(
520 SparseApplyAdagradKernel<T, Tindex, has_epsilon>, config.block_count,
521 config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
522 lr.data(), epsilon.data(), grad.data(), indices.data(), first_dim_size,
523 grad_size, indices_size, update_slots);
524 }
525 };
526
527 template <typename T>
528 struct ApplyProximalAdagrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyProximalAdagrad529 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
530 typename TTypes<T>::Flat accum,
531 typename TTypes<T>::ConstScalar lr,
532 typename TTypes<T>::ConstScalar l1,
533 typename TTypes<T>::ConstScalar l2,
534 typename TTypes<T>::ConstFlat grad) {
535 #if TENSORFLOW_USE_ROCM
536 wrap_kernel_call(ApplyProximalAdagradKernel<T>, d, var, accum, lr, l1, l2,
537 grad);
538 #else
539 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
540 bcast[0] = grad.dimension(0);
541 Eigen::Sizes<1> single;
542 // Fobos update per paper with Adagrad learning rate.
543 accum.device(d) += grad.square();
544 // Adagrad learning rate.
545 // The following is the GPU equivalent of the CPU version:
546 // auto learning_rate = accum.constant(lr()) * accum.rsqrt();
547 auto lr_bcast = lr.reshape(single).broadcast(bcast);
548 auto l1_bcast = l1.reshape(single).broadcast(bcast);
549 auto l2_bcast = l2.reshape(single).broadcast(bcast);
550 auto learning_rate = lr_bcast * accum.rsqrt();
551 auto prox_var = var;
552 // compute v = w - lr * grad.
553 prox_var.device(d) -= grad * learning_rate;
554 // compute sign(v) * max(|v| - lr * max(l1, 0), 0)
555 var.device(d) = prox_var.sign() *
556 (prox_var.abs() - learning_rate * l1_bcast.cwiseMax(T(0.f)))
557 .cwiseMax(T(0.f)) /
558 (var.constant(T(1.f)) + l2_bcast * learning_rate);
559 #endif
560 }
561 };
562
563 template <typename T, typename Tindex>
564 struct SparseApplyProximalAdagrad<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyProximalAdagrad565 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
566 typename TTypes<T>::Matrix accum,
567 typename TTypes<T>::ConstScalar lr,
568 typename TTypes<T>::ConstScalar l1,
569 typename TTypes<T>::ConstScalar l2,
570 typename TTypes<T>::ConstMatrix grad,
571 typename TTypes<Tindex>::ConstVec indices,
572 int64 inner_dim) {
573 const Tindex first_dim_size = var.dimension(0);
574 const Tindex grad_size = grad.size();
575 const Tindex indices_size = indices.size();
576 if (grad_size == 0) {
577 return Status::OK();
578 }
579 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
580 return GpuLaunchKernel(SparseApplyProximalAdagradKernel<T, Tindex>,
581 config.block_count, config.thread_per_block, 0,
582 d.stream(), var.data(), accum.data(), lr.data(),
583 l1.data(), l2.data(), grad.data(), indices.data(),
584 first_dim_size, grad_size, indices_size);
585 }
586 };
587
588 template <typename T>
589 struct ApplyAdadelta<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdadelta590 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
591 typename TTypes<T>::Flat accum,
592 typename TTypes<T>::Flat accum_update,
593 typename TTypes<T>::ConstScalar lr,
594 typename TTypes<T>::ConstScalar rho,
595 typename TTypes<T>::ConstScalar epsilon,
596 typename TTypes<T>::ConstFlat grad) {
597 #if TENSORFLOW_USE_ROCM
598 wrap_kernel_call(ApplyAdadeltaKernel<T>, d, var, accum, accum_update, lr,
599 rho, epsilon, grad);
600 #else
601 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
602 bcast[0] = grad.dimension(0);
603 Eigen::Sizes<1> single;
604
605 accum.device(d) = accum * rho.reshape(single).broadcast(bcast) +
606 grad.square() * (grad.constant(T(1)) -
607 rho.reshape(single).broadcast(bcast));
608 const auto update =
609 (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
610 (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
611 var.device(d) -= update * lr.reshape(single).broadcast(bcast);
612 accum_update.device(d) =
613 accum_update * rho.reshape(single).broadcast(bcast) +
614 update.square() *
615 (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
616 #endif
617 }
618 };
619
620 template <typename T>
621 struct ApplyFtrl<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrl622 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
623 typename TTypes<T>::Flat accum,
624 typename TTypes<T>::Flat linear,
625 typename TTypes<T>::ConstFlat grad,
626 typename TTypes<T>::ConstScalar lr,
627 typename TTypes<T>::ConstScalar l1,
628 typename TTypes<T>::ConstScalar l2,
629 typename TTypes<T>::ConstScalar lr_power) {
630 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
631 bcast[0] = grad.dimension(0);
632 Eigen::Sizes<1> single;
633
634 auto l1_bcast = l1.reshape(single).broadcast(bcast);
635 auto l2_bcast = l2.reshape(single).broadcast(bcast);
636 auto lr_bcast = lr.reshape(single).broadcast(bcast);
637 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
638 const auto two = static_cast<T>(2.0);
639
640 auto new_accum = accum + grad.square();
641 auto accum_power = accum.binaryExpr(lr_power_bcast,
642 Eigen::internal::scalar_pow_op<T, T>());
643 auto new_accum_power = new_accum.binaryExpr(
644 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
645 linear.device(d) += grad - (new_accum_power - accum_power) * var / lr_bcast;
646 auto x = (l1_bcast * linear.sign() - linear);
647 auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
648 auto pre_shrink = x / y;
649 var.device(d) = (linear.abs() > l1_bcast)
650 .select(pre_shrink, var.constant(static_cast<T>(0)));
651 accum.device(d) += grad.square();
652 }
653 };
654
655 template <typename T>
656 struct ApplyFtrlMultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlMultiplyLinearByLr657 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
658 typename TTypes<T>::Flat accum,
659 typename TTypes<T>::Flat linear,
660 typename TTypes<T>::ConstFlat grad,
661 typename TTypes<T>::ConstScalar lr,
662 typename TTypes<T>::ConstScalar l1,
663 typename TTypes<T>::ConstScalar l2,
664 typename TTypes<T>::ConstScalar lr_power) {
665 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
666 bcast[0] = grad.dimension(0);
667 Eigen::Sizes<1> single;
668
669 auto lr_bcast = lr.reshape(single).broadcast(bcast);
670 auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
671 auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
672 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
673 const auto two = static_cast<T>(2.0);
674
675 auto new_accum = accum + grad.square();
676 auto accum_power = accum.binaryExpr(lr_power_bcast,
677 Eigen::internal::scalar_pow_op<T, T>());
678 auto new_accum_power = new_accum.binaryExpr(
679 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
680 linear.device(d) += grad * lr_bcast - (new_accum_power - accum_power) * var;
681 auto x = (l1_lr_bcast * linear.sign() - linear);
682 auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
683 auto pre_shrink = x / y;
684 var.device(d) = (linear.abs() > l1_lr_bcast)
685 .select(pre_shrink, var.constant(static_cast<T>(0)));
686 accum.device(d) += grad.square();
687 }
688 };
689
690 template <typename T>
691 struct ApplyFtrlV2<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2692 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
693 typename TTypes<T>::Flat accum,
694 typename TTypes<T>::Flat linear,
695 typename TTypes<T>::ConstFlat grad,
696 typename TTypes<T>::ConstScalar lr,
697 typename TTypes<T>::ConstScalar l1,
698 typename TTypes<T>::ConstScalar l2,
699 typename TTypes<T>::ConstScalar l2_shrinkage,
700 typename TTypes<T>::ConstScalar lr_power) {
701 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
702 bcast[0] = grad.dimension(0);
703 Eigen::Sizes<1> single;
704
705 auto l1_bcast = l1.reshape(single).broadcast(bcast);
706 auto l2_bcast = l2.reshape(single).broadcast(bcast);
707 auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
708 auto lr_bcast = lr.reshape(single).broadcast(bcast);
709 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
710 const auto two = static_cast<T>(2.0);
711
712 auto new_accum = accum + grad.square();
713 auto accum_power = accum.binaryExpr(lr_power_bcast,
714 Eigen::internal::scalar_pow_op<T, T>());
715 auto new_accum_power = new_accum.binaryExpr(
716 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
717 auto grad_with_shrinkage =
718 grad + (var.constant(two) * l2_shrinkage_bcast * var);
719 linear.device(d) +=
720 grad_with_shrinkage - (new_accum_power - accum_power) * var / lr_bcast;
721 auto x = (l1_bcast * linear.sign() - linear);
722 auto y = (new_accum_power / lr_bcast) + linear.constant(two) * l2_bcast;
723 auto pre_shrink = x / y;
724 var.device(d) = (linear.abs() > l1_bcast)
725 .select(pre_shrink, var.constant(static_cast<T>(0)));
726 accum.device(d) += grad.square();
727 }
728 };
729
730 template <typename T>
731 struct ApplyFtrlV2MultiplyLinearByLr<GPUDevice, T> {
operator ()tensorflow::functor::ApplyFtrlV2MultiplyLinearByLr732 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
733 typename TTypes<T>::Flat accum,
734 typename TTypes<T>::Flat linear,
735 typename TTypes<T>::ConstFlat grad,
736 typename TTypes<T>::ConstScalar lr,
737 typename TTypes<T>::ConstScalar l1,
738 typename TTypes<T>::ConstScalar l2,
739 typename TTypes<T>::ConstScalar l2_shrinkage,
740 typename TTypes<T>::ConstScalar lr_power) {
741 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
742 bcast[0] = grad.dimension(0);
743 Eigen::Sizes<1> single;
744
745 auto l2_shrinkage_bcast = l2_shrinkage.reshape(single).broadcast(bcast);
746 auto lr_bcast = lr.reshape(single).broadcast(bcast);
747 auto l1_lr_bcast = (l1 * lr).reshape(single).broadcast(bcast);
748 auto l2_lr_bcast = (l2 * lr).reshape(single).broadcast(bcast);
749 auto lr_power_bcast = -lr_power.reshape(single).broadcast(bcast);
750 const auto two = static_cast<T>(2.0);
751
752 auto new_accum = accum + grad.square();
753 auto accum_power = accum.binaryExpr(lr_power_bcast,
754 Eigen::internal::scalar_pow_op<T, T>());
755 auto new_accum_power = new_accum.binaryExpr(
756 lr_power_bcast, Eigen::internal::scalar_pow_op<T, T>());
757 auto grad_with_shrinkage =
758 grad + (var.constant(two) * l2_shrinkage_bcast * var);
759 linear.device(d) +=
760 grad_with_shrinkage * lr_bcast - (new_accum_power - accum_power) * var;
761 auto x = (l1_lr_bcast * linear.sign() - linear);
762 auto y = new_accum_power + linear.constant(two) * l2_lr_bcast;
763 auto pre_shrink = x / y;
764 var.device(d) = (linear.abs() > l1_lr_bcast)
765 .select(pre_shrink, var.constant(static_cast<T>(0)));
766 accum.device(d) += grad.square();
767 }
768 };
769
770 template <typename T, typename Tindex, bool has_l2_shrinkage>
771 struct SparseApplyFtrl<GPUDevice, T, Tindex, has_l2_shrinkage> {
operator ()tensorflow::functor::SparseApplyFtrl772 Status operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
773 typename TTypes<T>::Matrix accum,
774 typename TTypes<T>::Matrix linear,
775 typename TTypes<T>::ConstScalar lr,
776 typename TTypes<T>::ConstScalar l1,
777 typename TTypes<T>::ConstScalar l2,
778 typename TTypes<T>::ConstScalar l2_shrinkage,
779 typename TTypes<T>::ConstScalar lr_power,
780 typename TTypes<T>::ConstMatrix grad,
781 typename TTypes<Tindex>::ConstVec indices, int64 inner_dim,
782 bool multiply_linear_by_lr) {
783 const Tindex first_dim_size = var.dimension(0);
784 const Tindex grad_size = grad.size();
785 const Tindex indices_size = indices.size();
786 if (grad_size == 0) {
787 return Status::OK();
788 }
789 // The simpler overload of GetGpuLaunchConfig() would result in a "too many
790 // resources requested for launch" error.
791 auto* device_func = SparseApplyFtrlKernel<T, Tindex, has_l2_shrinkage>;
792 GpuLaunchConfig config =
793 GetGpuLaunchConfig(grad_size, d, device_func, 0, 0);
794 return GpuLaunchKernel(
795 device_func, config.block_count, config.thread_per_block, 0, d.stream(),
796 /*var=*/var.data(),
797 /*accum=*/accum.data(),
798 /*linear=*/linear.data(), /*lr=*/lr.data(), /*l1=*/l1.data(),
799 /*l2=*/l2.data(), /*l2_shrinkage=*/l2_shrinkage.data(),
800 /*lr_power=*/lr_power.data(), /*grad=*/grad.data(),
801 /*indices=*/indices.data(), /*param_rows=*/first_dim_size,
802 /*updates_size=*/grad_size,
803 /*indices_size=*/indices_size,
804 /*multiply_linear_by_lr=*/multiply_linear_by_lr);
805 }
806 };
807
808 template <typename T>
809 struct ApplyMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyMomentum810 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
811 typename TTypes<T>::Flat accum,
812 typename TTypes<T>::ConstScalar lr,
813 typename TTypes<T>::ConstFlat grad,
814 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
815 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
816 bcast[0] = grad.dimension(0);
817 Eigen::Sizes<1> single;
818 accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
819 if (use_nesterov) {
820 var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
821 accum * momentum.reshape(single).broadcast(bcast) *
822 lr.reshape(single).broadcast(bcast);
823 } else {
824 var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
825 }
826 }
827 };
828
829 template <typename T>
830 struct ApplyKerasMomentum<GPUDevice, T> {
operator ()tensorflow::functor::ApplyKerasMomentum831 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
832 typename TTypes<T>::Flat accum,
833 typename TTypes<T>::ConstScalar lr,
834 typename TTypes<T>::ConstFlat grad,
835 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
836 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
837 bcast[0] = grad.dimension(0);
838 Eigen::Sizes<1> single;
839 accum.device(d) = (accum * momentum.reshape(single).broadcast(bcast) -
840 grad * lr.reshape(single).broadcast(bcast));
841 if (use_nesterov) {
842 var.device(d) += (accum * momentum.reshape(single).broadcast(bcast) -
843 grad * lr.reshape(single).broadcast(bcast));
844 } else {
845 var.device(d) += accum;
846 }
847 }
848 };
849
850 template <typename T, typename Tindex>
851 struct SparseApplyKerasMomentum<GPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseApplyKerasMomentum852 Tindex operator()(const GPUDevice& d, typename TTypes<T>::Matrix var,
853 typename TTypes<T>::Matrix accum,
854 typename TTypes<T>::ConstScalar lr,
855 typename TTypes<T>::ConstMatrix grad,
856 typename TTypes<Tindex>::ConstVec indices,
857 typename TTypes<T>::ConstScalar momentum,
858 bool use_nesterov) {
859 const Tindex first_dim_size = var.dimension(0);
860 const Tindex grad_size = grad.size();
861 const Tindex indices_size = indices.size();
862 if (grad_size != 0) {
863 GpuLaunchConfig config = GetGpuLaunchConfig(grad_size, d);
864 TF_CHECK_OK(GpuLaunchKernel(
865 SparseApplyKerasMomentumKernel<T, Tindex>, config.block_count,
866 config.thread_per_block, 0, d.stream(), var.data(), accum.data(),
867 lr.data(), grad.data(), indices.data(), momentum.data(), use_nesterov,
868 first_dim_size, grad_size, indices_size));
869 }
870 return static_cast<Tindex>(-1);
871 }
872 };
873
874 template <typename T>
875 struct ApplyAdam<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdam876 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
877 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
878 typename TTypes<T>::ConstScalar beta1_power,
879 typename TTypes<T>::ConstScalar beta2_power,
880 typename TTypes<T>::ConstScalar lr,
881 typename TTypes<T>::ConstScalar beta1,
882 typename TTypes<T>::ConstScalar beta2,
883 typename TTypes<T>::ConstScalar epsilon,
884 typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
885 int32 data_dim = grad.dimension(0);
886 GpuLaunchConfig config = GetGpuLaunchConfig(data_dim, d);
887 eigen_assert(static_cast<int64>(grad.dimension(0)) +
888 static_cast<int64>(config.block_count) *
889 static_cast<int64>(config.thread_per_block) <
890 std::numeric_limits<int32>::max());
891
892 TF_CHECK_OK(GpuLaunchKernel(
893 ApplyAdamKernel<T>, config.block_count, config.thread_per_block, 0,
894 d.stream(), data_dim, var.data(), m.data(), v.data(),
895 beta1_power.data(), beta2_power.data(), lr.data(), beta1.data(),
896 beta2.data(), epsilon.data(), grad.data(), use_nesterov));
897 }
898 };
899
900 template <typename T>
901 struct ApplyAdamWithAmsgrad<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdamWithAmsgrad902 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
903 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
904 typename TTypes<T>::Flat vhat,
905 typename TTypes<T>::ConstScalar beta1_power,
906 typename TTypes<T>::ConstScalar beta2_power,
907 typename TTypes<T>::ConstScalar lr,
908 typename TTypes<T>::ConstScalar beta1,
909 typename TTypes<T>::ConstScalar beta2,
910 typename TTypes<T>::ConstScalar epsilon,
911 typename TTypes<T>::ConstFlat grad) {
912 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
913 bcast[0] = grad.dimension(0);
914 Eigen::Sizes<1> single;
915 const auto one = static_cast<T>(1.0);
916 m.device(d) =
917 m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
918 (grad - m);
919 v.device(d) =
920 v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
921 (grad.square() - v);
922 vhat.device(d) = vhat.cwiseMax(v);
923
924 var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
925 (beta1_power.constant(one) - beta1_power))
926 .reshape(single)
927 .broadcast(bcast) *
928 m /
929 (epsilon.reshape(single).broadcast(bcast) + vhat.sqrt());
930 }
931 };
932
933 template <typename T>
934 struct ApplyAdaMax<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAdaMax935 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
936 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
937 typename TTypes<T>::ConstScalar beta1_power,
938 typename TTypes<T>::ConstScalar lr,
939 typename TTypes<T>::ConstScalar beta1,
940 typename TTypes<T>::ConstScalar beta2,
941 typename TTypes<T>::ConstScalar epsilon,
942 typename TTypes<T>::ConstFlat grad) {
943 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
944 bcast[0] = grad.dimension(0);
945 Eigen::Sizes<1> single;
946 const auto one = static_cast<T>(1.0);
947 m.device(d) +=
948 (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
949 (grad - m);
950 v.device(d) =
951 (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
952 var.device(d) -= lr.reshape(single).broadcast(bcast) /
953 (beta1_power.constant(one) - beta1_power)
954 .reshape(single)
955 .broadcast(bcast) *
956 (m / (v + epsilon.reshape(single).broadcast(bcast)));
957 }
958 };
959
960 template <typename T>
961 struct ApplyRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyRMSProp962 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
963 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
964 typename TTypes<T>::ConstScalar lr,
965 typename TTypes<T>::ConstScalar rho,
966 typename TTypes<T>::ConstScalar momentum,
967 typename TTypes<T>::ConstScalar epsilon,
968 typename TTypes<T>::ConstFlat grad) {
969 #if TENSORFLOW_USE_ROCM
970 wrap_kernel_call(ApplyRMSPropKernel<T>, d, var, ms, mom, lr, rho, momentum,
971 epsilon, grad);
972 #else
973 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
974 bcast[0] = grad.dimension(0);
975 Eigen::Sizes<1> single;
976 const auto one = static_cast<T>(1.0);
977 ms.device(d) =
978 ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
979 (grad.square() - ms);
980 mom.device(d) =
981 mom * momentum.reshape(single).broadcast(bcast) +
982 lr.reshape(single).broadcast(bcast) * grad /
983 ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
984 var.device(d) -= mom;
985 #endif
986 }
987 };
988
989 template <typename T>
990 struct ApplyCenteredRMSProp<GPUDevice, T> {
operator ()tensorflow::functor::ApplyCenteredRMSProp991 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
992 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
993 typename TTypes<T>::Flat mom,
994 typename TTypes<T>::ConstScalar lr,
995 typename TTypes<T>::ConstScalar rho,
996 typename TTypes<T>::ConstScalar momentum,
997 typename TTypes<T>::ConstScalar epsilon,
998 typename TTypes<T>::ConstFlat grad) {
999 #if TENSORFLOW_USE_ROCM
1000 wrap_kernel_call(ApplyCenteredRMSPropKernel<T>, d, var, mg, ms, mom, lr,
1001 rho, momentum, epsilon, grad);
1002 #else
1003 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1004 bcast[0] = grad.dimension(0);
1005 Eigen::Sizes<1> single;
1006 const auto one = static_cast<T>(1.0);
1007 const auto one_minus_rho =
1008 (rho.constant(one) - rho).reshape(single).broadcast(bcast);
1009 ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
1010 mg.device(d) = mg + one_minus_rho * (grad - mg);
1011 auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast);
1012 mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
1013 lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
1014 var.device(d) -= mom;
1015 #endif
1016 }
1017 };
1018
1019 template <typename T>
1020 struct ApplyAddSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyAddSign1021 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1022 typename TTypes<T>::Flat m,
1023 typename TTypes<T>::ConstScalar lr,
1024 typename TTypes<T>::ConstScalar alpha,
1025 typename TTypes<T>::ConstScalar sign_decay,
1026 typename TTypes<T>::ConstScalar beta,
1027 typename TTypes<T>::ConstFlat grad) {
1028 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1029 bcast[0] = grad.dimension(0);
1030 Eigen::Sizes<1> single;
1031
1032 // The following is the GPU equivalent of the CPU version:
1033 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1034 const auto one = static_cast<T>(1.0);
1035 auto beta_bcast = beta.reshape(single).broadcast(bcast);
1036 auto one_minus_beta =
1037 (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1038 m.device(d) = m * beta_bcast + grad * one_minus_beta;
1039
1040 // The following is the GPU equivalent of the CPU version:
1041 // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
1042 auto sign_gm = grad.sign() * m.sign();
1043 auto lr_bcast = lr.reshape(single).broadcast(bcast);
1044 auto alpha_bcast = alpha.reshape(single).broadcast(bcast);
1045 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1046 var.device(d) -=
1047 lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad;
1048 }
1049 };
1050
1051 template <typename T>
1052 struct ApplyPowerSign<GPUDevice, T> {
operator ()tensorflow::functor::ApplyPowerSign1053 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
1054 typename TTypes<T>::Flat m,
1055 typename TTypes<T>::ConstScalar lr,
1056 typename TTypes<T>::ConstScalar logbase,
1057 typename TTypes<T>::ConstScalar sign_decay,
1058 typename TTypes<T>::ConstScalar beta,
1059 typename TTypes<T>::ConstFlat grad) {
1060 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
1061 bcast[0] = grad.dimension(0);
1062 Eigen::Sizes<1> single;
1063
1064 // The following is the GPU equivalent of the CPU version:
1065 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
1066 const auto one = static_cast<T>(1.0);
1067 auto beta_bcast = beta.reshape(single).broadcast(bcast);
1068 auto one_minus_beta =
1069 (beta.constant(one) - beta).reshape(single).broadcast(bcast);
1070 m.device(d) = m * beta_bcast + grad * one_minus_beta;
1071
1072 // The following is the GPU equivalent of the CPU version:
1073 // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
1074 // var.device(d) -= lr() * grad_scale * grad;
1075 auto sign_gm = grad.sign() * m.sign();
1076 auto lr_bcast = lr.reshape(single).broadcast(bcast);
1077 auto logbase_bcast = logbase.reshape(single).broadcast(bcast);
1078 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
1079 auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp();
1080 var.device(d) -= lr_bcast * grad_scale * grad;
1081 }
1082 };
1083
1084 } // namespace functor
1085
1086 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
1087 template struct functor::ApplyGradientDescent<GPUDevice, float>;
1088 template struct functor::ApplyGradientDescent<GPUDevice, double>;
1089 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1090 // complex sqrt
1091 template struct functor::ApplyGradientDescent<GPUDevice, complex64>;
1092 template struct functor::ApplyGradientDescent<GPUDevice, complex128>;
1093 #endif
1094
1095 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
1096 template struct functor::ApplyAdagrad<GPUDevice, float>;
1097 template struct functor::ApplyAdagrad<GPUDevice, double>;
1098 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1099 // complex sqrt
1100 template struct functor::ApplyAdagrad<GPUDevice, complex64>;
1101 template struct functor::ApplyAdagrad<GPUDevice, complex128>;
1102 #endif
1103
1104 template struct functor::ApplyAdagradV2<GPUDevice, Eigen::half>;
1105 template struct functor::ApplyAdagradV2<GPUDevice, float>;
1106 template struct functor::ApplyAdagradV2<GPUDevice, double>;
1107 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1108 // complex sqrt
1109 template struct functor::ApplyAdagradV2<GPUDevice, complex64>;
1110 template struct functor::ApplyAdagradV2<GPUDevice, complex128>;
1111 #endif
1112
1113 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
1114 template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
1115 /*has_epsilon=*/false>; \
1116 template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
1117 /*has_epsilon=*/false>; \
1118 template struct functor::SparseApplyAdagrad<GPUDevice, T, int32, \
1119 /*has_epsilon=*/true>; \
1120 template struct functor::SparseApplyAdagrad<GPUDevice, T, int64, \
1121 /*has_epsilon=*/true>
1122 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1123 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1124 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1125 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1126
1127 template struct functor::ApplyProximalAdagrad<GPUDevice, Eigen::half>;
1128 template struct functor::ApplyProximalAdagrad<GPUDevice, float>;
1129 template struct functor::ApplyProximalAdagrad<GPUDevice, double>;
1130
1131 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1132 int32>;
1133 template struct functor::SparseApplyProximalAdagrad<GPUDevice, Eigen::half,
1134 int64>;
1135 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int32>;
1136 template struct functor::SparseApplyProximalAdagrad<GPUDevice, float, int64>;
1137 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int32>;
1138 template struct functor::SparseApplyProximalAdagrad<GPUDevice, double, int64>;
1139
1140 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
1141 template struct functor::ApplyAdadelta<GPUDevice, float>;
1142 template struct functor::ApplyAdadelta<GPUDevice, double>;
1143 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1144 // complex sqrt
1145 template struct functor::ApplyAdadelta<GPUDevice, complex64>;
1146 template struct functor::ApplyAdadelta<GPUDevice, complex128>;
1147 #endif
1148
1149 template struct functor::ApplyFtrl<GPUDevice, Eigen::half>;
1150 template struct functor::ApplyFtrl<GPUDevice, float>;
1151 template struct functor::ApplyFtrl<GPUDevice, double>;
1152
1153 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, Eigen::half>;
1154 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, float>;
1155 template struct functor::ApplyFtrlMultiplyLinearByLr<GPUDevice, double>;
1156
1157 template struct functor::ApplyFtrlV2<GPUDevice, Eigen::half>;
1158 template struct functor::ApplyFtrlV2<GPUDevice, float>;
1159 template struct functor::ApplyFtrlV2<GPUDevice, double>;
1160
1161 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, Eigen::half>;
1162 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, float>;
1163 template struct functor::ApplyFtrlV2MultiplyLinearByLr<GPUDevice, double>;
1164
1165 #define EXPLICITLY_INSTANTIATE_FUNCTOR(T) \
1166 template struct functor::SparseApplyFtrl<GPUDevice, T, int32, \
1167 /*has_l2_shrinkage=*/false>; \
1168 template struct functor::SparseApplyFtrl<GPUDevice, T, int64, \
1169 /*has_l2_shrinkage=*/false>; \
1170 template struct functor::SparseApplyFtrl<GPUDevice, T, int32, \
1171 /*has_l2_shrinkage=*/true>; \
1172 template struct functor::SparseApplyFtrl<GPUDevice, T, int64, \
1173 /*has_l2_shrinkage=*/true>
1174 EXPLICITLY_INSTANTIATE_FUNCTOR(Eigen::half);
1175 EXPLICITLY_INSTANTIATE_FUNCTOR(float);
1176 EXPLICITLY_INSTANTIATE_FUNCTOR(double);
1177 #undef EXPLICITLY_INSTANTIATE_FUNCTOR
1178
1179 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
1180 template struct functor::ApplyMomentum<GPUDevice, float>;
1181 template struct functor::ApplyMomentum<GPUDevice, double>;
1182 #if !defined(TENSORFLOW_USE_NVCC) && \
1183 !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
1184 // complex sqrt
1185 template struct functor::ApplyMomentum<GPUDevice, complex64>;
1186 template struct functor::ApplyMomentum<GPUDevice, complex128>;
1187 #endif
1188
1189 template struct functor::ApplyKerasMomentum<GPUDevice, Eigen::half>;
1190 template struct functor::ApplyKerasMomentum<GPUDevice, float>;
1191 template struct functor::ApplyKerasMomentum<GPUDevice, double>;
1192 #if !defined(TENSORFLOW_USE_NVCC) && \
1193 !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
1194 // complex sqrt
1195 template struct functor::ApplyKerasMomentum<GPUDevice, complex64>;
1196 template struct functor::ApplyKerasMomentum<GPUDevice, complex128>;
1197 #endif
1198
1199 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1200 int32>;
1201 template struct functor::SparseApplyKerasMomentum<GPUDevice, Eigen::half,
1202 int64>;
1203 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int32>;
1204 template struct functor::SparseApplyKerasMomentum<GPUDevice, float, int64>;
1205 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int32>;
1206 template struct functor::SparseApplyKerasMomentum<GPUDevice, double, int64>;
1207 #if !defined(TENSORFLOW_USE_NVCC) && \
1208 !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
1209 // complex sqrt
1210 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int32>;
1211 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex64, int64>;
1212 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int32>;
1213 template struct functor::SparseApplyKerasMomentum<GPUDevice, complex128, int64>;
1214 #endif
1215
1216 template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
1217 template struct functor::ApplyAdam<GPUDevice, float>;
1218 template struct functor::ApplyAdam<GPUDevice, double>;
1219 #if !defined(TENSORFLOW_USE_NVCC) && \
1220 !defined(TENSORFLOW_USE_ROCM) // TODO(b/143684500): Eigen to support
1221 // complex sqrt
1222 template struct functor::ApplyAdam<GPUDevice, complex64>;
1223 template struct functor::ApplyAdam<GPUDevice, complex128>;
1224 #endif
1225
1226 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, Eigen::half>;
1227 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, float>;
1228 template struct functor::ApplyAdamWithAmsgrad<GPUDevice, double>;
1229
1230 template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>;
1231 template struct functor::ApplyAdaMax<GPUDevice, float>;
1232 template struct functor::ApplyAdaMax<GPUDevice, double>;
1233
1234 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
1235 template struct functor::ApplyRMSProp<GPUDevice, float>;
1236 template struct functor::ApplyRMSProp<GPUDevice, double>;
1237 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1238 // complex sqrt
1239 template struct functor::ApplyRMSProp<GPUDevice, complex64>;
1240 template struct functor::ApplyRMSProp<GPUDevice, complex128>;
1241 #endif
1242
1243 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
1244 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
1245 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
1246 #ifndef TENSORFLOW_USE_NVCC // TODO(b/143684500): Eigen to support
1247 // complex sqrt
1248 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex64>;
1249 template struct functor::ApplyCenteredRMSProp<GPUDevice, complex128>;
1250 #endif
1251
1252 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>;
1253 template struct functor::ApplyAddSign<GPUDevice, float>;
1254 template struct functor::ApplyAddSign<GPUDevice, double>;
1255
1256 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>;
1257 template struct functor::ApplyPowerSign<GPUDevice, float>;
1258 template struct functor::ApplyPowerSign<GPUDevice, double>;
1259
1260 } // end namespace tensorflow
1261
1262 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1263