1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define USE_EIGEN_TENSOR
19 #define EIGEN_USE_THREADS
20 
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24 
25 #include "tensorflow/core/kernels/conv_ops.h"
26 
27 #include <string.h>
28 
29 #include <atomic>
30 #include <map>
31 #include <vector>
32 
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/bounds_check.h"
35 #include "tensorflow/core/framework/kernel_shape_util.h"
36 #include "tensorflow/core/framework/numeric_op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/register_types.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_slice.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/kernels/conv_2d.h"
44 #include "tensorflow/core/kernels/deep_conv2d.h"
45 #include "tensorflow/core/kernels/ops_util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/strings/numbers.h"
49 #include "tensorflow/core/lib/strings/str_util.h"
50 #include "tensorflow/core/platform/logging.h"
51 #include "tensorflow/core/platform/macros.h"
52 #include "tensorflow/core/util/padding.h"
53 #include "tensorflow/core/util/tensor_format.h"
54 #include "tensorflow/core/util/use_cudnn.h"
55 
56 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
57 #include "tensorflow/core/kernels/xsmm_conv2d.h"
58 #endif
59 
60 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61 #include "tensorflow/core/kernels/conv_ops_gpu.h"
62 #include "tensorflow/core/platform/stream_executor.h"
63 #include "tensorflow/core/protobuf/autotuning.pb.h"
64 #include "tensorflow/core/util/proto/proto_utils.h"
65 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
66 #if GOOGLE_CUDA
67 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
68 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
69 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
70 #endif  // GOOGLE_CUDA
71 
72 namespace tensorflow {
73 
74 typedef Eigen::ThreadPoolDevice CPUDevice;
75 typedef Eigen::GpuDevice GPUDevice;
76 
77 namespace {
78 template <typename Device, typename T>
79 struct LaunchGeneric {
operator ()tensorflow::__anon7a02fd2b0111::LaunchGeneric80   void operator()(OpKernelContext* ctx, const Tensor& input,
81                   const Tensor& filter, int row_stride, int col_stride,
82                   int row_dilation, int col_dilation, const Padding& padding,
83                   const std::vector<int64>& explicit_paddings, Tensor* output,
84                   TensorFormat data_format) {
85     CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only "
86                                          "supports NHWC tensor format for now.";
87     if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
88         col_stride == 1 && (padding == SAME || padding == VALID)) {
89       // For 1x1 kernel, the 2D convolution is reduced to matrix
90       // multiplication.
91       //
92       // TODO(vrv): We should be able to call SpatialConvolution
93       // and it will produce the same result, but doing so
94       // led to NaNs during training.  Using matmul instead for now.
95       int conv_width = 1;  // Width for the convolution step.
96       for (int i = 0; i < 3; ++i) {
97         conv_width *= output->dim_size(i);
98       }
99 
100       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
101       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
102       functor::MatMulConvFunctor<Device, T>()(
103           ctx->eigen_device<Device>(),
104           output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
105           input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
106           filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
107           dim_pair);
108     } else if (filter.dim_size(0) == input.dim_size(1) &&
109                filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
110                col_dilation == 1 && padding == VALID) {
111       // If the input data and filter have the same height/width,
112       // the 2D convolution is reduced to matrix multiplication.
113       const int k =  // Length of reduction dimension.
114           filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2);
115 
116       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
117       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
118       functor::MatMulConvFunctor<Device, T>()(
119           ctx->eigen_device<Device>(),
120           output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
121           input.shaped<T, 2>({input.dim_size(0), k}),
122           filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair);
123     } else {
124       if (padding == EXPLICIT) {
125         functor::SpatialConvolution<Device, T>()(
126             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
127             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
128             row_dilation, col_dilation, static_cast<int>(explicit_paddings[2]),
129             static_cast<int>(explicit_paddings[3]),
130             static_cast<int>(explicit_paddings[4]),
131             static_cast<int>(explicit_paddings[5]));
132       } else {
133         functor::SpatialConvolution<Device, T>()(
134             ctx->eigen_device<Device>(), output->tensor<T, 4>(),
135             input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride, col_stride,
136             row_dilation, col_dilation, BrainPadding2EigenPadding(padding));
137       }
138     }
139   }
140 };
141 }  // namespace
142 
143 template <typename T>
144 struct LaunchConv2DOp<CPUDevice, T> {
operator ()tensorflow::LaunchConv2DOp145   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
146                   const Tensor& input, const Tensor& filter, int row_dilation,
147                   int col_dilation, int row_stride, int col_stride,
148                   const Padding& padding,
149                   const std::vector<int64>& explicit_paddings, Tensor* output,
150                   TensorFormat data_format) {
151     if (data_format != FORMAT_NHWC) {
152       ctx->SetStatus(errors::Unimplemented(
153           "The Conv2D op currently only supports the NHWC tensor format on the "
154           "CPU. The op was given the format: ",
155           ToString(data_format)));
156       return;
157     }
158     const int64 in_depth = GetTensorDim(input, data_format, 'C');
159     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
160                 errors::Unimplemented(
161                     "The Conv2D op currently does not support grouped "
162                     "convolutions on the CPU. A grouped convolution was "
163                     "attempted to be run because the input depth of ",
164                     in_depth, " does not match the filter input depth of ",
165                     filter.dim_size(2)));
166 
167     for (int64 explicit_padding : explicit_paddings) {
168       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
169         ctx->SetStatus(errors::InvalidArgument("filter too large"));
170         return;
171       }
172     }
173     LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
174                                   row_dilation, col_dilation, padding,
175                                   explicit_paddings, output, data_format);
176   }
177 };
178 
179 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
180 template <>
181 struct LaunchConv2DOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DOp182   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
183                   const Tensor& input, const Tensor& filter, int row_dilation,
184                   int col_dilation, int row_stride, int col_stride,
185                   const Padding& padding,
186                   const std::vector<int64>& explicit_paddings, Tensor* output,
187                   TensorFormat data_format) {
188     if (data_format != FORMAT_NHWC) {
189       ctx->SetStatus(
190           errors::Unimplemented("The Conv2D op currently only supports the "
191                                 "NHWC tensor format for integer types. "
192                                 "The op was given the format: ",
193                                 ToString(data_format)));
194       return;
195     }
196     const int64 in_depth = GetTensorDim(input, data_format, 'C');
197     OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
198                 errors::Unimplemented(
199                     "The Conv2D op currently does not support grouped "
200                     "convolutions for integer types. A grouped convolution was "
201                     "attempted to be run because the input depth of ",
202                     in_depth, " does not match the filter input depth of ",
203                     filter.dim_size(2)));
204 
205     for (int64 explicit_padding : explicit_paddings) {
206       if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
207         ctx->SetStatus(errors::InvalidArgument("filter too large"));
208         return;
209       }
210     }
211     LaunchGeneric<GPUDevice, int32>()(
212         ctx, input, filter, row_stride, col_stride, row_dilation, col_dilation,
213         padding, explicit_paddings, output, data_format);
214   }
215 };
216 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
217 
218 template <typename Device, typename T>
219 class LaunchDeepConvOp {
220  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int,int,int,int,int,int,Tensor *,TensorFormat)221   static bool Run(OpKernelContext* ctx, const Tensor& input,
222                   const Tensor& filter, int batch, int input_rows,
223                   int input_cols, int in_depth, int filter_rows,
224                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
225                   int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
226                   int /*dilation_cols*/, int /*stride_rows*/,
227                   int /*stride_cols*/, Tensor* /*output*/,
228                   TensorFormat /*data_format*/) {
229     return false;
230   }
231 };
232 
233 // Conditionally launches DeepConv operation based on convolution parameters.
234 template <>
235 class LaunchDeepConvOp<CPUDevice, float> {
236  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)237   static bool Run(OpKernelContext* ctx, const Tensor& input,
238                   const Tensor& filter, int batch, int input_rows,
239                   int input_cols, int in_depth, int filter_rows,
240                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
241                   int out_cols, int out_depth, int dilation_rows,
242                   int dilation_cols, int stride_rows, int stride_cols,
243                   Tensor* output, TensorFormat data_format) {
244     if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
245         dilation_cols != 1 ||
246         !CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
247                           in_depth, out_depth, out_rows, out_cols)) {
248       return false;
249     }
250 
251     Conv2DArgs args;
252     args.batch = batch;
253     args.in_rows = input_rows;
254     args.in_cols = input_cols;
255     args.in_depth = in_depth;
256     args.filter_rows = filter_rows;
257     args.filter_cols = filter_cols;
258     args.pad_rows = pad_rows;
259     args.pad_cols = pad_cols;
260     args.out_rows = out_rows;
261     args.out_cols = out_cols;
262     args.out_depth = out_depth;
263 
264     auto input_ptr = input.template flat<float>().data();
265     auto filter_ptr = filter.template flat<float>().data();
266     auto output_ptr = output->template flat<float>().data();
267 
268     functor::DeepConv2D<CPUDevice, float>()(ctx, args, input_ptr, filter_ptr,
269                                             output_ptr);
270     return true;
271   }
272 };
273 
274 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
275 template <typename Device, typename T>
276 class LaunchXsmmConvOp {
277  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int stride_rows,int stride_cols,int dilation_rows,int dilation_cols,Tensor * output,TensorFormat data_format)278   static bool Run(OpKernelContext* ctx, const Tensor& input,
279                   const Tensor& filter, int batch, int input_rows,
280                   int input_cols, int in_depth, int filter_rows,
281                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
282                   int out_cols, int out_depth, int stride_rows, int stride_cols,
283                   int dilation_rows, int dilation_cols, Tensor* output,
284                   TensorFormat data_format) {
285     return false;
286   }
287 };
288 
289 template <>
290 class LaunchXsmmConvOp<CPUDevice, float> {
291  public:
Run(OpKernelContext * ctx,const Tensor & input,const Tensor & filter,int batch,int input_rows,int input_cols,int in_depth,int filter_rows,int filter_cols,int pad_rows,int pad_cols,int out_rows,int out_cols,int out_depth,int dilation_rows,int dilation_cols,int stride_rows,int stride_cols,Tensor * output,TensorFormat data_format)292   static bool Run(OpKernelContext* ctx, const Tensor& input,
293                   const Tensor& filter, int batch, int input_rows,
294                   int input_cols, int in_depth, int filter_rows,
295                   int filter_cols, int pad_rows, int pad_cols, int out_rows,
296                   int out_cols, int out_depth, int dilation_rows,
297                   int dilation_cols, int stride_rows, int stride_cols,
298                   Tensor* output, TensorFormat data_format) {
299     auto num_threads =
300         ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
301     // See libxsmm_dnn.h for this struct definition.
302     libxsmm_dnn_conv_desc desc;
303     desc.N = batch;
304     desc.C = in_depth;
305     desc.H = input_rows;
306     desc.W = input_cols;
307     desc.K = out_depth;
308     desc.R = filter_rows;
309     desc.S = filter_cols;
310     desc.u = stride_rows;
311     desc.v = stride_cols;
312     desc.pad_h = pad_rows;
313     desc.pad_w = pad_cols;
314     desc.pad_h_in = 0;
315     desc.pad_w_in = 0;
316     desc.pad_h_out = 0;
317     desc.pad_w_out = 0;
318     desc.threads = num_threads;
319     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
320     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
321     desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;
322     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
323     desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
324     desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
325     desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
326     if (dilation_rows != 1 || dilation_cols != 1 ||
327         !CanUseXsmmConv2D(desc, data_format)) {
328       return false;
329     }
330 
331     auto input_ptr = input.template flat<float>().data();
332     auto filter_ptr = filter.template flat<float>().data();
333     auto output_ptr = output->template flat<float>().data();
334 
335     bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
336         ctx, desc, input_ptr, filter_ptr, output_ptr);
337     return success;
338   }
339 };
340 #endif
341 
342 #define TF_REQUIRES(EXP, STATUS)                \
343   do {                                          \
344     if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
345   } while (false)
346 
InitConv2DParameters(const OpKernelConstruction * context,Conv2DParameters * params)347 Status InitConv2DParameters(const OpKernelConstruction* context,
348                             Conv2DParameters* params) {
349   TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
350   TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
351   TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
352   if (context->HasAttr("explicit_paddings")) {
353     TF_RETURN_IF_ERROR(
354         context->GetAttr("explicit_paddings", &params->explicit_paddings));
355   }
356   string data_format_string;
357   TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
358   TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
359               errors::InvalidArgument("Invalid data format"));
360 
361   const auto& strides = params->strides;
362   const auto& dilations = params->dilations;
363   const auto& data_format = params->data_format;
364 
365   TF_REQUIRES(dilations.size() == 4,
366               errors::InvalidArgument("Sliding window dilations field must "
367                                       "specify 4 dimensions"));
368   TF_REQUIRES(strides.size() == 4,
369               errors::InvalidArgument("Sliding window strides field must "
370                                       "specify 4 dimensions"));
371   const int64 stride_n = GetTensorDim(strides, data_format, 'N');
372   const int64 stride_c = GetTensorDim(strides, data_format, 'C');
373   const int64 stride_h = GetTensorDim(strides, data_format, 'H');
374   const int64 stride_w = GetTensorDim(strides, data_format, 'W');
375   TF_REQUIRES(
376       stride_n == 1 && stride_c == 1,
377       errors::Unimplemented("Current implementation does not yet support "
378                             "strides in the batch and depth dimensions."));
379   TF_REQUIRES(stride_h > 0 && stride_w > 0,
380               errors::InvalidArgument(
381                   "Row and column strides should be larger than 0."));
382 
383   const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
384   const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
385   const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
386   const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
387   TF_REQUIRES(
388       dilation_n == 1 && dilation_c == 1,
389       errors::Unimplemented("Current implementation does not yet support "
390                             "dilations in the batch and depth dimensions."));
391   TF_REQUIRES(
392       dilation_h > 0 && dilation_w > 0,
393       errors::InvalidArgument("Dilated rates should be larger than 0."));
394 
395   TF_RETURN_IF_ERROR(CheckValidPadding(params->padding,
396                                        params->explicit_paddings,
397                                        /*num_dims=*/4, data_format));
398 
399   return Status::OK();
400 }
401 
ComputeConv2DDimension(const Conv2DParameters & params,const Tensor & input,const Tensor & filter,Conv2DDimensions * dimensions)402 Status ComputeConv2DDimension(const Conv2DParameters& params,
403                               const Tensor& input, const Tensor& filter,
404                               Conv2DDimensions* dimensions) {
405   // Check that 2D convolution input and filter have exactly 4 dimensions.
406   TF_REQUIRES(input.dims() == 4,
407               errors::InvalidArgument("input must be 4-dimensional",
408                                       input.shape().DebugString()));
409   TF_REQUIRES(filter.dims() == 4,
410               errors::InvalidArgument("filter must be 4-dimensional: ",
411                                       filter.shape().DebugString()));
412   for (int i = 0; i < 3; i++) {
413     TF_REQUIRES(
414         FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
415         errors::InvalidArgument("filter too large"));
416   }
417 
418   // The last dimension for input is in_depth. Check that it is the same as the
419   // filter's in_depth or it is evenly divisible by filter's in_depth.
420   const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
421   const int64 patch_depth_raw = filter.dim_size(2);
422   TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
423               errors::InvalidArgument("Input depth too large"));
424   TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
425               errors::InvalidArgument("Patch depth too large"));
426   const int in_depth = static_cast<int>(in_depth_raw);
427   const int patch_depth = static_cast<int>(patch_depth_raw);
428   TF_REQUIRES(in_depth % patch_depth == 0,
429               errors::InvalidArgument(
430                   "input depth must be evenly divisible by filter depth: ",
431                   in_depth, " vs ", patch_depth));
432 
433   // The last dimension for filter is out_depth.
434   const int out_depth = static_cast<int>(filter.dim_size(3));
435 
436   // The second dimension for input is rows/height.
437   // The first dimension for filter is rows/height.
438   const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
439   TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
440               errors::InvalidArgument("Input rows too large"));
441   const int input_rows = static_cast<int>(input_rows_raw);
442   const int filter_rows = static_cast<int>(filter.dim_size(0));
443 
444   // The third dimension for input is columns/width.
445   // The second dimension for filter is columns/width.
446   const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
447   TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
448               errors::InvalidArgument("Input cols too large"));
449   const int input_cols = static_cast<int>(input_cols_raw);
450   const int filter_cols = static_cast<int>(filter.dim_size(1));
451 
452   // The first dimension for input is batch.
453   const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
454   TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
455               errors::InvalidArgument("batch is too large"));
456   const int batch = static_cast<int>(batch_raw);
457 
458   // Take the stride and dilation from the second and third dimensions only (we
459   // do not support striding or dilation on the batch or depth dimension).
460   const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
461   const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
462   const int dilation_rows =
463       GetTensorDim(params.dilations, params.data_format, 'H');
464   const int dilation_cols =
465       GetTensorDim(params.dilations, params.data_format, 'W');
466 
467   int64 pad_rows_before, pad_rows_after, pad_cols_before, pad_cols_after;
468   if (params.padding == Padding::EXPLICIT) {
469     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'H',
470                              &pad_rows_before, &pad_rows_after);
471     GetExplicitPaddingForDim(params.explicit_paddings, params.data_format, 'W',
472                              &pad_cols_before, &pad_cols_after);
473   }
474 
475   // Compute windowed output sizes for rows and columns.
476   int64 out_rows = 0, out_cols = 0;
477   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
478       input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
479       &out_rows, &pad_rows_before, &pad_rows_after));
480   TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
481       input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
482       &out_cols, &pad_cols_before, &pad_cols_after));
483 
484   dimensions->batch = batch;
485   dimensions->input_rows = input_rows;
486   dimensions->input_cols = input_cols;
487   dimensions->in_depth = in_depth;
488   dimensions->filter_rows = filter_rows;
489   dimensions->filter_cols = filter_cols;
490   dimensions->patch_depth = patch_depth;
491   dimensions->out_depth = out_depth;
492   dimensions->stride_rows = stride_rows;
493   dimensions->stride_cols = stride_cols;
494   dimensions->dilation_rows = dilation_rows;
495   dimensions->dilation_cols = dilation_cols;
496   dimensions->out_rows = out_rows;
497   dimensions->out_cols = out_cols;
498   dimensions->pad_rows_before = pad_rows_before;
499   dimensions->pad_rows_after = pad_rows_after;
500   dimensions->pad_cols_before = pad_cols_before;
501   dimensions->pad_cols_after = pad_cols_after;
502 
503   return Status::OK();
504 }
505 
506 #undef TF_REQUIRES
507 
508 template <typename Device, typename T>
509 class Conv2DOp : public BinaryOp<T> {
510  public:
Conv2DOp(OpKernelConstruction * context)511   explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
512     OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
513 
514     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
515     cudnn_use_autotune_ = CudnnUseAutotune();
516   }
517 
Compute(OpKernelContext * context)518   void Compute(OpKernelContext* context) override {
519     // Input tensor is of the following dimensions:
520     // [ batch, in_rows, in_cols, in_depth ]
521     const Tensor& input = context->input(0);
522 
523     // Input filter is of the following dimensions:
524     // [ filter_rows, filter_cols, in_depth, out_depth]
525     const Tensor& filter = context->input(1);
526 
527     Conv2DDimensions dimensions;
528     OP_REQUIRES_OK(context,
529                    ComputeConv2DDimension(params_, input, filter, &dimensions));
530 
531     TensorShape out_shape = ShapeFromFormat(
532         params_.data_format, dimensions.batch, dimensions.out_rows,
533         dimensions.out_cols, dimensions.out_depth);
534 
535     // Output tensor is of the following dimensions:
536     // [ in_batch, out_rows, out_cols, out_depth ]
537     Tensor* output = nullptr;
538     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
539 
540     VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
541             << ", patch_depth = " << dimensions.patch_depth
542             << ", input_cols = " << dimensions.input_cols
543             << ", filter_cols = " << dimensions.filter_cols
544             << ", input_rows = " << dimensions.input_rows
545             << ", filter_rows = " << dimensions.filter_rows
546             << ", stride_rows = " << dimensions.stride_rows
547             << ", stride_cols = " << dimensions.stride_cols
548             << ", dilation_rows = " << dimensions.dilation_rows
549             << ", dilation_cols = " << dimensions.dilation_cols
550             << ", out_depth = " << dimensions.out_depth;
551 
552     // If there is nothing to compute, return.
553     if (out_shape.num_elements() == 0) {
554       return;
555     }
556 
557 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
558     if (params_.padding != EXPLICIT &&
559         LaunchXsmmConvOp<Device, T>::Run(
560             context, input, filter, dimensions.batch, dimensions.input_rows,
561             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
562             dimensions.filter_cols, dimensions.pad_rows_before,
563             dimensions.pad_cols_before, dimensions.out_rows,
564             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
565             dimensions.dilation_cols, dimensions.stride_rows,
566             dimensions.stride_cols, output, params_.data_format)) {
567       return;
568     }
569 #endif
570 
571     if (params_.padding != EXPLICIT &&
572         LaunchDeepConvOp<Device, T>::Run(
573             context, input, filter, dimensions.batch, dimensions.input_rows,
574             dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
575             dimensions.filter_cols, dimensions.pad_rows_before,
576             dimensions.pad_cols_before, dimensions.out_rows,
577             dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows,
578             dimensions.dilation_cols, dimensions.stride_rows,
579             dimensions.stride_cols, output, params_.data_format)) {
580       return;
581     }
582 
583     launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
584               dimensions.dilation_rows, dimensions.dilation_cols,
585               dimensions.stride_rows, dimensions.stride_cols, params_.padding,
586               params_.explicit_paddings, output, params_.data_format);
587   }
588 
589  private:
590   Conv2DParameters params_;
591   bool use_cudnn_;
592   bool cudnn_use_autotune_;
593 
594   LaunchConv2DOp<Device, T> launcher_;
595 
596   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
597 };
598 
599 #define REGISTER_CPU(T)                                         \
600   REGISTER_KERNEL_BUILDER(                                      \
601       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
602       Conv2DOp<CPUDevice, T>);
603 
604 // If we're using the alternative GEMM-based implementation of Conv2D for the
605 // CPU implementation, don't register this EigenTensor-based version.
606 #if !defined(USE_GEMM_FOR_CONV)
607 TF_CALL_half(REGISTER_CPU);
608 TF_CALL_float(REGISTER_CPU);
609 TF_CALL_double(REGISTER_CPU);
610 TF_CALL_int32(REGISTER_CPU);
611 #endif  // USE_GEMM_FOR_CONV
612 
613 // To be used inside depthwise_conv_op.cc.
614 template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
615 template struct LaunchConv2DOp<CPUDevice, float>;
616 template struct LaunchConv2DOp<CPUDevice, double>;
617 
618 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
619 
GetDnnWorkspaceLimit(const string & envvar_in_mb,int64 default_value_in_bytes)620 int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
621                            int64 default_value_in_bytes) {
622   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
623   if (workspace_limit_in_mb_str != nullptr &&
624       strcmp(workspace_limit_in_mb_str, "") != 0) {
625     int64 scratch_limit_in_mb = -1;
626     if (strings::safe_strto64(workspace_limit_in_mb_str,
627                               &scratch_limit_in_mb)) {
628       return scratch_limit_in_mb * (1 << 20);
629     } else {
630       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
631                    << workspace_limit_in_mb_str;
632     }
633   }
634   return default_value_in_bytes;
635 }
636 
637 // A dummy type to group forward convolution autotune results together.
638 struct ConvAutoTuneGroup {
nametensorflow::ConvAutoTuneGroup639   static string name() { return "Conv"; }
640 };
641 typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
642                           se::dnn::AlgorithmConfig>
643     AutoTuneConv;
644 
645 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & input_param,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64> & explicit_paddings,Tensor * output,TensorFormat data_format)646 void LaunchConv2DOp<GPUDevice, T>::operator()(
647     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
648     const Tensor& input_param, const Tensor& filter, int row_dilation,
649     int col_dilation, int row_stride, int col_stride, const Padding& padding,
650     const std::vector<int64>& explicit_paddings, Tensor* output,
651     TensorFormat data_format) {
652   using se::dnn::AlgorithmConfig;
653   using se::dnn::AlgorithmDesc;
654   using se::dnn::ProfileResult;
655   auto* stream = ctx->op_device_context()->stream();
656   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
657 
658   if (!use_cudnn) {
659     ctx->SetStatus(
660         errors::Unimplemented("Conv2D for GPU is not currently supported "
661                               "without cudnn"));
662     return;
663   }
664 
665   Tensor input = input_param;
666   const int64 in_batch = GetTensorDim(input, data_format, 'N');
667   int64 in_rows = GetTensorDim(input, data_format, 'H');
668   int64 in_cols = GetTensorDim(input, data_format, 'W');
669   const int64 in_depths = GetTensorDim(input, data_format, 'C');
670   const int64 patch_rows = filter.dim_size(0);
671   const int64 patch_cols = filter.dim_size(1);
672   const int64 patch_depths = filter.dim_size(2);
673 
674   // If the filter in-depth (patch_depths) is 1 and smaller than the input
675   // depth, it's a depthwise convolution. More generally, if the filter in-depth
676   // divides but is smaller than the input depth, it is a grouped convolution.
677   bool is_grouped_convolution = patch_depths != in_depths;
678   if (patch_rows == 1 && patch_cols == 1 && !is_grouped_convolution &&
679       row_dilation == 1 && col_dilation == 1 && row_stride == 1 &&
680       col_stride == 1 && data_format == FORMAT_NHWC &&
681       (padding == VALID || padding == SAME)) {
682     // 1x1 filter, so call cublas directly.
683     const uint64 m = in_batch * in_rows * in_cols;
684     const uint64 k = patch_depths;
685     const uint64 n = filter.dim_size(3);
686 
687     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
688                                 input.template flat<T>().size());
689     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
690                                 filter.template flat<T>().size());
691     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
692                                 output->template flat<T>().size());
693 
694     auto no_transpose = se::blas::Transpose::kNoTranspose;
695     bool blas_launch_status =
696         stream
697             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
698                            a_ptr, k, 0.0f, &c_ptr, n)
699             .ok();
700     if (!blas_launch_status) {
701       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
702                                       ", n=", n, ", k=", k));
703     }
704     return;
705   } else if (patch_rows == in_rows && patch_cols == in_cols &&
706              !is_grouped_convolution && row_dilation == 1 &&
707              col_dilation == 1 && padding == VALID &&
708              data_format == FORMAT_NHWC) {
709     // The input data and filter have the same height/width, so call cublas
710     // directly.
711     const uint64 m = in_batch;
712     const uint64 k = patch_rows * patch_cols * patch_depths;
713     const uint64 n = filter.dim_size(3);
714 
715     auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
716                                 input.template flat<T>().size());
717     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
718                                 filter.template flat<T>().size());
719     auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
720                                 output->template flat<T>().size());
721 
722     auto no_transpose = se::blas::Transpose::kNoTranspose;
723     bool blas_launch_status =
724         stream
725             ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
726                            a_ptr, k, 0.0f, &c_ptr, n)
727             .ok();
728     if (!blas_launch_status) {
729       ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
730                                       ", n=", n, ", k=", k));
731     }
732     return;
733   }
734 
735 #if GOOGLE_CUDA
736   // Tensor Core (NVIDIA Volta+ GPUs) supports efficient convolution with fp16
737   // in NHWC data layout. In all other configurations it's more efficient to
738   // run computation in NCHW data format.
739   const bool compute_in_nhwc =
740       DataTypeToEnum<T>::value == DT_HALF && IsVoltaOrLater(*stream->parent());
741 #else
742   // fast NHWC implementation is a CUDA only feature
743   const bool compute_in_nhwc = false;
744 #endif
745 
746   // We only do one directional conversion: NHWC->NCHW. We never convert in the
747   // other direction. Grappler layout optimizer selects preferred layout and
748   // adds necessary annotations to the graph.
749   // TODO(ezhulenev): Convert in other direction for fp16?
750   const TensorFormat compute_data_format =
751       (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
752                                                       : FORMAT_NCHW;
753 
754   VLOG(3) << "Compute Conv2D with cuDNN:"
755           << " data_format=" << ToString(data_format)
756           << " compute_data_format=" << ToString(compute_data_format);
757 
758   const int64 out_batch = GetTensorDim(*output, data_format, 'N');
759   const int64 out_rows = GetTensorDim(*output, data_format, 'H');
760   const int64 out_cols = GetTensorDim(*output, data_format, 'W');
761   const int64 out_depths = GetTensorDim(*output, data_format, 'C');
762   int64 padding_top = -1, padding_bottom = -1;
763   int64 padding_left = -1, padding_right = -1;
764   if (padding == EXPLICIT) {
765     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
766                              &padding_bottom);
767     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
768                              &padding_right);
769   }
770   int64 out_rows_check, out_cols_check;
771   Status status = GetWindowedOutputSizeVerboseV2(
772       in_rows, patch_rows, row_dilation, row_stride, padding, &out_rows_check,
773       &padding_top, &padding_bottom);
774   // The status is guaranteed to be OK because we checked the output and padding
775   // was valid earlier.
776   TF_CHECK_OK(status);
777   DCHECK_EQ(out_rows, out_rows_check);
778   status = GetWindowedOutputSizeVerboseV2(in_cols, patch_cols, col_dilation,
779                                           col_stride, padding, &out_cols_check,
780                                           &padding_left, &padding_right);
781   TF_CHECK_OK(status);
782   DCHECK_EQ(out_cols, out_cols_check);
783 
784   const int64 common_padding_rows = std::min(padding_top, padding_bottom);
785   const int64 common_padding_cols = std::min(padding_left, padding_right);
786   if (padding_top != padding_bottom || padding_left != padding_right) {
787     // cuDNN only supports padding the same amount on the left and right sides,
788     // and on the top and bottom sides. So we manually create a new padded
789     // input tensor such that we can pass it to cuDNN.
790     VLOG(4) << "Pad input tensor:"
791             << " padding_top=" << padding_top
792             << " padding_bottom=" << padding_bottom
793             << " padding_left=" << padding_left
794             << " padding_right=" << padding_right;
795 
796     // TODO(reedwm): In some cases, we can avoid an allocation even if the two
797     // padding sides are different. For example, if the input is 2x2, the filter
798     // is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the result is
799     // equivalent to as if the padding is (1, 1, 1, 1). Changing the padding in
800     // such a way would allow us to avoid the allocation.
801     Tensor transformed_input;
802     const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
803     const int64 padding_cols_diff = std::abs(padding_right - padding_left);
804     const int64 new_in_rows = in_rows + padding_rows_diff;
805     const int64 new_in_cols = in_cols + padding_cols_diff;
806     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
807                             DataTypeToEnum<T>::value,
808                             ShapeFromFormat(data_format, in_batch, new_in_rows,
809                                             new_in_cols, in_depths),
810                             &transformed_input));
811 
812     const int64 input_pad_top = padding_top - common_padding_rows;
813     const int64 input_pad_bottom = padding_bottom - common_padding_rows;
814     const int64 input_pad_left = padding_left - common_padding_cols;
815     const int64 input_pad_right = padding_right - common_padding_cols;
816     bool in_bounds =
817         FastBoundsCheck(input_pad_top, std::numeric_limits<int>::max()) &&
818         FastBoundsCheck(input_pad_bottom, std::numeric_limits<int>::max()) &&
819         FastBoundsCheck(input_pad_left, std::numeric_limits<int>::max()) &&
820         FastBoundsCheck(input_pad_right, std::numeric_limits<int>::max());
821     if (!in_bounds) {
822       ctx->SetStatus(errors::InvalidArgument("Padding is too large."));
823       return;
824     }
825     functor::PadInput<GPUDevice, T, int, 4>()(
826         ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
827         {{static_cast<int>(input_pad_top), static_cast<int>(input_pad_left)}},
828         {{static_cast<int>(input_pad_bottom),
829           static_cast<int>(input_pad_right)}},
830         To32Bit(transformed_input.tensor<T, 4>()), data_format, T{});
831 
832     input = transformed_input;
833     in_rows = new_in_rows;
834     in_cols = new_in_cols;
835   }
836 
837   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
838     VLOG(4) << "Convert the input tensor from NHWC to NCHW.";
839 
840     TensorShape nchw_shape =
841         ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
842     if (in_depths > 1) {
843       Tensor transformed_input;
844       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
845                                              nchw_shape, &transformed_input));
846       functor::NHWCToNCHW<GPUDevice, T, 4>()(
847           ctx->eigen_device<GPUDevice>(),
848           const_cast<const Tensor&>(input).tensor<T, 4>(),
849           transformed_input.tensor<T, 4>());
850       input = transformed_input;
851     } else {
852       // If depth <= 1, then just reshape.
853       CHECK(input.CopyFrom(input, nchw_shape));
854     }
855   } else {
856     CHECK(data_format == compute_data_format)  // Crash OK
857         << "Illegal data and compute format pair:"
858         << " data_format=" << ToString(data_format)
859         << " compute_data_format=" << ToString(compute_data_format);
860   }
861 
862   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
863       << "Negative row or col paddings: (" << common_padding_rows << ", "
864       << common_padding_cols << ")";
865 
866   constexpr auto kComputeInNHWC =
867       std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
868                       se::dnn::FilterLayout::kOutputYXInput);
869   constexpr auto kComputeInNCHW =
870       std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
871                       se::dnn::FilterLayout::kOutputInputYX);
872 
873   se::dnn::DataLayout compute_data_layout;
874   se::dnn::FilterLayout filter_layout;
875 
876   std::tie(compute_data_layout, filter_layout) =
877       compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
878 
879   se::dnn::BatchDescriptor input_desc;
880   input_desc.set_count(in_batch)
881       .set_feature_map_count(in_depths)
882       .set_height(in_rows)
883       .set_width(in_cols)
884       .set_layout(compute_data_layout);
885   se::dnn::BatchDescriptor output_desc;
886   output_desc.set_count(out_batch)
887       .set_height(out_rows)
888       .set_width(out_cols)
889       .set_feature_map_count(out_depths)
890       .set_layout(compute_data_layout);
891   se::dnn::FilterDescriptor filter_desc;
892   filter_desc.set_input_filter_height(patch_rows)
893       .set_input_filter_width(patch_cols)
894       .set_input_feature_map_count(patch_depths)
895       .set_output_feature_map_count(filter.dim_size(3))
896       .set_layout(filter_layout);
897   se::dnn::ConvolutionDescriptor conv_desc;
898   conv_desc.set_vertical_dilation_rate(row_dilation)
899       .set_horizontal_dilation_rate(col_dilation)
900       .set_vertical_filter_stride(row_stride)
901       .set_horizontal_filter_stride(col_stride)
902       .set_zero_padding_height(common_padding_rows)
903       .set_zero_padding_width(common_padding_cols)
904       .set_group_count(in_depths / patch_depths);
905 
906   Tensor transformed_filter;
907 
908   const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
909     VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
910             << " to " << ToString(dst_format);
911 
912     TensorShape dst_shape =
913         dst_format == FORMAT_OIHW
914             ? TensorShape({filter.dim_size(3), filter.dim_size(2),
915                            filter.dim_size(0), filter.dim_size(1)})
916             : TensorShape({filter.dim_size(3), filter.dim_size(0),
917                            filter.dim_size(1), filter.dim_size(2)});
918 
919     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
920                                           &transformed_filter));
921     functor::TransformFilter<GPUDevice, T, int, 4>()(
922         ctx->eigen_device<GPUDevice>(), dst_format,
923         To32Bit(filter.tensor<T, 4>()),
924         To32Bit(transformed_filter.tensor<T, 4>()));
925 
926     return Status::OK();
927   };
928 
929   if (compute_data_format == FORMAT_NCHW) {
930     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
931   } else if (compute_data_format == FORMAT_NHWC) {
932     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
933   } else {
934     ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
935                                            ToString(compute_data_format)));
936     return;
937   }
938 
939   Tensor transformed_output;
940   if (data_format != compute_data_format) {
941     VLOG(4) << "Allocate temporary memory for output in compute data format";
942     OP_REQUIRES_OK(
943         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
944                                 ShapeFromFormat(compute_data_format, out_batch,
945                                                 out_rows, out_cols, out_depths),
946                                 &transformed_output));
947   } else {
948     transformed_output = *output;
949   }
950 
951   auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
952                                   input.template flat<T>().size());
953   auto filter_ptr =
954       AsDeviceMemory(transformed_filter.template flat<T>().data(),
955                      transformed_filter.template flat<T>().size());
956   auto output_ptr =
957       AsDeviceMemory(transformed_output.template flat<T>().data(),
958                      transformed_output.template flat<T>().size());
959 
960   static int64 ConvolveScratchSize = GetDnnWorkspaceLimit(
961       // default value is in bytes despite the name of the environment variable
962       "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
963   );
964 
965   int device_id = stream->parent()->device_ordinal();
966   DataType dtype = input.dtype();
967   ConvParameters conv_parameters = {in_batch,             // batch
968                                     in_depths,            // in_depths
969                                     {{in_rows,            // in_rows
970                                       in_cols}},          // in_cols
971                                     compute_data_format,  // compute_data_format
972                                     out_depths,           // out_depths
973                                     {{patch_rows,         // filter_rows
974                                       patch_cols,         // filter_cols
975                                       patch_depths}},     // filter_depths
976                                     {{row_dilation,       // dilation_rows
977                                       col_dilation}},     // dilation_cols
978                                     {{row_stride,         // stride_rows
979                                       col_stride}},       // stride_cols
980                                     {{common_padding_rows,    // padding_rows
981                                       common_padding_cols}},  // padding_cols
982                                     dtype,                    // tensor datatype
983                                     device_id,                // device_id
984                                     conv_desc.group_count()};
985   AlgorithmConfig algorithm_config;
986 #if TENSORFLOW_USE_ROCM
987   // cudnn_use_autotune is applicable only the CUDA flow
988   // for ROCm/MIOpen, we need to call GetMIOpenConvolveAlgorithms explicitly
989   // if we do not have a cached algorithm_config for this conv_parameters
990   cudnn_use_autotune = true;
991 #endif
992   if (cudnn_use_autotune &&
993       !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
994 #if GOOGLE_CUDA
995     std::vector<AlgorithmDesc> algorithms;
996     OP_REQUIRES(
997         ctx,
998         stream->parent()->GetConvolveAlgorithms(
999             conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
1000                 stream->parent()),
1001             &algorithms),
1002         errors::Unknown("Failed to get convolution algorithm. This is probably "
1003                         "because cuDNN failed to initialize, so try looking to "
1004                         "see if a warning log message was printed above."));
1005 
1006     se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
1007                                                 stream);
1008     se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
1009                                       se::GpuAsmOpts());
1010     se::DeviceMemory<T> output_tensor(
1011         WrapRedzoneBestEffort(&rz_allocator, output_ptr));
1012 
1013     std::vector<tensorflow::AutotuneResult> results;
1014     for (const auto& profile_algorithm : algorithms) {
1015       // TODO(zhengxq): profile each algorithm multiple times to better
1016       // accuracy.
1017       se::RedzoneAllocator rz_scratch_allocator(
1018           stream, &tf_allocator_adapter, se::GpuAsmOpts(),
1019           /*memory_limit=*/ConvolveScratchSize);
1020       DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1021       se::ScratchAllocator* allocator_used =
1022           !RedzoneCheckDisabled()
1023               ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
1024               : static_cast<se::ScratchAllocator*>(&scratch_allocator);
1025 
1026       ProfileResult profile_result;
1027       auto cudnn_launch_status = stream->ConvolveWithAlgorithm(
1028           input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1029           output_desc, &output_tensor, allocator_used,
1030           AlgorithmConfig(profile_algorithm), &profile_result);
1031       if (cudnn_launch_status.ok() && profile_result.is_valid()) {
1032         results.emplace_back();
1033         auto& result = results.back();
1034         result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1035         result.mutable_conv()->set_tensor_ops_enabled(
1036             profile_algorithm.tensor_ops_enabled());
1037 
1038         result.set_scratch_bytes(
1039             !RedzoneCheckDisabled()
1040                 ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
1041                 : scratch_allocator.TotalByteSize());
1042         *result.mutable_run_time() = proto_utils::ToDurationProto(
1043             absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1044 
1045         CheckRedzones(rz_scratch_allocator, &result);
1046         CheckRedzones(rz_allocator, &result);
1047       }
1048     }
1049 
1050 #elif TENSORFLOW_USE_ROCM
1051     DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1052 
1053     std::vector<ProfileResult> algorithms;
1054     OP_REQUIRES(
1055         ctx,
1056         stream->parent()->GetMIOpenConvolveAlgorithms(
1057             se::dnn::ConvolutionKind::FORWARD, se::dnn::ToDataType<T>::value,
1058             stream, input_desc, input_ptr, filter_desc, filter_ptr, output_desc,
1059             output_ptr, conv_desc, &scratch_allocator, &algorithms),
1060         errors::Unknown(
1061             "Failed to get convolution algorithm. This is probably "
1062             "because MIOpen failed to initialize, so try looking to "
1063             "see if a warning log message was printed above."));
1064     se::DeviceMemory<T> output_tensor = output_ptr;
1065 
1066     std::vector<tensorflow::AutotuneResult> results;
1067     if (algorithms.size() == 1) {
1068       auto profile_result = algorithms[0];
1069       results.emplace_back();
1070       auto& result = results.back();
1071       result.mutable_conv()->set_algorithm(
1072           profile_result.algorithm().algo_id());
1073       result.mutable_conv()->set_tensor_ops_enabled(
1074           profile_result.algorithm().tensor_ops_enabled());
1075 
1076       result.set_scratch_bytes(profile_result.scratch_size());
1077       *result.mutable_run_time() = proto_utils::ToDurationProto(
1078           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1079     } else {
1080       for (auto miopen_algorithm : algorithms) {
1081         auto profile_algorithm = miopen_algorithm.algorithm();
1082         ProfileResult profile_result;
1083         auto miopen_launch_status = stream->ConvolveWithAlgorithm(
1084             input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
1085             output_desc, &output_ptr, &scratch_allocator,
1086             AlgorithmConfig(profile_algorithm, miopen_algorithm.scratch_size()),
1087             &profile_result);
1088         if (miopen_launch_status.ok() && profile_result.is_valid()) {
1089           results.emplace_back();
1090           auto& result = results.back();
1091           result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
1092           result.mutable_conv()->set_tensor_ops_enabled(
1093               profile_algorithm.tensor_ops_enabled());
1094 
1095           result.set_scratch_bytes(scratch_allocator.TotalByteSize());
1096           *result.mutable_run_time() = proto_utils::ToDurationProto(
1097               absl::Milliseconds(profile_result.elapsed_time_in_ms()));
1098         }
1099       }
1100     }
1101 #endif
1102     LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
1103                            se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
1104                            output_tensor, input_desc, filter_desc, output_desc,
1105                            conv_desc, stream->parent(), results);
1106     OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
1107     AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
1108   }
1109 
1110   VLOG(4) << "Convolution Algorithm: "
1111           << algorithm_config.algorithm()->algo_id();
1112   VLOG(4) << "tensor_ops_enabled: "
1113           << algorithm_config.algorithm()->tensor_ops_enabled();
1114 
1115   DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
1116   auto cudnn_launch_status = stream->ConvolveWithAlgorithm(
1117       input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, output_desc,
1118       &output_ptr, &scratch_allocator, algorithm_config, nullptr);
1119 
1120   if (!cudnn_launch_status.ok()) {
1121     ctx->SetStatus(cudnn_launch_status);
1122   }
1123 
1124   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
1125     VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
1126     functor::NCHWToNHWC<GPUDevice, T, 4>()(
1127         ctx->eigen_device<GPUDevice>(),
1128         const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
1129         output->tensor<T, 4>());
1130   }
1131 }
1132 
1133 // Forward declarations of the functor specializations for GPU.
1134 namespace functor {
1135 #define DECLARE_GPU_SPEC(T)                                                 \
1136   template <>                                                               \
1137   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1138       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1139       typename TTypes<T, 4>::ConstTensor input,                             \
1140       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1141       int col_stride, int row_dilation, int col_dilation,                   \
1142       const Eigen::PaddingType& padding,                                    \
1143       const Eigen::NoOpOutputKernel& output_kernel);                        \
1144   template <>                                                               \
1145   void SpatialConvolution<GPUDevice, T>::operator()(                        \
1146       const GPUDevice& d, typename TTypes<T, 4>::Tensor output,             \
1147       typename TTypes<T, 4>::ConstTensor input,                             \
1148       typename TTypes<T, 4>::ConstTensor filter, int row_stride,            \
1149       int col_stride, int row_dilation, int col_dilation, int padding_top,  \
1150       int padding_bottom, int padding_left, int padding_right,              \
1151       const Eigen::NoOpOutputKernel& output_kernel);                        \
1152   extern template struct SpatialConvolution<GPUDevice, T>;                  \
1153   template <>                                                               \
1154   void MatMulConvFunctor<GPUDevice, T>::operator()(                         \
1155       const GPUDevice& d, typename TTypes<T, 2>::Tensor out,                \
1156       typename TTypes<T, 2>::ConstTensor in0,                               \
1157       typename TTypes<T, 2>::ConstTensor in1,                               \
1158       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, \
1159       const Eigen::NoOpOutputKernel& output_kernel);                        \
1160   extern template struct MatMulConvFunctor<GPUDevice, T>;                   \
1161   template <>                                                               \
1162   void TransformFilter<GPUDevice, T, int, 4>::operator()(                   \
1163       const GPUDevice& d, FilterTensorFormat dst_filter_format,             \
1164       typename TTypes<T, 4, int>::ConstTensor in,                           \
1165       typename TTypes<T, 4, int>::Tensor out);                              \
1166   extern template struct TransformFilter<GPUDevice, T, int, 4>;             \
1167   template <>                                                               \
1168   void PadInput<GPUDevice, T, int, 4>::operator()(                          \
1169       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,       \
1170       const std::array<int, 2>& padding_left,                               \
1171       const std::array<int, 2>& padding_right,                              \
1172       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format,     \
1173       const T& padding_value);                                              \
1174   extern template struct PadInput<GPUDevice, T, int, 4>
1175 
1176 DECLARE_GPU_SPEC(float);
1177 DECLARE_GPU_SPEC(Eigen::half);
1178 DECLARE_GPU_SPEC(double);
1179 DECLARE_GPU_SPEC(int32);
1180 #undef DECLARE_GPU_SPEC
1181 
1182 }  // namespace functor
1183 
1184 // Registration of the GPU implementations.
1185 REGISTER_KERNEL_BUILDER(
1186     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
1187     Conv2DOp<GPUDevice, Eigen::half>);
1188 REGISTER_KERNEL_BUILDER(
1189     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
1190     Conv2DOp<GPUDevice, float>);
1191 REGISTER_KERNEL_BUILDER(
1192     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<double>("T"),
1193     Conv2DOp<GPUDevice, double>);
1194 REGISTER_KERNEL_BUILDER(
1195     Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<int32>("T"),
1196     Conv2DOp<GPUDevice, int32>);
1197 
1198 // To be used inside depthwise_conv_op.cc.
1199 template struct LaunchConv2DOp<GPUDevice, float>;
1200 template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
1201 template struct LaunchConv2DOp<GPUDevice, double>;
1202 
1203 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1204 
1205 }  // namespace tensorflow
1206