1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #include "tensorflow/core/kernels/conv_grad_input_ops.h"
19 
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #include "tensorflow/core/protobuf/autotuning.pb.h"
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 namespace tensorflow {
25 
26 typedef Eigen::ThreadPoolDevice CPUDevice;
27 typedef Eigen::GpuDevice GPUDevice;
28 
29 // To be used inside depthwise_conv_grad_op.cc.
30 template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
31 template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
32 template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
33 
34 // GPU definitions.
35 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
36 // The slow version (but compiles for GPU)
37 
38 // A dummy type to group forward backward data autotune results together.
39 struct ConvBackwardDataAutoTuneGroup {
nametensorflow::ConvBackwardDataAutoTuneGroup40   static string name() { return "ConvBwdData"; }
41 };
42 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
43                           se::dnn::AlgorithmConfig>
44     AutoTuneConvBwdData;
45 
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on GPU
48 // for int32 inputs.
49 template <>
50 struct LaunchConv2DBackpropInputOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DBackpropInputOp51   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
52                   const Tensor& out_backprop, const Tensor& filter,
53                   int row_dilation, int col_dilation, int row_stride,
54                   int col_stride, const Padding& padding,
55                   const std::vector<int64>& explicit_paddings,
56                   Tensor* in_backprop, TensorFormat data_format) {
57     LaunchConv2DBackpropInputOpImpl<GPUDevice, int32> launcher;
58     launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
59              row_dilation, col_dilation, row_stride, col_stride, padding,
60              explicit_paddings, in_backprop, data_format);
61   }
62 };
63 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
64 
65 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & out_backprop,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * in_backprop,TensorFormat data_format)66 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
67     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
68     const Tensor& out_backprop, const Tensor& filter, int row_dilation,
69     int col_dilation, int row_stride, int col_stride, const Padding& padding,
70     const std::vector<int64>& explicit_paddings, Tensor* in_backprop,
71     TensorFormat data_format) {
72   using se::dnn::AlgorithmConfig;
73   using se::dnn::AlgorithmDesc;
74   using se::dnn::ProfileResult;
75 
76   std::vector<int32> strides(4, 1);
77   std::vector<int32> dilations(4, 1);
78   auto input_h = GetTensorDimIndex(data_format, 'H');
79   auto input_w = GetTensorDimIndex(data_format, 'W');
80   strides[input_h] = row_stride;
81   strides[input_w] = col_stride;
82   dilations[input_h] = row_dilation;
83   dilations[input_w] = col_dilation;
84   TensorShape input_shape = in_backprop->shape();
85 
86   const TensorShape& filter_shape = filter.shape();
87   ConvBackpropDimensions dims;
88   OP_REQUIRES_OK(
89       ctx, ConvBackpropComputeDimensionsV2(
90                "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape,
91                filter_shape, out_backprop.shape(), dilations, strides, padding,
92                explicit_paddings, data_format, &dims));
93 
94   int64 padding_top = -1, padding_bottom = -1;
95   int64 padding_left = -1, padding_right = -1;
96   if (padding == EXPLICIT) {
97     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
98                              &padding_bottom);
99     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
100                              &padding_right);
101   }
102   int64 expected_out_rows, expected_out_cols;
103   // The function is guaranteed to succeed because we checked the output and
104   // padding was valid earlier.
105   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
106       dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
107       row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
108       &padding_bottom));
109   DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
110   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
111       dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
112       col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
113       &padding_right));
114   DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
115 
116   auto* stream = ctx->op_device_context()->stream();
117   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
118 
119   if (!use_cudnn) {
120     ctx->SetStatus(errors::Unimplemented(
121         "Conv2DBackpropInput for GPU is not currently supported "
122         "without cudnn"));
123     return;
124   }
125 
126   // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
127   // input depth, it's a depthwise convolution. More generally, if the filter
128   // in-depth divides but is smaller than the input depth, it is a grouped
129   // convolution.
130   bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
131   if (dims.spatial_dims[0].filter_size == 1 &&
132       dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
133       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
134       data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
135     // 1x1 filter, so call cublas directly.
136     const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
137                      dims.spatial_dims[1].input_size;
138     const uint64 k = dims.out_depth;
139     const uint64 n = dims.in_depth;
140 
141     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
142                                 out_backprop.template flat<T>().size());
143     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
144                                 filter.template flat<T>().size());
145     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
146                                 in_backprop->template flat<T>().size());
147 
148     auto transpose = se::blas::Transpose::kTranspose;
149     auto no_transpose = se::blas::Transpose::kNoTranspose;
150 
151     bool blas_launch_status =
152         stream
153             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
154                            a_ptr, k, 0.0f, &c_ptr, n)
155             .ok();
156     if (!blas_launch_status) {
157       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
158                                       ", n=", n, ", k=", k));
159     }
160     return;
161   } else if (dims.spatial_dims[0].filter_size ==
162                  dims.spatial_dims[0].input_size &&
163              dims.spatial_dims[1].filter_size ==
164                  dims.spatial_dims[1].input_size &&
165              !is_grouped_convolution && padding == VALID &&
166              data_format == FORMAT_NHWC) {
167     // The input data and filter have the same height/width, and we are not
168     // using grouped convolution, so call cublas directly.
169     const uint64 m = dims.batch_size;
170     const uint64 k = dims.out_depth;
171     const uint64 n = dims.spatial_dims[0].input_size *
172                      dims.spatial_dims[1].input_size * dims.in_depth;
173 
174     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
175                                 out_backprop.template flat<T>().size());
176     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
177                                 filter.template flat<T>().size());
178     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
179                                 in_backprop->template flat<T>().size());
180 
181     auto transpose = se::blas::Transpose::kTranspose;
182     auto no_transpose = se::blas::Transpose::kNoTranspose;
183 
184     bool blas_launch_status =
185         stream
186             ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
187                            a_ptr, k, 0.0f, &c_ptr, n)
188             .ok();
189     if (!blas_launch_status) {
190       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
191                                       ", n=", n, ", k=", k));
192     }
193     return;
194   }
195 
196   const int64 common_padding_rows = std::min(padding_top, padding_bottom);
197   const int64 common_padding_cols = std::min(padding_left, padding_right);
198   TensorShape compatible_input_shape;
199   if (padding_top != padding_bottom || padding_left != padding_right) {
200     // Pad the input in the same way we did during the forward pass, so that
201     // cuDNN or MIOpen receives the same input during the backward pass function
202     // as it did during the forward pass function.
203     const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
204     const int64 padding_cols_diff = std::abs(padding_right - padding_left);
205     const int64 new_in_rows =
206         dims.spatial_dims[0].input_size + padding_rows_diff;
207     const int64 new_in_cols =
208         dims.spatial_dims[1].input_size + padding_cols_diff;
209     compatible_input_shape = ShapeFromFormat(
210         data_format, dims.batch_size, new_in_rows, new_in_cols, dims.in_depth);
211   } else {
212     compatible_input_shape = input_shape;
213   }
214 
215   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
216       << "Negative row or col paddings: (" << common_padding_rows << ", "
217       << common_padding_cols << ")";
218 
219   // The Tensor Core in NVIDIA Volta+ GPUs supports efficient convolution with
220   // fp16 in NHWC data layout. In all other configurations it's more efficient
221   // to run computation in NCHW data format.
222   const bool compute_in_nhwc =
223       DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
224 
225   // We only do one directional conversion: NHWC->NCHW. We never convert in the
226   // other direction. Grappler layout optimizer selects the preferred layout and
227   // adds necessary annotations to the graph.
228   const TensorFormat compute_data_format =
229       (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
230                                                       : FORMAT_NCHW;
231 
232   VLOG(3) << "Compute Conv2DBackpropInput with cuDNN:"
233           << " data_format=" << ToString(data_format)
234           << " compute_data_format=" << ToString(compute_data_format);
235 
236   constexpr auto kComputeInNHWC =
237       std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
238                       se::dnn::FilterLayout::kOutputYXInput);
239   constexpr auto kComputeInNCHW =
240       std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
241                       se::dnn::FilterLayout::kOutputInputYX);
242 
243   se::dnn::DataLayout compute_data_layout;
244   se::dnn::FilterLayout filter_layout;
245 
246   std::tie(compute_data_layout, filter_layout) =
247       compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
248 
249   se::dnn::BatchDescriptor input_desc;
250   input_desc.set_count(dims.batch_size)
251       .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
252       .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
253       .set_feature_map_count(dims.in_depth)
254       .set_layout(compute_data_layout);
255   se::dnn::BatchDescriptor output_desc;
256   output_desc.set_count(dims.batch_size)
257       .set_height(dims.spatial_dims[0].output_size)
258       .set_width(dims.spatial_dims[1].output_size)
259       .set_feature_map_count(dims.out_depth)
260       .set_layout(compute_data_layout);
261   se::dnn::FilterDescriptor filter_desc;
262   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
263       .set_input_filter_width(dims.spatial_dims[1].filter_size)
264       .set_input_feature_map_count(filter_shape.dim_size(2))
265       .set_output_feature_map_count(filter_shape.dim_size(3))
266       .set_layout(filter_layout);
267   se::dnn::ConvolutionDescriptor conv_desc;
268   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
269       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
270       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
271       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
272       .set_zero_padding_height(common_padding_rows)
273       .set_zero_padding_width(common_padding_cols)
274       .set_group_count(dims.in_depth / filter_shape.dim_size(2));
275 
276   // Tensorflow filter format: HWIO
277   // cuDNN filter formats: (data format) -> (filter format)
278   //   (1) NCHW -> OIHW
279   //   (2) NHWC -> OHWI
280 
281   Tensor transformed_filter;
282   const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
283     VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
284             << " to " << ToString(dst_format);
285 
286     TensorShape dst_shape =
287         dst_format == FORMAT_OIHW
288             ? TensorShape({filter.dim_size(3), filter.dim_size(2),
289                            filter.dim_size(0), filter.dim_size(1)})
290             : TensorShape({filter.dim_size(3), filter.dim_size(0),
291                            filter.dim_size(1), filter.dim_size(2)});
292 
293     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
294                                           &transformed_filter));
295     functor::TransformFilter<GPUDevice, T, int, 4>()(
296         ctx->eigen_device<GPUDevice>(), dst_format,
297         To32Bit(filter.tensor<T, 4>()),
298         To32Bit(transformed_filter.tensor<T, 4>()));
299 
300     return Status::OK();
301   };
302 
303   if (compute_data_format == FORMAT_NCHW) {
304     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
305   } else if (compute_data_format == FORMAT_NHWC) {
306     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
307   } else {
308     ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
309                                            ToString(compute_data_format)));
310     return;
311   }
312 
313   Tensor transformed_out_backprop;
314   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
315     VLOG(4) << "Convert the `out_backprop` tensor from NHWC to NCHW.";
316     TensorShape compute_shape = ShapeFromFormat(
317         compute_data_format, dims.batch_size, dims.spatial_dims[0].output_size,
318         dims.spatial_dims[1].output_size, dims.out_depth);
319     if (dims.out_depth > 1) {
320       OP_REQUIRES_OK(ctx,
321                      ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape,
322                                         &transformed_out_backprop));
323       functor::NHWCToNCHW<GPUDevice, T, 4>()(
324           ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
325           transformed_out_backprop.tensor<T, 4>());
326     } else {
327       // If depth <= 1, then just reshape.
328       CHECK(transformed_out_backprop.CopyFrom(out_backprop, compute_shape));
329     }
330   } else {
331     transformed_out_backprop = out_backprop;
332   }
333 
334   Tensor pre_transformed_in_backprop;
335   OP_REQUIRES_OK(
336       ctx, ctx->allocate_temp(
337                DataTypeToEnum<T>::value,
338                ShapeFromFormat(
339                    compute_data_format,
340                    GetTensorDim(compatible_input_shape, data_format, 'N'),
341                    GetTensorDim(compatible_input_shape, data_format, 'H'),
342                    GetTensorDim(compatible_input_shape, data_format, 'W'),
343                    GetTensorDim(compatible_input_shape, data_format, 'C')),
344                &pre_transformed_in_backprop));
345 
346   auto out_backprop_ptr =
347       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
348                      transformed_out_backprop.template flat<T>().size());
349   auto filter_ptr =
350       AsDeviceMemory(transformed_filter.template flat<T>().data(),
351                      transformed_filter.template flat<T>().size());
352   auto in_backprop_ptr =
353       AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
354                      pre_transformed_in_backprop.template flat<T>().size());
355 
356   static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
357       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB by default
358   );
359   DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
360   int device_id = stream->parent()->device_ordinal();
361   DataType dtype = out_backprop.dtype();
362   ConvParameters conv_parameters = {
363       dims.batch_size,                     // batch
364       dims.in_depth,                       // in_depths
365       {{input_desc.height(),               // in_rows
366         input_desc.width()}},              // in_cols
367       compute_data_format,                 // compute_data_format
368       dims.out_depth,                      // out_depths
369       {{dims.spatial_dims[0].filter_size,  // filter_rows
370         dims.spatial_dims[1].filter_size,  // filter_cols
371         filter_shape.dim_size(2)}},        // filter_depths
372       {{dims.spatial_dims[0].dilation,     // dilation_rows
373         dims.spatial_dims[1].dilation}},   // dilation_cols
374       {{dims.spatial_dims[0].stride,       // stride_rows
375         dims.spatial_dims[1].stride}},     // stride_cols
376       {{common_padding_rows,               // padding_rows
377         common_padding_cols}},             // padding_cols
378       dtype,                               // tensor data type
379       device_id,                           // device_id
380       conv_desc.group_count()              // group_count
381   };
382 #if TENSORFLOW_USE_ROCM
383   // cudnn_use_autotune is applicable only the CUDA flow
384   // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
385   // if we do not have a cached algorithm_config for this conv_parameters
386   cudnn_use_autotune = true;
387 #endif
388   AlgorithmConfig algorithm_config;
389   if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
390                                 conv_parameters, &algorithm_config)) {
391 #if GOOGLE_CUDA
392 
393     se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
394                                                 stream);
395 
396     se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
397                                       se::GpuAsmOpts());
398 
399     se::DeviceMemory<T> in_backprop_ptr_rz(
400         WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
401 
402     std::vector<AlgorithmDesc> algorithms;
403     CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
404         conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
405         &algorithms));
406     std::vector<tensorflow::AutotuneResult> results;
407     for (const auto& profile_algorithm : algorithms) {
408       // TODO(zhengxq): profile each algorithm multiple times to better
409       // accuracy.
410       DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
411                                             ctx);
412       se::RedzoneAllocator rz_scratch_allocator(
413           stream, &tf_allocator_adapter, se::GpuAsmOpts(),
414           /*memory_limit=*/ConvolveBackwardDataScratchSize);
415       se::ScratchAllocator* allocator_used =
416           !RedzoneCheckDisabled()
417               ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
418               : static_cast<se::ScratchAllocator*>(&scratch_allocator);
419       ProfileResult profile_result;
420       auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
421           filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
422           input_desc, &in_backprop_ptr_rz, allocator_used,
423           AlgorithmConfig(profile_algorithm), &profile_result);
424       if (cudnn_launch_status.ok() && profile_result.is_valid()) {
425         results.emplace_back();
426         auto& result = results.back();
427         result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
428         result.mutable_conv()->set_tensor_ops_enabled(
429             profile_algorithm.tensor_ops_enabled());
430         result.set_scratch_bytes(
431             !RedzoneCheckDisabled()
432                 ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
433                 : scratch_allocator.TotalByteSize());
434         *result.mutable_run_time() = proto_utils::ToDurationProto(
435             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
436 
437         CheckRedzones(rz_scratch_allocator, &result);
438         CheckRedzones(rz_allocator, &result);
439       }
440     }
441 #elif TENSORFLOW_USE_ROCM
442     DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
443     std::vector<ProfileResult> algorithms;
444     OP_REQUIRES(
445         ctx,
446         stream->parent()->GetMIOpenConvolveAlgorithms(
447             se::dnn::ConvolutionKind::BACKWARD_DATA,
448             se::dnn::ToDataType<T>::value, stream, input_desc, in_backprop_ptr,
449             filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
450             &scratch_allocator, &algorithms),
451         errors::Unknown(
452             "Failed to get convolution algorithm. This is probably "
453             "because MIOpen failed to initialize, so try looking to "
454             "see if a warning log message was printed above."));
455 
456     std::vector<tensorflow::AutotuneResult> results;
457     if (algorithms.size() == 1) {
458       auto profile_result = algorithms[0];
459       results.emplace_back();
460       auto& result = results.back();
461       result.mutable_conv()->set_algorithm(
462           profile_result.algorithm().algo_id());
463       result.mutable_conv()->set_tensor_ops_enabled(
464           profile_result.algorithm().tensor_ops_enabled());
465 
466       result.set_scratch_bytes(profile_result.scratch_size());
467       *result.mutable_run_time() = proto_utils::ToDurationProto(
468           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
469     } else {
470       for (auto miopen_algorithm : algorithms) {
471         auto profile_algorithm = miopen_algorithm.algorithm();
472         ProfileResult profile_result;
473         auto miopen_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
474             filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
475             input_desc, &in_backprop_ptr, &scratch_allocator,
476             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
477             &profile_result);
478 
479         if (miopen_launch_status.ok() && profile_result.is_valid()) {
480           results.emplace_back();
481           auto& result = results.back();
482           result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
483           result.mutable_conv()->set_tensor_ops_enabled(
484               profile_algorithm.tensor_ops_enabled());
485           result.set_scratch_bytes(scratch_allocator.TotalByteSize());
486           *result.mutable_run_time() = proto_utils::ToDurationProto(
487               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
488         }
489       }
490     }
491 #endif
492     LogConvAutotuneResults(
493         se::dnn::ConvolutionKind::BACKWARD_DATA, se::dnn::ToDataType<T>::value,
494         in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc,
495         output_desc, conv_desc, stream->parent(), results);
496     OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
497     AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
498                                                algorithm_config);
499   }
500   auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
501       filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
502       input_desc, &in_backprop_ptr, &scratch_allocator, algorithm_config,
503       nullptr);
504 
505   if (!cudnn_launch_status.ok()) {
506     ctx->SetStatus(cudnn_launch_status);
507     return;
508   }
509 
510   if (padding_top != padding_bottom || padding_left != padding_right) {
511     Tensor in_backprop_remove_padding;
512     OP_REQUIRES_OK(
513         ctx, ctx->allocate_temp(
514                  DataTypeToEnum<T>::value,
515                  ShapeFromFormat(compute_data_format,
516                                  GetTensorDim(input_shape, data_format, 'N'),
517                                  GetTensorDim(input_shape, data_format, 'H'),
518                                  GetTensorDim(input_shape, data_format, 'W'),
519                                  GetTensorDim(input_shape, data_format, 'C')),
520                  &in_backprop_remove_padding));
521 
522     // Remove the padding that was added to the input shape above.
523     const int64 input_pad_top = padding_top - common_padding_rows;
524     const int64 input_pad_bottom = padding_bottom - common_padding_rows;
525     const int64 input_pad_left = padding_left - common_padding_cols;
526     const int64 input_pad_right = padding_right - common_padding_cols;
527     functor::PadInput<GPUDevice, T, int, 4>()(
528         ctx->template eigen_device<GPUDevice>(),
529         To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
530                     .tensor<T, 4>()),
531         {{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
532         {{static_cast<int>(-input_pad_bottom),
533           static_cast<int>(-input_pad_right)}},
534         To32Bit(in_backprop_remove_padding.tensor<T, 4>()), compute_data_format,
535         T{});
536 
537     pre_transformed_in_backprop = in_backprop_remove_padding;
538   }
539 
540   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
541     VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
542     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
543     functor::NCHWToNHWC<GPUDevice, T, 4>()(
544         ctx->eigen_device<GPUDevice>(),
545         toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
546         in_backprop->tensor<T, 4>());
547   } else {
548     *in_backprop = pre_transformed_in_backprop;
549   }
550 }
551 
552 // Forward declarations of the functor specializations for GPU.
553 namespace functor {
554 #define DECLARE_GPU_SPEC(T)                                             \
555   template <>                                                           \
556   void TransformFilter<GPUDevice, T, int, 4>::operator()(               \
557       const GPUDevice& d, FilterTensorFormat dst_filter_format,         \
558       typename TTypes<T, 4, int>::ConstTensor in,                       \
559       typename TTypes<T, 4, int>::Tensor out);                          \
560   extern template struct TransformFilter<GPUDevice, T, int, 4>;         \
561   template <>                                                           \
562   void PadInput<GPUDevice, T, int, 4>::operator()(                      \
563       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,   \
564       const std::array<int, 2>& padding_left,                           \
565       const std::array<int, 2>& padding_right,                          \
566       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
567       const T& padding_value);                                          \
568   extern template struct PadInput<GPUDevice, T, int, 4>;
569 
570 DECLARE_GPU_SPEC(float);
571 DECLARE_GPU_SPEC(Eigen::half);
572 DECLARE_GPU_SPEC(double);
573 #undef DECLARE_GPU_SPEC
574 
575 template <>
576 void SpatialConvolutionBackwardInputFunc<GPUDevice, int32>::operator()(
577     const GPUDevice&, typename TTypes<int32, 4>::Tensor,
578     typename TTypes<int32, 4>::ConstTensor,
579     typename TTypes<int32, 4>::ConstTensor, Eigen::DenseIndex,
580     Eigen::DenseIndex, Eigen::DenseIndex, Eigen::DenseIndex);
581 extern template struct SpatialConvolutionBackwardInputFunc<GPUDevice, int32>;
582 
583 template <>
584 void SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
585     GPUDevice, int32>::operator()(const GPUDevice&,
586                                   typename TTypes<int32, 4>::Tensor,
587                                   typename TTypes<int32, 4>::ConstTensor,
588                                   typename TTypes<int32, 4>::ConstTensor,
589                                   Eigen::DenseIndex, Eigen::DenseIndex,
590                                   Eigen::DenseIndex, Eigen::DenseIndex,
591                                   Eigen::DenseIndex, Eigen::DenseIndex,
592                                   Eigen::DenseIndex, Eigen::DenseIndex);
593 extern template struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
594     GPUDevice, int32>;
595 
596 }  // namespace functor
597 
598 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
599                             .Device(DEVICE_GPU)
600                             .TypeConstraint<double>("T")
601                             .HostMemory("input_sizes"),
602                         Conv2DBackpropInputOp<GPUDevice, double>);
603 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
604                             .Device(DEVICE_GPU)
605                             .TypeConstraint<float>("T")
606                             .HostMemory("input_sizes"),
607                         Conv2DBackpropInputOp<GPUDevice, float>);
608 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
609                             .Device(DEVICE_GPU)
610                             .TypeConstraint<Eigen::half>("T")
611                             .HostMemory("input_sizes"),
612                         Conv2DBackpropInputOp<GPUDevice, Eigen::half>);
613 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
614                             .Device(DEVICE_GPU)
615                             .TypeConstraint<int32>("T")
616                             .HostMemory("input_sizes"),
617                         Conv2DBackpropInputOp<GPUDevice, int32>);
618 
619 // To be used inside depthwise_conv_grad_op.cc.
620 // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
621 template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
622 template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
623 template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
624 
625 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
626 
627 }  // namespace tensorflow
628