1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/depthwise_conv_op.h"
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <type_traits>
23 
24 #include "tensorflow/core/framework/bounds_check.h"
25 #include "tensorflow/core/framework/kernel_shape_util.h"
26 #include "tensorflow/core/framework/numeric_op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/tensor_types.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/kernels/conv_ops.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/padding.h"
38 #include "tensorflow/core/util/tensor_format.h"
39 #include "tensorflow/core/util/use_cudnn.h"
40 #include "tensorflow/core/util/work_sharder.h"
41 
42 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
43 
44 #if GOOGLE_CUDA
45 #include "third_party/gpus/cudnn/cudnn.h"
46 #endif
47 
48 #include "tensorflow/core/platform/stream_executor.h"
49 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
50 
51 namespace tensorflow {
52 
53 // In depthwise convolution, one input is convolved into depth_multipler
54 // outputs and the outputs don't need to be reduced again like what regular
55 // convolution does.
56 //  However, the way to apply filters to inputs is exactly the same as the
57 // regular convolution. Please refer to the regular convolution kernels for
58 // more details.
59 
60 typedef Eigen::ThreadPoolDevice CPUDevice;
61 typedef Eigen::GpuDevice GPUDevice;
62 
63 // Computes the vectorized product of 'input_buffer' and 'filter' and stores
64 // result in 'output' at location specified by 'out_r' and 'out_c'.
65 //
66 // EX:
67 //   in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
68 //   Both 'input_buffer' and 'filter' are padded to register-width boundaries.
69 //
70 //   input_buffer [rows, cols, in_depth, depth_multiplier]
71 //     [a0, a0, a1, a1] [a2, a2, 0, 0] [b0, b0, b1, b1] [b2, b2, 0, 0]
72 //     [e0, e0, e1, e1] [e2, e2, 0, 0] [f0, f0, f1, f1] [f2, f2, 0, 0]
73 //
74 //   filter [rows, cols, in_depth, depth_multiplier]
75 //     [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0]
76 //     [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0]
77 //
78 //   First output register [in_depth, depth_multiplier]
79 //     [q0, q1, q2, q3] = ([a0, a0, a1, a1] x [u0, v0, w0, x0]) +
80 //                        ([b0, b0, b1, b1] x [u1, v1, w1, x1]) +
81 //                        ([e0, e0, e1, e1] x [u2, v2, w2, x2]) +
82 //                        ([f0, f0, f1, f1] x [u3, v3, w3, x3])
83 //
84 // TODO(andydavis) Experiment with processing multiple inputs per input buffer.
85 template <typename T>
86 struct DepthwiseConv2DKernel {
Runtensorflow::DepthwiseConv2DKernel87   static void Run(const DepthwiseArgs& args,
88                   const int64 padded_filter_inner_dim_size, const int64 out_r,
89                   const int64 out_c, const T* filter, const T* input_buffer,
90                   T* output, TensorFormat data_format) {
91     typedef typename Eigen::internal::packet_traits<T>::type Packet;
92     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
93 
94     const int64 out_depth = args.out_depth;
95     const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
96     const int64 output_scalar_size = out_depth % kPacketSize;
97     const int64 output_vectorized_size =
98         (out_depth / kPacketSize) * kPacketSize;
99     const int64 base_output_index = (out_r * args.out_cols + out_c) * out_depth;
100 
101     for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
102       // Reset accumulator.
103       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
104       for (int j = 0; j < filter_spatial_size; ++j) {
105         // Calculate index.
106         const int64 index = i + j * padded_filter_inner_dim_size;
107         // Load filter.
108         // TODO(andydavis) Unroll 'out_c' loop in caller so we can load
109         // multiple inputs here to amortize the cost of each filter block load.
110         const auto filter_block =
111             Eigen::internal::ploadu<Packet>(filter + index);
112         // Load input.
113         const auto data_block =
114             Eigen::internal::ploadu<Packet>(input_buffer + index);
115         // Vector multiply-add.
116         vaccum =
117             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
118       }
119       // Store vector accumulator to output.
120       Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum);
121     }
122 
123     if (output_scalar_size > 0) {
124       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
125       for (int j = 0; j < filter_spatial_size; ++j) {
126         const int64 index =
127             output_vectorized_size + j * padded_filter_inner_dim_size;
128         const auto filter_block =
129             Eigen::internal::ploadu<Packet>(filter + index);
130         const auto data_block =
131             Eigen::internal::ploadu<Packet>(input_buffer + index);
132         vaccum =
133             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
134       }
135       // Load accumulator into an array and loop through output.
136       T out_buf[kPacketSize];
137       Eigen::internal::pstoreu<T>(out_buf, vaccum);
138       const int64 last_output_index =
139           base_output_index + output_vectorized_size;
140       for (int j = 0; j < output_scalar_size; ++j) {
141         output[last_output_index + j] = out_buf[j];
142       }
143     }
144   }
145 };
146 
147 // Computes the depthwise conv2d of 'input' by 'depthwise_filter' and stores
148 // the result in 'output'. This implementation trades off copying small patches
149 // of the input to achieve better data alignment, which enables vectorized
150 // load/store and multiply-add operations (see comments at InputBufferCopyOp and
151 // DepthwiseConv2DKernel for details).
152 //
153 // TODO(andydavis) Evaluate the performance of processing multiple input
154 // patches in the inner loop.
155 // TODO(andydavis) Consider a zero-copy implementation for the case when
156 // 'in_depth' is a multiple of register width, and 'depth_multipler' is one.
157 // TODO(andydavis) Evaluate the performance of alternative implementations.
158 template <typename T>
159 struct LaunchDepthwiseConvOp<CPUDevice, T> {
160   typedef typename Eigen::internal::packet_traits<T>::type Packet;
161 
operator ()tensorflow::LaunchDepthwiseConvOp162   void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
163                   const T* input, const T* depthwise_filter, T* output,
164                   TensorFormat data_format) {
165     OP_REQUIRES(
166         ctx, data_format == FORMAT_NHWC,
167         errors::Unimplemented(
168             "Depthwise convolution on CPU is only supported for NHWC format"));
169     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
170 
171     // Pad 'depthwise_filter' to vector register width (if needed).
172     const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true;
173     Tensor padded_filter;
174     if (pad_filter) {
175       // Allocate space for padded filter.
176       const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
177       const int64 padded_filter_inner_dim_size =
178           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
179       OP_REQUIRES_OK(
180           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
181                                   TensorShape({filter_spatial_size,
182                                                padded_filter_inner_dim_size}),
183                                   &padded_filter));
184       // Write out padded filter.
185       functor::DepthwiseFilterPadOp<T>()(
186           args, depthwise_filter, padded_filter.template flat<T>().data());
187     }
188     const T* filter_data =
189         pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter;
190 
191     // Computes one shard of depthwise conv2d output.
192     auto shard = [&ctx, &args, &input, &filter_data, &output, data_format](
193                      int64 start, int64 limit) {
194       static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
195       const int64 input_image_size =
196           args.in_rows * args.in_cols * args.in_depth;
197       const int64 output_image_size =
198           args.out_rows * args.out_cols * args.out_depth;
199       const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
200       const int64 padded_filter_inner_dim_size =
201           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
202 
203       // Allocate buffer for local input regions.
204       Tensor input_buffer;
205       OP_REQUIRES_OK(
206           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
207                                   TensorShape({filter_spatial_size,
208                                                padded_filter_inner_dim_size}),
209                                   &input_buffer));
210       T* input_buffer_data = input_buffer.template flat<T>().data();
211 
212       for (int64 i = start; i < limit; ++i) {
213         const int64 b = i / args.out_rows;
214         const int64 in_base = b * input_image_size;
215         const int64 out_base = b * output_image_size;
216 
217         const int64 out_r = i % args.out_rows;
218 
219         for (int64 out_c = 0; out_c < args.out_cols; ++out_c) {
220           // Populate 'input_buffer_data' with data from local input region.
221           functor::DepthwiseInputCopyOp<T>()(args, padded_filter_inner_dim_size,
222                                              out_r, out_c, input + in_base,
223                                              input_buffer_data);
224 
225           // Process buffered input across all filters and store to output.
226           DepthwiseConv2DKernel<T>::Run(
227               args, padded_filter_inner_dim_size, out_r, out_c, filter_data,
228               input_buffer_data, output + out_base, data_format);
229         }
230       }
231     };
232 
233     const int64 total_shards = args.batch * args.out_rows;
234 
235     // Empirically tested to give reasonable performance boosts at batch size 1
236     // without reducing throughput at batch size 32.
237     const float kCostMultiplier = 2.5f;
238 
239     // TODO(andydavis): Estimate shard cost (in cycles) based on the number of
240     // flops/loads/stores required to compute one shard.
241     const int64 shard_cost = kCostMultiplier * args.out_cols * args.out_depth;
242 
243     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
244     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
245           shard_cost, shard);
246   }
247 };
248 
249 // Extern template instantiated in conv_ops.cc.
250 extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
251 extern template struct LaunchConv2DOp<CPUDevice, float>;
252 extern template struct LaunchConv2DOp<CPUDevice, double>;
253 
254 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255 
256 // Extern template instantiated in conv_ops.cc.
257 extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
258 extern template struct LaunchConv2DOp<GPUDevice, float>;
259 extern template struct LaunchConv2DOp<GPUDevice, double>;
260 
261 // Extern template instantiated in depthwise_conv_op_gpu.cc.
262 extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
263 extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
264 extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
265 
266 #endif
267 
268 template <typename Device, typename T>
269 class DepthwiseConv2dNativeOp : public BinaryOp<T> {
270  public:
DepthwiseConv2dNativeOp(OpKernelConstruction * context)271   explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context)
272       : BinaryOp<T>(context) {
273     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
274     string data_format;
275     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
276     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
277                 errors::InvalidArgument("Invalid data format"));
278 
279     OP_REQUIRES(context, strides_.size() == 4,
280                 errors::InvalidArgument("Sliding window strides field must "
281                                         "specify 4 dimensions"));
282     stride_ = GetTensorDim(strides_, data_format_, 'H');
283     const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
284     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
285     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
286 
287     OP_REQUIRES(context, stride_ == stride_w,
288                 errors::InvalidArgument(
289                     "Current implementation only supports equal length "
290                     "strides in the row and column dimensions."));
291     OP_REQUIRES(
292         context, (stride_n == 1 && stride_c == 1),
293         errors::InvalidArgument("Current implementation does not yet support "
294                                 "strides in the batch and depth dimensions."));
295     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
296     OP_REQUIRES_OK(context,
297                    context->GetAttr("explicit_paddings", &explicit_paddings_));
298     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
299                                               /*num_dims=*/4, data_format_));
300 
301     cudnn_use_autotune_ = CudnnUseAutotune();
302     dtype_ = DataTypeToEnum<T>::value;
303 #if CUDNN_VERSION >= 8000
304     // From the cuDNN release note 8.0: We’ve extended the fprop and dgrad
305     // NHWC depthwise kernels to support more combinations (filter
306     // sizes/strides) such as 5x5/1x1, 5x5/2x2, 7x7/1x1, 7x7/2x2 (in addition
307     // to what we already have, 1x1/1x1, 3x3/1x1, 3x3/2x2), which provides
308     // good performance. (https://docs.nvidia.com/deeplearning/sdk/cudnn-
309     // release-notes/rel_8.html#rel_8)
310     use_cudnn_grouped_conv_ =
311         dtype_ == DT_HALF &&
312         (data_format_ == FORMAT_NCHW ||
313          (data_format_ == FORMAT_NHWC && stride_ == stride_w &&
314           (stride_ == 1 || stride_ == 2)));
315 #elif CUDNN_VERSION >= 7603
316     // Use CuDNN grouped conv only when input/output is NCHW and float16(half).
317     // See cudnn release note 7.6.3. (https://docs.nvidia.com/deeplearning/sdk/c
318     // udnn-release-notes/rel_763.html#rel_763)
319     use_cudnn_grouped_conv_ = dtype_ == DT_HALF && data_format_ == FORMAT_NCHW;
320 #else
321     use_cudnn_grouped_conv_ = false;
322 #endif
323   }
324 
Compute(OpKernelContext * context)325   void Compute(OpKernelContext* context) override {
326     // Input tensor is of the following dimensions:
327     // [ batch, in_rows, in_cols, in_depth ]
328     const Tensor& input = context->input(0);
329 
330     // Input filter is of the following dimensions:
331     // [ filter_rows, filter_cols, in_depth, depth_multiplier]
332     const Tensor& filter = context->input(1);
333 
334     // For 2D convolution, there should be 4 dimensions.
335     OP_REQUIRES(context, input.dims() == 4,
336                 errors::InvalidArgument("input must be 4-dimensional",
337                                         input.shape().DebugString()));
338     OP_REQUIRES(context, filter.dims() == 4,
339                 errors::InvalidArgument("filter must be 4-dimensional: ",
340                                         filter.shape().DebugString()));
341 
342     // in_depth for input and filter must match.
343     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
344     OP_REQUIRES(context, in_depth == filter.dim_size(2),
345                 errors::InvalidArgument(
346                     "input and filter must have the same depth: ", in_depth,
347                     " vs ", filter.dim_size(2)));
348 
349     // The last dimension for filter is depth multiplier.
350     const int32 depth_multiplier = filter.dim_size(3);
351 
352     // The output depth is input depth x depth multiplier
353     const int32 out_depth = in_depth * depth_multiplier;
354 
355     const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
356     OP_REQUIRES(
357         context,
358         FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()),
359         errors::InvalidArgument("Input rows too large"));
360     const int32 input_rows = static_cast<int32>(input_rows_raw);
361     const int32 filter_rows = filter.dim_size(0);
362 
363     const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
364     OP_REQUIRES(
365         context,
366         FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()),
367         errors::InvalidArgument("Input cols too large"));
368     const int32 input_cols = static_cast<int32>(input_cols_raw);
369     const int32 filter_cols = filter.dim_size(1);
370 
371     // The first dimension for input is batch.
372     const int32 batch = input.dim_size(0);
373 
374     int64 out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0, pad_left = 0,
375           pad_right = 0;
376     if (padding_ == Padding::EXPLICIT) {
377       GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top,
378                                &pad_bottom);
379       GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left,
380                                &pad_right);
381     }
382     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
383                                 input_rows, filter_rows, stride_, padding_,
384                                 &out_rows, &pad_top, &pad_bottom));
385     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
386                                 input_cols, filter_cols, stride_, padding_,
387                                 &out_cols, &pad_left, &pad_right));
388     TensorShape out_shape =
389         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
390     OP_REQUIRES(
391         context,
392         (!std::is_same<Device, GPUDevice>::value ||
393          FastBoundsCheck(out_shape.num_elements(),
394                          std::numeric_limits<int32>::max())),
395         errors::InvalidArgument("Output elements too large for GPU kernel"));
396 
397     Tensor* output = nullptr;
398     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
399 
400     // If there is nothing to compute, return.
401     if (out_shape.num_elements() == 0) {
402       return;
403     }
404 
405     // TODO(csigg): Have autotune decide if native is faster than cuDNN.
406     // If in_depth==1, this operation is just a standard convolution.
407     // Depthwise convolution is a special case of cuDNN's grouped convolution.
408     bool use_cudnn = std::is_same<Device, GPUDevice>::value &&
409                      (in_depth == 1 ||
410                       (use_cudnn_grouped_conv_ &&
411                        IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows,
412                                                   /*filter_cols=*/filter_cols,
413                                                   /*in_depth=*/in_depth,
414                                                   /*out_depth=*/out_depth)));
415 
416     VLOG(2) << "DepthwiseConv2dNative: "
417             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
418             << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
419             << filter_cols << ", " << in_depth << ", " << depth_multiplier
420             << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols
421             << ", " << out_depth << "], stride = " << stride_
422             << ", pad_top = " << pad_top << ", pad_left = " << pad_left
423             << ", Use cuDNN: " << use_cudnn;
424 
425     if (use_cudnn) {
426       // Reshape from TF depthwise filter to cuDNN grouped convolution filter:
427       //
428       //                  | TensorFlow       | cuDNN
429       // --------------------------------------------------------------------
430       // filter_out_depth | depth_multiplier | depth_multiplier * group_count
431       // filter_in_depth  | in_depth         | in_depth / group_count
432       //
433       // For depthwise convolution, we have group_count == in_depth.
434       int32 filter_in_depth = 1;
435       TensorShape shape =
436           TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth};
437       Tensor reshaped_filter(/*type=*/dtype_);
438       OP_REQUIRES(
439           context, reshaped_filter.CopyFrom(filter, shape),
440           errors::Internal(
441               "Failed to reshape filter tensor for grouped convolution."));
442       // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
443       // conv is supported.
444       launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, input,
445                 reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1,
446                 stride_, stride_, padding_, explicit_paddings_, output,
447                 data_format_);
448       return;
449     }
450 
451     DepthwiseArgs args;
452     args.batch = batch;
453     args.in_rows = input_rows;
454     args.in_cols = input_cols;
455     args.in_depth = in_depth;
456     args.filter_rows = filter_rows;
457     args.filter_cols = filter_cols;
458     args.depth_multiplier = depth_multiplier;
459     args.stride = stride_;
460     args.pad_rows = pad_top;
461     args.pad_cols = pad_left;
462     args.out_rows = out_rows;
463     args.out_cols = out_cols;
464     args.out_depth = out_depth;
465 
466     auto input_ptr = input.template flat<T>().data();
467     auto filter_ptr = filter.template flat<T>().data();
468     auto output_ptr = output->template flat<T>().data();
469     LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
470                                        output_ptr, data_format_);
471   }
472 
473  protected:
474   bool use_cudnn_grouped_conv_;
475 
476  private:
477   std::vector<int32> strides_;
478   Padding padding_;
479   std::vector<int64> explicit_paddings_;
480   TensorFormat data_format_;
481 
482   int64 stride_;  // in height/width dimension.
483 
484   // For in_depth == 1 and grouped convolutions.
485   LaunchConv2DOp<Device, T> launcher_;
486   bool cudnn_use_autotune_;
487   DataType dtype_;
488 
489   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
490 };
491 
492 #define REGISTER_CPU_KERNEL(T)                                                 \
493   REGISTER_KERNEL_BUILDER(                                                     \
494       Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
495       DepthwiseConv2dNativeOp<CPUDevice, T>)
496 
497 TF_CALL_half(REGISTER_CPU_KERNEL);
498 TF_CALL_float(REGISTER_CPU_KERNEL);
499 #if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
500 TF_CALL_double(REGISTER_CPU_KERNEL);
501 #endif
502 
503 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
504 
505 #define REGISTER_GPU_KERNEL(T)                                                 \
506   REGISTER_KERNEL_BUILDER(                                                     \
507       Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
508       DepthwiseConv2dNativeOp<GPUDevice, T>)
509 
510 TF_CALL_half(REGISTER_GPU_KERNEL);
511 TF_CALL_float(REGISTER_GPU_KERNEL);
512 TF_CALL_double(REGISTER_GPU_KERNEL);
513 
514 #if CUDNN_VERSION >= 7000
515 template <typename T>
516 class DepthwiseConv2dGroupedConvOp
517     : public DepthwiseConv2dNativeOp<GPUDevice, T> {
518  public:
DepthwiseConv2dGroupedConvOp(OpKernelConstruction * context)519   DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context)
520       : DepthwiseConv2dNativeOp<GPUDevice, T>(context) {
521     this->use_cudnn_grouped_conv_ = true;
522   }
523 };
524 
525 #define REGISTER_GROUPED_CONV_KERNEL(T)                            \
526   REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")            \
527                               .Device(DEVICE_GPU)                  \
528                               .TypeConstraint<T>("T")              \
529                               .Label("cudnn_grouped_convolution"), \
530                           DepthwiseConv2dGroupedConvOp<T>)
531 
532 TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL);
533 TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL);
534 TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL);
535 #endif  // CUDNN_VERSION
536 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
537 
538 }  // namespace tensorflow
539