1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/bias_op.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/redux_functor.h"
28 #include "tensorflow/core/util/tensor_format.h"
29 
30 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
31 #include "tensorflow/core/kernels/bias_op_gpu.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34 #if GOOGLE_CUDA
35 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
36 #endif  // GOOGLE_CUDA
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 namespace {
44 
GetBiasValueDims(const Tensor & value_tensor,TensorFormat data_format,int32 * batch,int32 * height,int32 * width,int32 * depth,int32 * channel)45 void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
46                       int32* batch, int32* height, int32* width, int32* depth,
47                       int32* channel) {
48   *batch = 1;
49   *height = 1;
50   *width = 1;
51   *depth = 1;
52   *channel = 1;
53   if (data_format == FORMAT_NHWC) {
54     int32 channel_dim = value_tensor.dims() - 1;
55     *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
56     for (int32 i = 0; i < channel_dim; i++) {
57       *batch *= static_cast<int32>(value_tensor.dim_size(i));
58     }
59   } else if (data_format == FORMAT_NCHW) {
60     *batch = static_cast<int32>(value_tensor.dim_size(0));
61     *channel = static_cast<int32>(value_tensor.dim_size(1));
62     *height = static_cast<int32>(value_tensor.dim_size(2));
63     if (value_tensor.dims() > 3) {
64       *width = static_cast<int32>(value_tensor.dim_size(3));
65     }
66     if (value_tensor.dims() > 4) {
67       *depth = static_cast<int32>(value_tensor.dim_size(4));
68     }
69   }
70 }
71 
72 template <class T>
73 struct AccumulatorType {
74   typedef T type;
75 };
76 
77 // float is faster on the CPU than half, and also more precise,
78 // so use float for the temporary accumulators.
79 template <>
80 struct AccumulatorType<Eigen::half> {
81   typedef float type;
82 };
83 
84 }  // namespace
85 
86 template <typename Device, typename T>
87 class BiasOp : public BinaryOp<T> {
88  public:
BiasOp(OpKernelConstruction * context)89   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
90     string data_format;
91     if (context->GetAttr("data_format", &data_format).ok()) {
92       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
93                   errors::InvalidArgument("Invalid data format"));
94     } else {
95       data_format_ = FORMAT_NHWC;
96     }
97   }
98 
Compute(OpKernelContext * context)99   void Compute(OpKernelContext* context) override {
100     const Tensor& input = context->input(0);
101     const Tensor& bias = context->input(1);
102 
103     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
104                 errors::InvalidArgument("Input tensor must be at least 2D: ",
105                                         input.shape().DebugString()));
106     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
107                 errors::InvalidArgument("Biases must be 1D: ",
108                                         bias.shape().DebugString()));
109 
110     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
111     size_t channel_dim;
112     if (data_format_ == FORMAT_NCHW) {
113       channel_dim = 1;  // NCHW always have channel dim in 1 (with 3, 4, 5
114                         // dimensions data).
115     } else {
116       channel_dim = input.shape().dims() - 1;  // End of code by intel_tf.
117     }
118 
119     OP_REQUIRES(
120         context,
121         bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
122         errors::InvalidArgument(
123             "Must provide as many biases as the last dimension "
124             "of the input tensor: ",
125             bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
126 
127     Tensor* output = nullptr;
128     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
129                                 {0}, 0, input.shape(), &output));
130     if (input.NumElements() == 0) return;
131 
132     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
133     if (data_format_ == FORMAT_NCHW) {
134       int32 batch, height, width, depth, channel;
135       GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
136                        &channel);
137       switch (input.shape().dims()) {
138         case 3: {
139           Eigen::DSizes<int32, 3> three_dims(1, channel, 1);
140           Eigen::DSizes<int32, 3> broad_cast_dims(batch, 1, height);
141           const Device& d = context->eigen_device<Device>();
142           output->tensor<T, 3>().device(d) =
143               input.tensor<T, 3>() + bias.tensor<T, 1>()
144                                          .reshape(three_dims)
145                                          .broadcast(broad_cast_dims);
146         } break;
147         case 4: {
148           Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
149           Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
150           const Device& d = context->eigen_device<Device>();
151           output->tensor<T, 4>().device(d) =
152               input.tensor<T, 4>() +
153               bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
154         } break;
155         case 5: {
156           Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1);
157           Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
158                                                   depth);
159           const Device& d = context->eigen_device<Device>();
160           output->tensor<T, 5>().device(d) =
161               input.tensor<T, 5>() +
162               bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims);
163         } break;
164         default:
165           OP_REQUIRES(context, false,
166                       errors::InvalidArgument("Only ranks up to 5 supported: ",
167                                               input.shape().DebugString()));
168       }
169       return;
170     }  // End of code by intel_tf.
171 
172     switch (input.shape().dims()) {
173       case 2:
174         Compute<2>(context, input, bias, output);
175         break;
176       case 3:
177         Compute<3>(context, input, bias, output);
178         break;
179       case 4:
180         Compute<4>(context, input, bias, output);
181         break;
182       case 5:
183         Compute<5>(context, input, bias, output);
184         break;
185       default:
186         OP_REQUIRES(context, false,
187                     errors::InvalidArgument("Only ranks up to 5 supported: ",
188                                             input.shape().DebugString()));
189     }
190   }
191 
192   // Add biases for an input matrix of rank Dims, by using the Bias.
193   template <int Dims>
Compute(OpKernelContext * ctx,const Tensor & input,const Tensor & bias,Tensor * output)194   void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias,
195                Tensor* output) {
196     functor::Bias<Device, T, Dims> functor;
197     functor(ctx->eigen_device<Device>(), input.tensor<T, Dims>(), bias.vec<T>(),
198             output->tensor<T, Dims>());
199   }
200 
201  private:
202   TensorFormat data_format_;
203 };
204 
205 #define REGISTER_KERNEL(type)                                         \
206   REGISTER_KERNEL_BUILDER(                                            \
207       Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
208       BiasOp<CPUDevice, type>);                                       \
209   REGISTER_KERNEL_BUILDER(                                            \
210       Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
211       BiasOp<CPUDevice, type>);
212 
213 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
214 #undef REGISTER_KERNEL
215 
216 
217 template <typename Device, typename T>
218 class BiasGradOp : public OpKernel {
219  public:
BiasGradOp(OpKernelConstruction * context)220   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
221     string data_format;
222     if (context->GetAttr("data_format", &data_format).ok()) {
223       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
224                   errors::InvalidArgument("Invalid data format"));
225     } else {
226       data_format_ = FORMAT_NHWC;
227     }
228   }
229 
Compute(OpKernelContext * context)230   void Compute(OpKernelContext* context) override {
231     const Tensor& output_backprop = context->input(0);
232 
233     OP_REQUIRES(context,
234                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
235                 errors::InvalidArgument("Input tensor must be at least 2D: ",
236                                         output_backprop.shape().DebugString()));
237 
238     OP_REQUIRES(
239         context,
240         FastBoundsCheck(output_backprop.NumElements(),
241                         std::numeric_limits<int32>::max()),
242         errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
243 
244     int32 batch, height, width, depth, channel;
245     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
246                      &depth, &channel);
247     Tensor* output = nullptr;
248     TensorShape output_shape{channel};
249     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
250 
251     if (channel == 0) {
252       return;  // Nothing to do
253     } else if (output_backprop.NumElements() == 0) {
254       // Eigen often crashes by design on empty tensors, but setZero is safe
255       output->template flat<T>().setZero();
256     } else {
257       // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
258       using AccumT = typename AccumulatorType<T>::type;
259       if (data_format_ == FORMAT_NCHW) {
260         const functor::ReduceMiddleDimensions<
261             T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>,
262             Eigen::internal::SumReducer<T>>
263             redux;
264         Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
265                                                   height * width * depth);
266         redux(context->eigen_device<Device>(), three_dims, output_backprop,
267               output, 1);
268       } else {
269         const functor::ReduceOuterDimensions<
270             T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>>
271             redux;
272 
273         Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
274                                                 channel);
275         redux(context->eigen_device<Device>(), two_dims, output_backprop,
276               output);
277       }
278     }
279   }
280 
281  private:
282   TensorFormat data_format_;
283 };
284 
285 // Registration of the GPU implementations.
286 #define REGISTER_KERNEL(type)                                           \
287   REGISTER_KERNEL_BUILDER(                                              \
288       Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
289       BiasGradOp<CPUDevice, type>);
290 
291 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
292 #undef REGISTER_KERNEL
293 
294 
295 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
296 template <typename T>
297 class BiasOp<GPUDevice, T> : public BinaryOp<T> {
298  public:
299   typedef GPUDevice Device;
BiasOp(OpKernelConstruction * context)300   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
301     string data_format;
302     if (context->GetAttr("data_format", &data_format).ok()) {
303       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
304                   errors::InvalidArgument("Invalid data format"));
305     } else {
306       data_format_ = FORMAT_NHWC;
307     }
308   }
309 
Compute(OpKernelContext * context)310   void Compute(OpKernelContext* context) override {
311     const Tensor& input = context->input(0);
312     const Tensor& bias = context->input(1);
313 
314     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
315                 errors::InvalidArgument("Input tensor must be at least 2D: ",
316                                         input.shape().DebugString()));
317     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
318                 errors::InvalidArgument("Biases must be 1D: ",
319                                         bias.shape().DebugString()));
320     int32 batch, height, width, depth, channel;
321     GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
322                      &channel);
323     OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
324                 errors::InvalidArgument(
325                     "Must provide as many biases as the channel dimension "
326                     "of the input tensor: ",
327                     bias.shape().DebugString(), " vs. ", channel, " in ",
328                     input.shape().DebugString()));
329     Tensor* output = nullptr;
330     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
331                                 {0}, 0, input.shape(), &output));
332     if (input.NumElements() > 0) {
333       BiasGPU<T>::compute(context->template eigen_device<Device>(),
334                           input.flat<T>().data(), bias.flat<T>().data(),
335                           output->flat<T>().data(), batch, width, height, depth,
336                           channel, data_format_);
337     }
338   }
339 
340  private:
341   TensorFormat data_format_;
342 };
343 
344 // Registration of the GPU implementations.
345 #define REGISTER_GPU_KERNEL(type)                                     \
346   REGISTER_KERNEL_BUILDER(                                            \
347       Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
348       BiasOp<GPUDevice, type>);                                       \
349   REGISTER_KERNEL_BUILDER(                                            \
350       Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
351       BiasOp<GPUDevice, type>);
352 
353 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
354 REGISTER_GPU_KERNEL(int32);
355 #undef REGISTER_GPU_KERNEL
356 
357 struct BiasGradAutotuneGroup {
nametensorflow::BiasGradAutotuneGroup358   static string name() { return "BiasGrad"; }
359 };
360 
361 class BiasAddGradGPUConfig {
362  public:
BiasAddGradGPUConfig()363   BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
ToString() const364   string ToString() const {
365     if (mode_ == BiasAddGradGPUMode::kNative) {
366       return "native CUDA kernel.";
367     }
368     if (mode_ == BiasAddGradGPUMode::kReduction) {
369       return "cub reduction kernel.";
370     }
371     return "unknown kernel.";
372   }
get_mode() const373   BiasAddGradGPUMode get_mode() const { return mode_; }
set_mode(BiasAddGradGPUMode val)374   void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
375 
operator ==(const BiasAddGradGPUConfig & other) const376   bool operator==(const BiasAddGradGPUConfig& other) const {
377     return this->mode_ == other.get_mode();
378   }
379 
operator !=(const BiasAddGradGPUConfig & other) const380   bool operator!=(const BiasAddGradGPUConfig& other) const {
381     return !(*this == other);
382   }
383 
384  private:
385   BiasAddGradGPUMode mode_;
386 };
387 
388 // Encapsulate all the shape information that is used in bias add grad
389 // operations.
390 class BiasAddParams {
391  public:
392   // We use a list to maintain both the shape value and the order (data format).
393   using SpatialArray = gtl::InlinedVector<int64, 4>;
BiasAddParams(const SpatialArray & in_shape,TensorFormat data_format,DataType dtype,int device_id)394   BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
395                 DataType dtype, int device_id)
396       : in_shape_(in_shape),
397         data_format_(data_format),
398         dtype_(dtype),
399         device_id_(device_id) {
400     for (int64 val : in_shape_) {
401       hash_code_ = Hash64Combine(hash_code_, val);
402     }
403     hash_code_ = Hash64Combine(hash_code_, data_format);
404     hash_code_ = Hash64Combine(hash_code_, dtype);
405     hash_code_ = Hash64Combine(hash_code_, device_id);
406   }
operator ==(const BiasAddParams & other) const407   bool operator==(const BiasAddParams& other) const {
408     return this->get_data_as_tuple() == other.get_data_as_tuple();
409   }
410 
operator !=(const BiasAddParams & other) const411   bool operator!=(const BiasAddParams& other) const {
412     return !(*this == other);
413   }
hash() const414   uint64 hash() const { return hash_code_; }
415 
ToString() const416   string ToString() const {
417     // clang-format off
418     return strings::StrCat(
419         "(", absl::StrJoin(in_shape_, ", "), "), ",
420         data_format_, ", ", dtype_, ", ", device_id_);
421     // clang-format on
422   }
423 
424  protected:
425   using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
426 
get_data_as_tuple() const427   ParamsDataType get_data_as_tuple() const {
428     return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
429   }
430 
431   uint64 hash_code_ = 0;
432 
433  private:
434   SpatialArray in_shape_;
435   TensorFormat data_format_;
436   DataType dtype_;
437   int device_id_;
438 };
439 
440 typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
441                           BiasAddGradGPUConfig>
442     AutotuneBiasGrad;
443 
444 template <typename T>
445 class BiasGradOp<GPUDevice, T> : public OpKernel {
446  public:
447   typedef GPUDevice Device;
BiasGradOp(OpKernelConstruction * context)448   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
449     string data_format;
450     if (context->GetAttr("data_format", &data_format).ok()) {
451       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
452                   errors::InvalidArgument("Invalid data format"));
453     } else {
454       data_format_ = FORMAT_NCHW;
455     }
456   }
457 
ComputeWithCustomKernel(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)458   void ComputeWithCustomKernel(OpKernelContext* context,
459                                const Tensor& output_backprop, int32 batch,
460                                int32 width, int32 height, int32 depth,
461                                int32 channel, Tensor* output) {
462     BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
463                             output_backprop.template flat<T>().data(),
464                             output->flat<T>().data(), batch, width, height,
465                             depth, channel, data_format_);
466   }
467 
ComputeWithReduceSum(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)468   void ComputeWithReduceSum(OpKernelContext* context,
469                             const Tensor& output_backprop, int32 batch,
470                             int32 width, int32 height, int32 depth,
471                             int32 channel, Tensor* output) {
472     if (data_format_ == FORMAT_NCHW) {
473       int32 row_count = batch * channel;
474       int32 col_count = height * width * depth;
475       Tensor temp_grad_outputs;
476       // For 'NCHW' format, we perform reduction twice: first HW, then N.
477       TensorShape temp_grad_output_shape{row_count, col_count};
478       OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
479                                                      temp_grad_output_shape,
480                                                      &temp_grad_outputs));
481       BiasGradGPU<T>::DoRowReduction(
482           context, temp_grad_outputs.flat<T>().data(),
483           output_backprop.template flat<T>().data(), row_count, col_count);
484 
485       row_count = batch;
486       col_count = channel;
487       BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
488                                      temp_grad_outputs.flat<T>().data(),
489                                      row_count, col_count);
490     } else {
491       // For 'NHWC', we simply apply reduction once on NHW.
492       int32 row_count = batch * height * width * depth;
493       int32 col_count = channel;
494       BiasGradGPU<T>::DoColReduction(
495           context, const_cast<T*>(output->flat<T>().data()),
496           reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
497           row_count, col_count);
498     }
499   }
500 
Compute(OpKernelContext * context)501   void Compute(OpKernelContext* context) override {
502     const Tensor& output_backprop = context->input(0);
503 
504     OP_REQUIRES(context,
505                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
506                 errors::InvalidArgument("Input tensor must be at least 2D: ",
507                                         output_backprop.shape().DebugString()));
508     int32 batch, height, width, depth, channel;
509     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
510                      &depth, &channel);
511     Tensor* output = nullptr;
512     TensorShape output_shape{channel};
513     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
514     if (channel == 0) return;
515     auto* stream = context->op_device_context()->stream();
516     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
517     se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
518                                     output->NumElements() * sizeof(T));
519     stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
520     if (output_backprop.NumElements() <= 0) return;
521 
522     int device_id = stream->parent()->device_ordinal();
523     DataType dtype = output_backprop.dtype();
524     BiasAddParams bias_parameters = {
525         {batch, height * width * depth, channel},
526         data_format_,
527         dtype,
528         device_id,
529     };
530 
531     // Autotune two algorithm: customized
532     BiasAddGradGPUConfig algo_config;
533     if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
534       BiasGradGPUProfileResult best_result;
535       // Initialize the timer.
536       perftools::gputools::Timer timer(stream->parent());
537       stream->InitTimer(&timer);
538       stream->ThenStartTimer(&timer);
539       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
540                               depth, channel, output);
541       stream->ThenStopTimer(&timer);
542       uint64 elapsed_microseconds = timer.Microseconds();
543       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
544               << " Native algo latency: " << elapsed_microseconds;
545       if (elapsed_microseconds < best_result.elapsed_time()) {
546         best_result.set_algorithm(BiasAddGradGPUMode::kNative);
547         best_result.set_elapsed_time(elapsed_microseconds);
548       }
549 
550       // Try reduction and profile.
551       stream->ThenStartTimer(&timer);
552       ComputeWithReduceSum(context, output_backprop, batch, width, height,
553                            depth, channel, output);
554       stream->ThenStopTimer(&timer);
555 
556       elapsed_microseconds = timer.Microseconds();
557       VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
558               << " Reduction algo latency: " << elapsed_microseconds;
559       if (elapsed_microseconds < best_result.elapsed_time()) {
560         best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
561         best_result.set_elapsed_time(elapsed_microseconds);
562       }
563 
564       algo_config.set_mode(best_result.algorithm());
565       AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
566 
567       // Results are already available during autotune, so no need to continue.
568       return;
569     }
570 
571     // Choose the best algorithm based on autotune results.
572     if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
573       ComputeWithReduceSum(context, output_backprop, batch, width, height,
574                            depth, channel, output);
575     } else {
576       // Default to the customized kernel.
577       ComputeWithCustomKernel(context, output_backprop, batch, width, height,
578                               depth, channel, output);
579     }
580   }
581 
582  private:
583   TensorFormat data_format_;
584 };
585 
586 // Registration of the GPU implementations.
587 #define REGISTER_GPU_KERNEL(type)                                       \
588   REGISTER_KERNEL_BUILDER(                                              \
589       Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
590       BiasGradOp<GPUDevice, type>);
591 
592 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
593 #undef REGISTER_GPU_KERNEL
594 
595 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
596 
597 }  // namespace tensorflow
598