1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/nn_ops.cc.
17
18 #include "tensorflow/core/kernels/conv_grad_input_ops.h"
19
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #include "tensorflow/core/protobuf/autotuning.pb.h"
22 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24 namespace tensorflow {
25
26 typedef Eigen::ThreadPoolDevice CPUDevice;
27 typedef Eigen::GpuDevice GPUDevice;
28
29 // To be used inside depthwise_conv_grad_op.cc.
30 template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
31 template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
32 template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
33
34 // GPU definitions.
35 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
36 // The slow version (but compiles for GPU)
37
38 // A dummy type to group forward backward data autotune results together.
39 struct ConvBackwardDataAutoTuneGroup {
nametensorflow::ConvBackwardDataAutoTuneGroup40 static string name() { return "ConvBwdData"; }
41 };
42 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
43 se::dnn::AlgorithmConfig>
44 AutoTuneConvBwdData;
45
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on GPU
48 // for int32 inputs.
49 template <>
50 struct LaunchConv2DBackpropInputOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DBackpropInputOp51 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
52 const Tensor& out_backprop, const Tensor& filter,
53 int row_dilation, int col_dilation, int row_stride,
54 int col_stride, const Padding& padding,
55 const std::vector<int64>& explicit_paddings,
56 Tensor* in_backprop, TensorFormat data_format) {
57 LaunchConv2DBackpropInputOpImpl<GPUDevice, int32> launcher;
58 launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
59 row_dilation, col_dilation, row_stride, col_stride, padding,
60 explicit_paddings, in_backprop, data_format);
61 }
62 };
63 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
64
65 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & out_backprop,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * in_backprop,TensorFormat data_format)66 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
67 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
68 const Tensor& out_backprop, const Tensor& filter, int row_dilation,
69 int col_dilation, int row_stride, int col_stride, const Padding& padding,
70 const std::vector<int64>& explicit_paddings, Tensor* in_backprop,
71 TensorFormat data_format) {
72 using se::dnn::AlgorithmConfig;
73 using se::dnn::AlgorithmDesc;
74 using se::dnn::ProfileResult;
75
76 std::vector<int32> strides(4, 1);
77 std::vector<int32> dilations(4, 1);
78 auto input_h = GetTensorDimIndex(data_format, 'H');
79 auto input_w = GetTensorDimIndex(data_format, 'W');
80 strides[input_h] = row_stride;
81 strides[input_w] = col_stride;
82 dilations[input_h] = row_dilation;
83 dilations[input_w] = col_dilation;
84 TensorShape input_shape = in_backprop->shape();
85
86 const TensorShape& filter_shape = filter.shape();
87 ConvBackpropDimensions dims;
88 OP_REQUIRES_OK(
89 ctx, ConvBackpropComputeDimensionsV2(
90 "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape,
91 filter_shape, out_backprop.shape(), dilations, strides, padding,
92 explicit_paddings, data_format, &dims));
93
94 int64 padding_top = -1, padding_bottom = -1;
95 int64 padding_left = -1, padding_right = -1;
96 if (padding == EXPLICIT) {
97 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
98 &padding_bottom);
99 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
100 &padding_right);
101 }
102 int64 expected_out_rows, expected_out_cols;
103 // The function is guaranteed to succeed because we checked the output and
104 // padding was valid earlier.
105 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
106 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
107 row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
108 &padding_bottom));
109 DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
110 TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
111 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
112 col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
113 &padding_right));
114 DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
115
116 auto* stream = ctx->op_device_context()->stream();
117 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
118
119 if (!use_cudnn) {
120 ctx->SetStatus(errors::Unimplemented(
121 "Conv2DBackpropInput for GPU is not currently supported "
122 "without cudnn"));
123 return;
124 }
125
126 // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
127 // input depth, it's a depthwise convolution. More generally, if the filter
128 // in-depth divides but is smaller than the input depth, it is a grouped
129 // convolution.
130 bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
131 if (dims.spatial_dims[0].filter_size == 1 &&
132 dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
133 dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
134 data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
135 // 1x1 filter, so call cublas directly.
136 const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
137 dims.spatial_dims[1].input_size;
138 const uint64 k = dims.out_depth;
139 const uint64 n = dims.in_depth;
140
141 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
142 out_backprop.template flat<T>().size());
143 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
144 filter.template flat<T>().size());
145 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
146 in_backprop->template flat<T>().size());
147
148 auto transpose = se::blas::Transpose::kTranspose;
149 auto no_transpose = se::blas::Transpose::kNoTranspose;
150
151 bool blas_launch_status =
152 stream
153 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
154 a_ptr, k, 0.0f, &c_ptr, n)
155 .ok();
156 if (!blas_launch_status) {
157 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
158 ", n=", n, ", k=", k));
159 }
160 return;
161 } else if (dims.spatial_dims[0].filter_size ==
162 dims.spatial_dims[0].input_size &&
163 dims.spatial_dims[1].filter_size ==
164 dims.spatial_dims[1].input_size &&
165 !is_grouped_convolution && padding == VALID &&
166 data_format == FORMAT_NHWC) {
167 // The input data and filter have the same height/width, and we are not
168 // using grouped convolution, so call cublas directly.
169 const uint64 m = dims.batch_size;
170 const uint64 k = dims.out_depth;
171 const uint64 n = dims.spatial_dims[0].input_size *
172 dims.spatial_dims[1].input_size * dims.in_depth;
173
174 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
175 out_backprop.template flat<T>().size());
176 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
177 filter.template flat<T>().size());
178 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
179 in_backprop->template flat<T>().size());
180
181 auto transpose = se::blas::Transpose::kTranspose;
182 auto no_transpose = se::blas::Transpose::kNoTranspose;
183
184 bool blas_launch_status =
185 stream
186 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k,
187 a_ptr, k, 0.0f, &c_ptr, n)
188 .ok();
189 if (!blas_launch_status) {
190 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
191 ", n=", n, ", k=", k));
192 }
193 return;
194 }
195
196 const int64 common_padding_rows = std::min(padding_top, padding_bottom);
197 const int64 common_padding_cols = std::min(padding_left, padding_right);
198 TensorShape compatible_input_shape;
199 if (padding_top != padding_bottom || padding_left != padding_right) {
200 // Pad the input in the same way we did during the forward pass, so that
201 // cuDNN or MIOpen receives the same input during the backward pass function
202 // as it did during the forward pass function.
203 const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
204 const int64 padding_cols_diff = std::abs(padding_right - padding_left);
205 const int64 new_in_rows =
206 dims.spatial_dims[0].input_size + padding_rows_diff;
207 const int64 new_in_cols =
208 dims.spatial_dims[1].input_size + padding_cols_diff;
209 compatible_input_shape = ShapeFromFormat(
210 data_format, dims.batch_size, new_in_rows, new_in_cols, dims.in_depth);
211 } else {
212 compatible_input_shape = input_shape;
213 }
214
215 CHECK(common_padding_rows >= 0 && common_padding_cols >= 0) // Crash OK
216 << "Negative row or col paddings: (" << common_padding_rows << ", "
217 << common_padding_cols << ")";
218
219 // The Tensor Core in NVIDIA Volta+ GPUs supports efficient convolution with
220 // fp16 in NHWC data layout. In all other configurations it's more efficient
221 // to run computation in NCHW data format.
222 const bool compute_in_nhwc =
223 DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
224
225 // We only do one directional conversion: NHWC->NCHW. We never convert in the
226 // other direction. Grappler layout optimizer selects the preferred layout and
227 // adds necessary annotations to the graph.
228 const TensorFormat compute_data_format =
229 (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
230 : FORMAT_NCHW;
231
232 VLOG(3) << "Compute Conv2DBackpropInput with cuDNN:"
233 << " data_format=" << ToString(data_format)
234 << " compute_data_format=" << ToString(compute_data_format);
235
236 constexpr auto kComputeInNHWC =
237 std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
238 se::dnn::FilterLayout::kOutputYXInput);
239 constexpr auto kComputeInNCHW =
240 std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
241 se::dnn::FilterLayout::kOutputInputYX);
242
243 se::dnn::DataLayout compute_data_layout;
244 se::dnn::FilterLayout filter_layout;
245
246 std::tie(compute_data_layout, filter_layout) =
247 compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
248
249 se::dnn::BatchDescriptor input_desc;
250 input_desc.set_count(dims.batch_size)
251 .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
252 .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
253 .set_feature_map_count(dims.in_depth)
254 .set_layout(compute_data_layout);
255 se::dnn::BatchDescriptor output_desc;
256 output_desc.set_count(dims.batch_size)
257 .set_height(dims.spatial_dims[0].output_size)
258 .set_width(dims.spatial_dims[1].output_size)
259 .set_feature_map_count(dims.out_depth)
260 .set_layout(compute_data_layout);
261 se::dnn::FilterDescriptor filter_desc;
262 filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
263 .set_input_filter_width(dims.spatial_dims[1].filter_size)
264 .set_input_feature_map_count(filter_shape.dim_size(2))
265 .set_output_feature_map_count(filter_shape.dim_size(3))
266 .set_layout(filter_layout);
267 se::dnn::ConvolutionDescriptor conv_desc;
268 conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
269 .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
270 .set_vertical_filter_stride(dims.spatial_dims[0].stride)
271 .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
272 .set_zero_padding_height(common_padding_rows)
273 .set_zero_padding_width(common_padding_cols)
274 .set_group_count(dims.in_depth / filter_shape.dim_size(2));
275
276 // Tensorflow filter format: HWIO
277 // cuDNN filter formats: (data format) -> (filter format)
278 // (1) NCHW -> OIHW
279 // (2) NHWC -> OHWI
280
281 Tensor transformed_filter;
282 const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
283 VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
284 << " to " << ToString(dst_format);
285
286 TensorShape dst_shape =
287 dst_format == FORMAT_OIHW
288 ? TensorShape({filter.dim_size(3), filter.dim_size(2),
289 filter.dim_size(0), filter.dim_size(1)})
290 : TensorShape({filter.dim_size(3), filter.dim_size(0),
291 filter.dim_size(1), filter.dim_size(2)});
292
293 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
294 &transformed_filter));
295 functor::TransformFilter<GPUDevice, T, int, 4>()(
296 ctx->eigen_device<GPUDevice>(), dst_format,
297 To32Bit(filter.tensor<T, 4>()),
298 To32Bit(transformed_filter.tensor<T, 4>()));
299
300 return Status::OK();
301 };
302
303 if (compute_data_format == FORMAT_NCHW) {
304 OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
305 } else if (compute_data_format == FORMAT_NHWC) {
306 OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
307 } else {
308 ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
309 ToString(compute_data_format)));
310 return;
311 }
312
313 Tensor transformed_out_backprop;
314 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
315 VLOG(4) << "Convert the `out_backprop` tensor from NHWC to NCHW.";
316 TensorShape compute_shape = ShapeFromFormat(
317 compute_data_format, dims.batch_size, dims.spatial_dims[0].output_size,
318 dims.spatial_dims[1].output_size, dims.out_depth);
319 if (dims.out_depth > 1) {
320 OP_REQUIRES_OK(ctx,
321 ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape,
322 &transformed_out_backprop));
323 functor::NHWCToNCHW<GPUDevice, T, 4>()(
324 ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
325 transformed_out_backprop.tensor<T, 4>());
326 } else {
327 // If depth <= 1, then just reshape.
328 CHECK(transformed_out_backprop.CopyFrom(out_backprop, compute_shape));
329 }
330 } else {
331 transformed_out_backprop = out_backprop;
332 }
333
334 Tensor pre_transformed_in_backprop;
335 OP_REQUIRES_OK(
336 ctx, ctx->allocate_temp(
337 DataTypeToEnum<T>::value,
338 ShapeFromFormat(
339 compute_data_format,
340 GetTensorDim(compatible_input_shape, data_format, 'N'),
341 GetTensorDim(compatible_input_shape, data_format, 'H'),
342 GetTensorDim(compatible_input_shape, data_format, 'W'),
343 GetTensorDim(compatible_input_shape, data_format, 'C')),
344 &pre_transformed_in_backprop));
345
346 auto out_backprop_ptr =
347 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
348 transformed_out_backprop.template flat<T>().size());
349 auto filter_ptr =
350 AsDeviceMemory(transformed_filter.template flat<T>().data(),
351 transformed_filter.template flat<T>().size());
352 auto in_backprop_ptr =
353 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
354 pre_transformed_in_backprop.template flat<T>().size());
355
356 static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit(
357 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
358 );
359 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
360 int device_id = stream->parent()->device_ordinal();
361 DataType dtype = out_backprop.dtype();
362 ConvParameters conv_parameters = {
363 dims.batch_size, // batch
364 dims.in_depth, // in_depths
365 {{input_desc.height(), // in_rows
366 input_desc.width()}}, // in_cols
367 compute_data_format, // compute_data_format
368 dims.out_depth, // out_depths
369 {{dims.spatial_dims[0].filter_size, // filter_rows
370 dims.spatial_dims[1].filter_size, // filter_cols
371 filter_shape.dim_size(2)}}, // filter_depths
372 {{dims.spatial_dims[0].dilation, // dilation_rows
373 dims.spatial_dims[1].dilation}}, // dilation_cols
374 {{dims.spatial_dims[0].stride, // stride_rows
375 dims.spatial_dims[1].stride}}, // stride_cols
376 {{common_padding_rows, // padding_rows
377 common_padding_cols}}, // padding_cols
378 dtype, // tensor data type
379 device_id, // device_id
380 conv_desc.group_count() // group_count
381 };
382 #if TENSORFLOW_USE_ROCM
383 // cudnn_use_autotune is applicable only the CUDA flow
384 // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
385 // if we do not have a cached algorithm_config for this conv_parameters
386 cudnn_use_autotune = true;
387 #endif
388 AlgorithmConfig algorithm_config;
389 if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
390 conv_parameters, &algorithm_config)) {
391 #if GOOGLE_CUDA
392
393 se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
394 stream);
395
396 se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
397 se::GpuAsmOpts());
398
399 se::DeviceMemory<T> in_backprop_ptr_rz(
400 WrapRedzoneBestEffort(&rz_allocator, in_backprop_ptr));
401
402 std::vector<AlgorithmDesc> algorithms;
403 CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
404 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
405 &algorithms));
406 std::vector<tensorflow::AutotuneResult> results;
407 for (const auto& profile_algorithm : algorithms) {
408 // TODO(zhengxq): profile each algorithm multiple times to better
409 // accuracy.
410 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
411 ctx);
412 se::RedzoneAllocator rz_scratch_allocator(
413 stream, &tf_allocator_adapter, se::GpuAsmOpts(),
414 /*memory_limit=*/ConvolveBackwardDataScratchSize);
415 se::ScratchAllocator* allocator_used =
416 !RedzoneCheckDisabled()
417 ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
418 : static_cast<se::ScratchAllocator*>(&scratch_allocator);
419 ProfileResult profile_result;
420 auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
421 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
422 input_desc, &in_backprop_ptr_rz, allocator_used,
423 AlgorithmConfig(profile_algorithm), &profile_result);
424 if (cudnn_launch_status.ok() && profile_result.is_valid()) {
425 results.emplace_back();
426 auto& result = results.back();
427 result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
428 result.mutable_conv()->set_tensor_ops_enabled(
429 profile_algorithm.tensor_ops_enabled());
430 result.set_scratch_bytes(
431 !RedzoneCheckDisabled()
432 ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
433 : scratch_allocator.TotalByteSize());
434 *result.mutable_run_time() = proto_utils::ToDurationProto(
435 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
436
437 CheckRedzones(rz_scratch_allocator, &result);
438 CheckRedzones(rz_allocator, &result);
439 }
440 }
441 #elif TENSORFLOW_USE_ROCM
442 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
443 std::vector<ProfileResult> algorithms;
444 OP_REQUIRES(
445 ctx,
446 stream->parent()->GetMIOpenConvolveAlgorithms(
447 se::dnn::ConvolutionKind::BACKWARD_DATA,
448 se::dnn::ToDataType<T>::value, stream, input_desc, in_backprop_ptr,
449 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
450 &scratch_allocator, &algorithms),
451 errors::Unknown(
452 "Failed to get convolution algorithm. This is probably "
453 "because MIOpen failed to initialize, so try looking to "
454 "see if a warning log message was printed above."));
455
456 std::vector<tensorflow::AutotuneResult> results;
457 if (algorithms.size() == 1) {
458 auto profile_result = algorithms[0];
459 results.emplace_back();
460 auto& result = results.back();
461 result.mutable_conv()->set_algorithm(
462 profile_result.algorithm().algo_id());
463 result.mutable_conv()->set_tensor_ops_enabled(
464 profile_result.algorithm().tensor_ops_enabled());
465
466 result.set_scratch_bytes(profile_result.scratch_size());
467 *result.mutable_run_time() = proto_utils::ToDurationProto(
468 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
469 } else {
470 for (auto miopen_algorithm : algorithms) {
471 auto profile_algorithm = miopen_algorithm.algorithm();
472 ProfileResult profile_result;
473 auto miopen_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
474 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
475 input_desc, &in_backprop_ptr, &scratch_allocator,
476 AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
477 &profile_result);
478
479 if (miopen_launch_status.ok() && profile_result.is_valid()) {
480 results.emplace_back();
481 auto& result = results.back();
482 result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
483 result.mutable_conv()->set_tensor_ops_enabled(
484 profile_algorithm.tensor_ops_enabled());
485 result.set_scratch_bytes(scratch_allocator.TotalByteSize());
486 *result.mutable_run_time() = proto_utils::ToDurationProto(
487 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
488 }
489 }
490 }
491 #endif
492 LogConvAutotuneResults(
493 se::dnn::ConvolutionKind::BACKWARD_DATA, se::dnn::ToDataType<T>::value,
494 in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc,
495 output_desc, conv_desc, stream->parent(), results);
496 OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
497 AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
498 algorithm_config);
499 }
500 auto cudnn_launch_status = stream->ConvolveBackwardDataWithAlgorithm(
501 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc,
502 input_desc, &in_backprop_ptr, &scratch_allocator, algorithm_config,
503 nullptr);
504
505 if (!cudnn_launch_status.ok()) {
506 ctx->SetStatus(cudnn_launch_status);
507 return;
508 }
509
510 if (padding_top != padding_bottom || padding_left != padding_right) {
511 Tensor in_backprop_remove_padding;
512 OP_REQUIRES_OK(
513 ctx, ctx->allocate_temp(
514 DataTypeToEnum<T>::value,
515 ShapeFromFormat(compute_data_format,
516 GetTensorDim(input_shape, data_format, 'N'),
517 GetTensorDim(input_shape, data_format, 'H'),
518 GetTensorDim(input_shape, data_format, 'W'),
519 GetTensorDim(input_shape, data_format, 'C')),
520 &in_backprop_remove_padding));
521
522 // Remove the padding that was added to the input shape above.
523 const int64 input_pad_top = padding_top - common_padding_rows;
524 const int64 input_pad_bottom = padding_bottom - common_padding_rows;
525 const int64 input_pad_left = padding_left - common_padding_cols;
526 const int64 input_pad_right = padding_right - common_padding_cols;
527 functor::PadInput<GPUDevice, T, int, 4>()(
528 ctx->template eigen_device<GPUDevice>(),
529 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
530 .tensor<T, 4>()),
531 {{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
532 {{static_cast<int>(-input_pad_bottom),
533 static_cast<int>(-input_pad_right)}},
534 To32Bit(in_backprop_remove_padding.tensor<T, 4>()), compute_data_format,
535 T{});
536
537 pre_transformed_in_backprop = in_backprop_remove_padding;
538 }
539
540 if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
541 VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
542 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
543 functor::NCHWToNHWC<GPUDevice, T, 4>()(
544 ctx->eigen_device<GPUDevice>(),
545 toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
546 in_backprop->tensor<T, 4>());
547 } else {
548 *in_backprop = pre_transformed_in_backprop;
549 }
550 }
551
552 // Forward declarations of the functor specializations for GPU.
553 namespace functor {
554 #define DECLARE_GPU_SPEC(T) \
555 template <> \
556 void TransformFilter<GPUDevice, T, int, 4>::operator()( \
557 const GPUDevice& d, FilterTensorFormat dst_filter_format, \
558 typename TTypes<T, 4, int>::ConstTensor in, \
559 typename TTypes<T, 4, int>::Tensor out); \
560 extern template struct TransformFilter<GPUDevice, T, int, 4>; \
561 template <> \
562 void PadInput<GPUDevice, T, int, 4>::operator()( \
563 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
564 const std::array<int, 2>& padding_left, \
565 const std::array<int, 2>& padding_right, \
566 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
567 const T& padding_value); \
568 extern template struct PadInput<GPUDevice, T, int, 4>;
569
570 DECLARE_GPU_SPEC(float);
571 DECLARE_GPU_SPEC(Eigen::half);
572 DECLARE_GPU_SPEC(double);
573 #undef DECLARE_GPU_SPEC
574
575 template <>
576 void SpatialConvolutionBackwardInputFunc<GPUDevice, int32>::operator()(
577 const GPUDevice&, typename TTypes<int32, 4>::Tensor,
578 typename TTypes<int32, 4>::ConstTensor,
579 typename TTypes<int32, 4>::ConstTensor, Eigen::DenseIndex,
580 Eigen::DenseIndex, Eigen::DenseIndex, Eigen::DenseIndex);
581 extern template struct SpatialConvolutionBackwardInputFunc<GPUDevice, int32>;
582
583 template <>
584 void SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
585 GPUDevice, int32>::operator()(const GPUDevice&,
586 typename TTypes<int32, 4>::Tensor,
587 typename TTypes<int32, 4>::ConstTensor,
588 typename TTypes<int32, 4>::ConstTensor,
589 Eigen::DenseIndex, Eigen::DenseIndex,
590 Eigen::DenseIndex, Eigen::DenseIndex,
591 Eigen::DenseIndex, Eigen::DenseIndex,
592 Eigen::DenseIndex, Eigen::DenseIndex);
593 extern template struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
594 GPUDevice, int32>;
595
596 } // namespace functor
597
598 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
599 .Device(DEVICE_GPU)
600 .TypeConstraint<double>("T")
601 .HostMemory("input_sizes"),
602 Conv2DBackpropInputOp<GPUDevice, double>);
603 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
604 .Device(DEVICE_GPU)
605 .TypeConstraint<float>("T")
606 .HostMemory("input_sizes"),
607 Conv2DBackpropInputOp<GPUDevice, float>);
608 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
609 .Device(DEVICE_GPU)
610 .TypeConstraint<Eigen::half>("T")
611 .HostMemory("input_sizes"),
612 Conv2DBackpropInputOp<GPUDevice, Eigen::half>);
613 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
614 .Device(DEVICE_GPU)
615 .TypeConstraint<int32>("T")
616 .HostMemory("input_sizes"),
617 Conv2DBackpropInputOp<GPUDevice, int32>);
618
619 // To be used inside depthwise_conv_grad_op.cc.
620 // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
621 template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
622 template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
623 template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
624
625 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
626
627 } // namespace tensorflow
628