1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/nn_ops.cc.
17
18 #define EIGEN_USE_THREADS
19
20 #include "tensorflow/core/kernels/bias_op.h"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/redux_functor.h"
28 #include "tensorflow/core/util/tensor_format.h"
29
30 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
31 #include "tensorflow/core/kernels/bias_op_gpu.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34 #if GOOGLE_CUDA
35 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
36 #endif // GOOGLE_CUDA
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42
43 namespace {
44
GetBiasValueDims(const Tensor & value_tensor,TensorFormat data_format,int32 * batch,int32 * height,int32 * width,int32 * depth,int32 * channel)45 void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
46 int32* batch, int32* height, int32* width, int32* depth,
47 int32* channel) {
48 *batch = 1;
49 *height = 1;
50 *width = 1;
51 *depth = 1;
52 *channel = 1;
53 if (data_format == FORMAT_NHWC) {
54 int32 channel_dim = value_tensor.dims() - 1;
55 *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
56 for (int32 i = 0; i < channel_dim; i++) {
57 *batch *= static_cast<int32>(value_tensor.dim_size(i));
58 }
59 } else if (data_format == FORMAT_NCHW) {
60 *batch = static_cast<int32>(value_tensor.dim_size(0));
61 *channel = static_cast<int32>(value_tensor.dim_size(1));
62 *height = static_cast<int32>(value_tensor.dim_size(2));
63 if (value_tensor.dims() > 3) {
64 *width = static_cast<int32>(value_tensor.dim_size(3));
65 }
66 if (value_tensor.dims() > 4) {
67 *depth = static_cast<int32>(value_tensor.dim_size(4));
68 }
69 }
70 }
71
72 template <class T>
73 struct AccumulatorType {
74 typedef T type;
75 };
76
77 // float is faster on the CPU than half, and also more precise,
78 // so use float for the temporary accumulators.
79 template <>
80 struct AccumulatorType<Eigen::half> {
81 typedef float type;
82 };
83
84 } // namespace
85
86 template <typename Device, typename T>
87 class BiasOp : public BinaryOp<T> {
88 public:
BiasOp(OpKernelConstruction * context)89 explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
90 string data_format;
91 if (context->GetAttr("data_format", &data_format).ok()) {
92 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
93 errors::InvalidArgument("Invalid data format"));
94 } else {
95 data_format_ = FORMAT_NHWC;
96 }
97 }
98
Compute(OpKernelContext * context)99 void Compute(OpKernelContext* context) override {
100 const Tensor& input = context->input(0);
101 const Tensor& bias = context->input(1);
102
103 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
104 errors::InvalidArgument("Input tensor must be at least 2D: ",
105 input.shape().DebugString()));
106 OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
107 errors::InvalidArgument("Biases must be 1D: ",
108 bias.shape().DebugString()));
109
110 // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
111 size_t channel_dim;
112 if (data_format_ == FORMAT_NCHW) {
113 channel_dim = 1; // NCHW always have channel dim in 1 (with 3, 4, 5
114 // dimensions data).
115 } else {
116 channel_dim = input.shape().dims() - 1; // End of code by intel_tf.
117 }
118
119 OP_REQUIRES(
120 context,
121 bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
122 errors::InvalidArgument(
123 "Must provide as many biases as the last dimension "
124 "of the input tensor: ",
125 bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
126
127 Tensor* output = nullptr;
128 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
129 {0}, 0, input.shape(), &output));
130 if (input.NumElements() == 0) return;
131
132 // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
133 if (data_format_ == FORMAT_NCHW) {
134 int32 batch, height, width, depth, channel;
135 GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
136 &channel);
137 switch (input.shape().dims()) {
138 case 3: {
139 Eigen::DSizes<int32, 3> three_dims(1, channel, 1);
140 Eigen::DSizes<int32, 3> broad_cast_dims(batch, 1, height);
141 const Device& d = context->eigen_device<Device>();
142 output->tensor<T, 3>().device(d) =
143 input.tensor<T, 3>() + bias.tensor<T, 1>()
144 .reshape(three_dims)
145 .broadcast(broad_cast_dims);
146 } break;
147 case 4: {
148 Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
149 Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
150 const Device& d = context->eigen_device<Device>();
151 output->tensor<T, 4>().device(d) =
152 input.tensor<T, 4>() +
153 bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
154 } break;
155 case 5: {
156 Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1);
157 Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
158 depth);
159 const Device& d = context->eigen_device<Device>();
160 output->tensor<T, 5>().device(d) =
161 input.tensor<T, 5>() +
162 bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims);
163 } break;
164 default:
165 OP_REQUIRES(context, false,
166 errors::InvalidArgument("Only ranks up to 5 supported: ",
167 input.shape().DebugString()));
168 }
169 return;
170 } // End of code by intel_tf.
171
172 switch (input.shape().dims()) {
173 case 2:
174 Compute<2>(context, input, bias, output);
175 break;
176 case 3:
177 Compute<3>(context, input, bias, output);
178 break;
179 case 4:
180 Compute<4>(context, input, bias, output);
181 break;
182 case 5:
183 Compute<5>(context, input, bias, output);
184 break;
185 default:
186 OP_REQUIRES(context, false,
187 errors::InvalidArgument("Only ranks up to 5 supported: ",
188 input.shape().DebugString()));
189 }
190 }
191
192 // Add biases for an input matrix of rank Dims, by using the Bias.
193 template <int Dims>
Compute(OpKernelContext * ctx,const Tensor & input,const Tensor & bias,Tensor * output)194 void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias,
195 Tensor* output) {
196 functor::Bias<Device, T, Dims> functor;
197 functor(ctx->eigen_device<Device>(), input.tensor<T, Dims>(), bias.vec<T>(),
198 output->tensor<T, Dims>());
199 }
200
201 private:
202 TensorFormat data_format_;
203 };
204
205 #define REGISTER_KERNEL(type) \
206 REGISTER_KERNEL_BUILDER( \
207 Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
208 BiasOp<CPUDevice, type>); \
209 REGISTER_KERNEL_BUILDER( \
210 Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
211 BiasOp<CPUDevice, type>);
212
213 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
214 #undef REGISTER_KERNEL
215
216
217 template <typename Device, typename T>
218 class BiasGradOp : public OpKernel {
219 public:
BiasGradOp(OpKernelConstruction * context)220 explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
221 string data_format;
222 if (context->GetAttr("data_format", &data_format).ok()) {
223 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
224 errors::InvalidArgument("Invalid data format"));
225 } else {
226 data_format_ = FORMAT_NHWC;
227 }
228 }
229
Compute(OpKernelContext * context)230 void Compute(OpKernelContext* context) override {
231 const Tensor& output_backprop = context->input(0);
232
233 OP_REQUIRES(context,
234 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
235 errors::InvalidArgument("Input tensor must be at least 2D: ",
236 output_backprop.shape().DebugString()));
237
238 OP_REQUIRES(
239 context,
240 FastBoundsCheck(output_backprop.NumElements(),
241 std::numeric_limits<int32>::max()),
242 errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
243
244 int32 batch, height, width, depth, channel;
245 GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
246 &depth, &channel);
247 Tensor* output = nullptr;
248 TensorShape output_shape{channel};
249 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
250
251 if (channel == 0) {
252 return; // Nothing to do
253 } else if (output_backprop.NumElements() == 0) {
254 // Eigen often crashes by design on empty tensors, but setZero is safe
255 output->template flat<T>().setZero();
256 } else {
257 // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
258 using AccumT = typename AccumulatorType<T>::type;
259 if (data_format_ == FORMAT_NCHW) {
260 const functor::ReduceMiddleDimensions<
261 T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>,
262 Eigen::internal::SumReducer<T>>
263 redux;
264 Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
265 height * width * depth);
266 redux(context->eigen_device<Device>(), three_dims, output_backprop,
267 output, 1);
268 } else {
269 const functor::ReduceOuterDimensions<
270 T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>>
271 redux;
272
273 Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
274 channel);
275 redux(context->eigen_device<Device>(), two_dims, output_backprop,
276 output);
277 }
278 }
279 }
280
281 private:
282 TensorFormat data_format_;
283 };
284
285 // Registration of the GPU implementations.
286 #define REGISTER_KERNEL(type) \
287 REGISTER_KERNEL_BUILDER( \
288 Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
289 BiasGradOp<CPUDevice, type>);
290
291 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
292 #undef REGISTER_KERNEL
293
294
295 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
296 template <typename T>
297 class BiasOp<GPUDevice, T> : public BinaryOp<T> {
298 public:
299 typedef GPUDevice Device;
BiasOp(OpKernelConstruction * context)300 explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
301 string data_format;
302 if (context->GetAttr("data_format", &data_format).ok()) {
303 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
304 errors::InvalidArgument("Invalid data format"));
305 } else {
306 data_format_ = FORMAT_NHWC;
307 }
308 }
309
Compute(OpKernelContext * context)310 void Compute(OpKernelContext* context) override {
311 const Tensor& input = context->input(0);
312 const Tensor& bias = context->input(1);
313
314 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
315 errors::InvalidArgument("Input tensor must be at least 2D: ",
316 input.shape().DebugString()));
317 OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
318 errors::InvalidArgument("Biases must be 1D: ",
319 bias.shape().DebugString()));
320 int32 batch, height, width, depth, channel;
321 GetBiasValueDims(input, data_format_, &batch, &height, &width, &depth,
322 &channel);
323 OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
324 errors::InvalidArgument(
325 "Must provide as many biases as the channel dimension "
326 "of the input tensor: ",
327 bias.shape().DebugString(), " vs. ", channel, " in ",
328 input.shape().DebugString()));
329 Tensor* output = nullptr;
330 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
331 {0}, 0, input.shape(), &output));
332 if (input.NumElements() > 0) {
333 BiasGPU<T>::compute(context->template eigen_device<Device>(),
334 input.flat<T>().data(), bias.flat<T>().data(),
335 output->flat<T>().data(), batch, width, height, depth,
336 channel, data_format_);
337 }
338 }
339
340 private:
341 TensorFormat data_format_;
342 };
343
344 // Registration of the GPU implementations.
345 #define REGISTER_GPU_KERNEL(type) \
346 REGISTER_KERNEL_BUILDER( \
347 Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
348 BiasOp<GPUDevice, type>); \
349 REGISTER_KERNEL_BUILDER( \
350 Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
351 BiasOp<GPUDevice, type>);
352
353 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
354 REGISTER_GPU_KERNEL(int32);
355 #undef REGISTER_GPU_KERNEL
356
357 struct BiasGradAutotuneGroup {
nametensorflow::BiasGradAutotuneGroup358 static string name() { return "BiasGrad"; }
359 };
360
361 class BiasAddGradGPUConfig {
362 public:
BiasAddGradGPUConfig()363 BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
ToString() const364 string ToString() const {
365 if (mode_ == BiasAddGradGPUMode::kNative) {
366 return "native CUDA kernel.";
367 }
368 if (mode_ == BiasAddGradGPUMode::kReduction) {
369 return "cub reduction kernel.";
370 }
371 return "unknown kernel.";
372 }
get_mode() const373 BiasAddGradGPUMode get_mode() const { return mode_; }
set_mode(BiasAddGradGPUMode val)374 void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
375
operator ==(const BiasAddGradGPUConfig & other) const376 bool operator==(const BiasAddGradGPUConfig& other) const {
377 return this->mode_ == other.get_mode();
378 }
379
operator !=(const BiasAddGradGPUConfig & other) const380 bool operator!=(const BiasAddGradGPUConfig& other) const {
381 return !(*this == other);
382 }
383
384 private:
385 BiasAddGradGPUMode mode_;
386 };
387
388 // Encapsulate all the shape information that is used in bias add grad
389 // operations.
390 class BiasAddParams {
391 public:
392 // We use a list to maintain both the shape value and the order (data format).
393 using SpatialArray = gtl::InlinedVector<int64, 4>;
BiasAddParams(const SpatialArray & in_shape,TensorFormat data_format,DataType dtype,int device_id)394 BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
395 DataType dtype, int device_id)
396 : in_shape_(in_shape),
397 data_format_(data_format),
398 dtype_(dtype),
399 device_id_(device_id) {
400 for (int64 val : in_shape_) {
401 hash_code_ = Hash64Combine(hash_code_, val);
402 }
403 hash_code_ = Hash64Combine(hash_code_, data_format);
404 hash_code_ = Hash64Combine(hash_code_, dtype);
405 hash_code_ = Hash64Combine(hash_code_, device_id);
406 }
operator ==(const BiasAddParams & other) const407 bool operator==(const BiasAddParams& other) const {
408 return this->get_data_as_tuple() == other.get_data_as_tuple();
409 }
410
operator !=(const BiasAddParams & other) const411 bool operator!=(const BiasAddParams& other) const {
412 return !(*this == other);
413 }
hash() const414 uint64 hash() const { return hash_code_; }
415
ToString() const416 string ToString() const {
417 // clang-format off
418 return strings::StrCat(
419 "(", absl::StrJoin(in_shape_, ", "), "), ",
420 data_format_, ", ", dtype_, ", ", device_id_);
421 // clang-format on
422 }
423
424 protected:
425 using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
426
get_data_as_tuple() const427 ParamsDataType get_data_as_tuple() const {
428 return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
429 }
430
431 uint64 hash_code_ = 0;
432
433 private:
434 SpatialArray in_shape_;
435 TensorFormat data_format_;
436 DataType dtype_;
437 int device_id_;
438 };
439
440 typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
441 BiasAddGradGPUConfig>
442 AutotuneBiasGrad;
443
444 template <typename T>
445 class BiasGradOp<GPUDevice, T> : public OpKernel {
446 public:
447 typedef GPUDevice Device;
BiasGradOp(OpKernelConstruction * context)448 explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
449 string data_format;
450 if (context->GetAttr("data_format", &data_format).ok()) {
451 OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
452 errors::InvalidArgument("Invalid data format"));
453 } else {
454 data_format_ = FORMAT_NCHW;
455 }
456 }
457
ComputeWithCustomKernel(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)458 void ComputeWithCustomKernel(OpKernelContext* context,
459 const Tensor& output_backprop, int32 batch,
460 int32 width, int32 height, int32 depth,
461 int32 channel, Tensor* output) {
462 BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
463 output_backprop.template flat<T>().data(),
464 output->flat<T>().data(), batch, width, height,
465 depth, channel, data_format_);
466 }
467
ComputeWithReduceSum(OpKernelContext * context,const Tensor & output_backprop,int32 batch,int32 width,int32 height,int32 depth,int32 channel,Tensor * output)468 void ComputeWithReduceSum(OpKernelContext* context,
469 const Tensor& output_backprop, int32 batch,
470 int32 width, int32 height, int32 depth,
471 int32 channel, Tensor* output) {
472 if (data_format_ == FORMAT_NCHW) {
473 int32 row_count = batch * channel;
474 int32 col_count = height * width * depth;
475 Tensor temp_grad_outputs;
476 // For 'NCHW' format, we perform reduction twice: first HW, then N.
477 TensorShape temp_grad_output_shape{row_count, col_count};
478 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
479 temp_grad_output_shape,
480 &temp_grad_outputs));
481 BiasGradGPU<T>::DoRowReduction(
482 context, temp_grad_outputs.flat<T>().data(),
483 output_backprop.template flat<T>().data(), row_count, col_count);
484
485 row_count = batch;
486 col_count = channel;
487 BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
488 temp_grad_outputs.flat<T>().data(),
489 row_count, col_count);
490 } else {
491 // For 'NHWC', we simply apply reduction once on NHW.
492 int32 row_count = batch * height * width * depth;
493 int32 col_count = channel;
494 BiasGradGPU<T>::DoColReduction(
495 context, const_cast<T*>(output->flat<T>().data()),
496 reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
497 row_count, col_count);
498 }
499 }
500
Compute(OpKernelContext * context)501 void Compute(OpKernelContext* context) override {
502 const Tensor& output_backprop = context->input(0);
503
504 OP_REQUIRES(context,
505 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
506 errors::InvalidArgument("Input tensor must be at least 2D: ",
507 output_backprop.shape().DebugString()));
508 int32 batch, height, width, depth, channel;
509 GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
510 &depth, &channel);
511 Tensor* output = nullptr;
512 TensorShape output_shape{channel};
513 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
514 if (channel == 0) return;
515 auto* stream = context->op_device_context()->stream();
516 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
517 se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
518 output->NumElements() * sizeof(T));
519 stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
520 if (output_backprop.NumElements() <= 0) return;
521
522 int device_id = stream->parent()->device_ordinal();
523 DataType dtype = output_backprop.dtype();
524 BiasAddParams bias_parameters = {
525 {batch, height * width * depth, channel},
526 data_format_,
527 dtype,
528 device_id,
529 };
530
531 // Autotune two algorithm: customized
532 BiasAddGradGPUConfig algo_config;
533 if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
534 BiasGradGPUProfileResult best_result;
535 // Initialize the timer.
536 perftools::gputools::Timer timer(stream->parent());
537 stream->InitTimer(&timer);
538 stream->ThenStartTimer(&timer);
539 ComputeWithCustomKernel(context, output_backprop, batch, width, height,
540 depth, channel, output);
541 stream->ThenStopTimer(&timer);
542 uint64 elapsed_microseconds = timer.Microseconds();
543 VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
544 << " Native algo latency: " << elapsed_microseconds;
545 if (elapsed_microseconds < best_result.elapsed_time()) {
546 best_result.set_algorithm(BiasAddGradGPUMode::kNative);
547 best_result.set_elapsed_time(elapsed_microseconds);
548 }
549
550 // Try reduction and profile.
551 stream->ThenStartTimer(&timer);
552 ComputeWithReduceSum(context, output_backprop, batch, width, height,
553 depth, channel, output);
554 stream->ThenStopTimer(&timer);
555
556 elapsed_microseconds = timer.Microseconds();
557 VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
558 << " Reduction algo latency: " << elapsed_microseconds;
559 if (elapsed_microseconds < best_result.elapsed_time()) {
560 best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
561 best_result.set_elapsed_time(elapsed_microseconds);
562 }
563
564 algo_config.set_mode(best_result.algorithm());
565 AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
566
567 // Results are already available during autotune, so no need to continue.
568 return;
569 }
570
571 // Choose the best algorithm based on autotune results.
572 if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
573 ComputeWithReduceSum(context, output_backprop, batch, width, height,
574 depth, channel, output);
575 } else {
576 // Default to the customized kernel.
577 ComputeWithCustomKernel(context, output_backprop, batch, width, height,
578 depth, channel, output);
579 }
580 }
581
582 private:
583 TensorFormat data_format_;
584 };
585
586 // Registration of the GPU implementations.
587 #define REGISTER_GPU_KERNEL(type) \
588 REGISTER_KERNEL_BUILDER( \
589 Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
590 BiasGradOp<GPUDevice, type>);
591
592 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
593 #undef REGISTER_GPU_KERNEL
594
595 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
596
597 } // namespace tensorflow
598