1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/depthwise_conv_op.h" 19 20 #include <algorithm> 21 #include <cmath> 22 #include <type_traits> 23 24 #include "tensorflow/core/framework/bounds_check.h" 25 #include "tensorflow/core/framework/kernel_shape_util.h" 26 #include "tensorflow/core/framework/numeric_op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_types.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/kernels/conv_ops.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/types.h" 37 #include "tensorflow/core/util/padding.h" 38 #include "tensorflow/core/util/tensor_format.h" 39 #include "tensorflow/core/util/use_cudnn.h" 40 #include "tensorflow/core/util/work_sharder.h" 41 42 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 43 44 #if GOOGLE_CUDA 45 #include "third_party/gpus/cudnn/cudnn.h" 46 #endif 47 48 #include "tensorflow/core/platform/stream_executor.h" 49 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 50 51 namespace tensorflow { 52 53 // In depthwise convolution, one input is convolved into depth_multipler 54 // outputs and the outputs don't need to be reduced again like what regular 55 // convolution does. 56 // However, the way to apply filters to inputs is exactly the same as the 57 // regular convolution. Please refer to the regular convolution kernels for 58 // more details. 59 60 typedef Eigen::ThreadPoolDevice CPUDevice; 61 typedef Eigen::GpuDevice GPUDevice; 62 63 // Computes the vectorized product of 'input_buffer' and 'filter' and stores 64 // result in 'output' at location specified by 'out_r' and 'out_c'. 65 // 66 // EX: 67 // in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 68 // Both 'input_buffer' and 'filter' are padded to register-width boundaries. 69 // 70 // input_buffer [rows, cols, in_depth, depth_multiplier] 71 // [a0, a0, a1, a1] [a2, a2, 0, 0] [b0, b0, b1, b1] [b2, b2, 0, 0] 72 // [e0, e0, e1, e1] [e2, e2, 0, 0] [f0, f0, f1, f1] [f2, f2, 0, 0] 73 // 74 // filter [rows, cols, in_depth, depth_multiplier] 75 // [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0] 76 // [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0] 77 // 78 // First output register [in_depth, depth_multiplier] 79 // [q0, q1, q2, q3] = ([a0, a0, a1, a1] x [u0, v0, w0, x0]) + 80 // ([b0, b0, b1, b1] x [u1, v1, w1, x1]) + 81 // ([e0, e0, e1, e1] x [u2, v2, w2, x2]) + 82 // ([f0, f0, f1, f1] x [u3, v3, w3, x3]) 83 // 84 // TODO(andydavis) Experiment with processing multiple inputs per input buffer. 85 template <typename T> 86 struct DepthwiseConv2DKernel { Runtensorflow::DepthwiseConv2DKernel87 static void Run(const DepthwiseArgs& args, 88 const int64 padded_filter_inner_dim_size, const int64 out_r, 89 const int64 out_c, const T* filter, const T* input_buffer, 90 T* output, TensorFormat data_format) { 91 typedef typename Eigen::internal::packet_traits<T>::type Packet; 92 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 93 94 const int64 out_depth = args.out_depth; 95 const int64 filter_spatial_size = args.filter_rows * args.filter_cols; 96 const int64 output_scalar_size = out_depth % kPacketSize; 97 const int64 output_vectorized_size = 98 (out_depth / kPacketSize) * kPacketSize; 99 const int64 base_output_index = (out_r * args.out_cols + out_c) * out_depth; 100 101 for (int i = 0; i < output_vectorized_size; i += kPacketSize) { 102 // Reset accumulator. 103 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0)); 104 for (int j = 0; j < filter_spatial_size; ++j) { 105 // Calculate index. 106 const int64 index = i + j * padded_filter_inner_dim_size; 107 // Load filter. 108 // TODO(andydavis) Unroll 'out_c' loop in caller so we can load 109 // multiple inputs here to amortize the cost of each filter block load. 110 const auto filter_block = 111 Eigen::internal::ploadu<Packet>(filter + index); 112 // Load input. 113 const auto data_block = 114 Eigen::internal::ploadu<Packet>(input_buffer + index); 115 // Vector multiply-add. 116 vaccum = 117 Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum); 118 } 119 // Store vector accumulator to output. 120 Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum); 121 } 122 123 if (output_scalar_size > 0) { 124 auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0)); 125 for (int j = 0; j < filter_spatial_size; ++j) { 126 const int64 index = 127 output_vectorized_size + j * padded_filter_inner_dim_size; 128 const auto filter_block = 129 Eigen::internal::ploadu<Packet>(filter + index); 130 const auto data_block = 131 Eigen::internal::ploadu<Packet>(input_buffer + index); 132 vaccum = 133 Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum); 134 } 135 // Load accumulator into an array and loop through output. 136 T out_buf[kPacketSize]; 137 Eigen::internal::pstoreu<T>(out_buf, vaccum); 138 const int64 last_output_index = 139 base_output_index + output_vectorized_size; 140 for (int j = 0; j < output_scalar_size; ++j) { 141 output[last_output_index + j] = out_buf[j]; 142 } 143 } 144 } 145 }; 146 147 // Computes the depthwise conv2d of 'input' by 'depthwise_filter' and stores 148 // the result in 'output'. This implementation trades off copying small patches 149 // of the input to achieve better data alignment, which enables vectorized 150 // load/store and multiply-add operations (see comments at InputBufferCopyOp and 151 // DepthwiseConv2DKernel for details). 152 // 153 // TODO(andydavis) Evaluate the performance of processing multiple input 154 // patches in the inner loop. 155 // TODO(andydavis) Consider a zero-copy implementation for the case when 156 // 'in_depth' is a multiple of register width, and 'depth_multipler' is one. 157 // TODO(andydavis) Evaluate the performance of alternative implementations. 158 template <typename T> 159 struct LaunchDepthwiseConvOp<CPUDevice, T> { 160 typedef typename Eigen::internal::packet_traits<T>::type Packet; 161 operator ()tensorflow::LaunchDepthwiseConvOp162 void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, 163 const T* input, const T* depthwise_filter, T* output, 164 TensorFormat data_format) { 165 OP_REQUIRES( 166 ctx, data_format == FORMAT_NHWC, 167 errors::Unimplemented( 168 "Depthwise convolution on CPU is only supported for NHWC format")); 169 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 170 171 // Pad 'depthwise_filter' to vector register width (if needed). 172 const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true; 173 Tensor padded_filter; 174 if (pad_filter) { 175 // Allocate space for padded filter. 176 const int64 filter_spatial_size = args.filter_rows * args.filter_cols; 177 const int64 padded_filter_inner_dim_size = 178 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; 179 OP_REQUIRES_OK( 180 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 181 TensorShape({filter_spatial_size, 182 padded_filter_inner_dim_size}), 183 &padded_filter)); 184 // Write out padded filter. 185 functor::DepthwiseFilterPadOp<T>()( 186 args, depthwise_filter, padded_filter.template flat<T>().data()); 187 } 188 const T* filter_data = 189 pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter; 190 191 // Computes one shard of depthwise conv2d output. 192 auto shard = [&ctx, &args, &input, &filter_data, &output, data_format]( 193 int64 start, int64 limit) { 194 static const int64 kPacketSize = (sizeof(Packet) / sizeof(T)); 195 const int64 input_image_size = 196 args.in_rows * args.in_cols * args.in_depth; 197 const int64 output_image_size = 198 args.out_rows * args.out_cols * args.out_depth; 199 const int64 filter_spatial_size = args.filter_rows * args.filter_cols; 200 const int64 padded_filter_inner_dim_size = 201 ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize; 202 203 // Allocate buffer for local input regions. 204 Tensor input_buffer; 205 OP_REQUIRES_OK( 206 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 207 TensorShape({filter_spatial_size, 208 padded_filter_inner_dim_size}), 209 &input_buffer)); 210 T* input_buffer_data = input_buffer.template flat<T>().data(); 211 212 for (int64 i = start; i < limit; ++i) { 213 const int64 b = i / args.out_rows; 214 const int64 in_base = b * input_image_size; 215 const int64 out_base = b * output_image_size; 216 217 const int64 out_r = i % args.out_rows; 218 219 for (int64 out_c = 0; out_c < args.out_cols; ++out_c) { 220 // Populate 'input_buffer_data' with data from local input region. 221 functor::DepthwiseInputCopyOp<T>()(args, padded_filter_inner_dim_size, 222 out_r, out_c, input + in_base, 223 input_buffer_data); 224 225 // Process buffered input across all filters and store to output. 226 DepthwiseConv2DKernel<T>::Run( 227 args, padded_filter_inner_dim_size, out_r, out_c, filter_data, 228 input_buffer_data, output + out_base, data_format); 229 } 230 } 231 }; 232 233 const int64 total_shards = args.batch * args.out_rows; 234 235 // Empirically tested to give reasonable performance boosts at batch size 1 236 // without reducing throughput at batch size 32. 237 const float kCostMultiplier = 2.5f; 238 239 // TODO(andydavis): Estimate shard cost (in cycles) based on the number of 240 // flops/loads/stores required to compute one shard. 241 const int64 shard_cost = kCostMultiplier * args.out_cols * args.out_depth; 242 243 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 244 Shard(worker_threads.num_threads, worker_threads.workers, total_shards, 245 shard_cost, shard); 246 } 247 }; 248 249 // Extern template instantiated in conv_ops.cc. 250 extern template struct LaunchConv2DOp<CPUDevice, Eigen::half>; 251 extern template struct LaunchConv2DOp<CPUDevice, float>; 252 extern template struct LaunchConv2DOp<CPUDevice, double>; 253 254 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 255 256 // Extern template instantiated in conv_ops.cc. 257 extern template struct LaunchConv2DOp<GPUDevice, Eigen::half>; 258 extern template struct LaunchConv2DOp<GPUDevice, float>; 259 extern template struct LaunchConv2DOp<GPUDevice, double>; 260 261 // Extern template instantiated in depthwise_conv_op_gpu.cc. 262 extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>; 263 extern template struct LaunchDepthwiseConvOp<GPUDevice, float>; 264 extern template struct LaunchDepthwiseConvOp<GPUDevice, double>; 265 266 #endif 267 268 template <typename Device, typename T> 269 class DepthwiseConv2dNativeOp : public BinaryOp<T> { 270 public: DepthwiseConv2dNativeOp(OpKernelConstruction * context)271 explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context) 272 : BinaryOp<T>(context) { 273 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 274 string data_format; 275 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 276 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 277 errors::InvalidArgument("Invalid data format")); 278 279 OP_REQUIRES(context, strides_.size() == 4, 280 errors::InvalidArgument("Sliding window strides field must " 281 "specify 4 dimensions")); 282 stride_ = GetTensorDim(strides_, data_format_, 'H'); 283 const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); 284 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 285 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 286 287 OP_REQUIRES(context, stride_ == stride_w, 288 errors::InvalidArgument( 289 "Current implementation only supports equal length " 290 "strides in the row and column dimensions.")); 291 OP_REQUIRES( 292 context, (stride_n == 1 && stride_c == 1), 293 errors::InvalidArgument("Current implementation does not yet support " 294 "strides in the batch and depth dimensions.")); 295 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 296 OP_REQUIRES_OK(context, 297 context->GetAttr("explicit_paddings", &explicit_paddings_)); 298 OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, 299 /*num_dims=*/4, data_format_)); 300 301 cudnn_use_autotune_ = CudnnUseAutotune(); 302 dtype_ = DataTypeToEnum<T>::value; 303 #if CUDNN_VERSION >= 8000 304 // From the cuDNN release note 8.0: We’ve extended the fprop and dgrad 305 // NHWC depthwise kernels to support more combinations (filter 306 // sizes/strides) such as 5x5/1x1, 5x5/2x2, 7x7/1x1, 7x7/2x2 (in addition 307 // to what we already have, 1x1/1x1, 3x3/1x1, 3x3/2x2), which provides 308 // good performance. (https://docs.nvidia.com/deeplearning/sdk/cudnn- 309 // release-notes/rel_8.html#rel_8) 310 use_cudnn_grouped_conv_ = 311 dtype_ == DT_HALF && 312 (data_format_ == FORMAT_NCHW || 313 (data_format_ == FORMAT_NHWC && stride_ == stride_w && 314 (stride_ == 1 || stride_ == 2))); 315 #elif CUDNN_VERSION >= 7603 316 // Use CuDNN grouped conv only when input/output is NCHW and float16(half). 317 // See cudnn release note 7.6.3. (https://docs.nvidia.com/deeplearning/sdk/c 318 // udnn-release-notes/rel_763.html#rel_763) 319 use_cudnn_grouped_conv_ = dtype_ == DT_HALF && data_format_ == FORMAT_NCHW; 320 #else 321 use_cudnn_grouped_conv_ = false; 322 #endif 323 } 324 Compute(OpKernelContext * context)325 void Compute(OpKernelContext* context) override { 326 // Input tensor is of the following dimensions: 327 // [ batch, in_rows, in_cols, in_depth ] 328 const Tensor& input = context->input(0); 329 330 // Input filter is of the following dimensions: 331 // [ filter_rows, filter_cols, in_depth, depth_multiplier] 332 const Tensor& filter = context->input(1); 333 334 // For 2D convolution, there should be 4 dimensions. 335 OP_REQUIRES(context, input.dims() == 4, 336 errors::InvalidArgument("input must be 4-dimensional", 337 input.shape().DebugString())); 338 OP_REQUIRES(context, filter.dims() == 4, 339 errors::InvalidArgument("filter must be 4-dimensional: ", 340 filter.shape().DebugString())); 341 342 // in_depth for input and filter must match. 343 const int64 in_depth = GetTensorDim(input, data_format_, 'C'); 344 OP_REQUIRES(context, in_depth == filter.dim_size(2), 345 errors::InvalidArgument( 346 "input and filter must have the same depth: ", in_depth, 347 " vs ", filter.dim_size(2))); 348 349 // The last dimension for filter is depth multiplier. 350 const int32 depth_multiplier = filter.dim_size(3); 351 352 // The output depth is input depth x depth multiplier 353 const int32 out_depth = in_depth * depth_multiplier; 354 355 const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); 356 OP_REQUIRES( 357 context, 358 FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()), 359 errors::InvalidArgument("Input rows too large")); 360 const int32 input_rows = static_cast<int32>(input_rows_raw); 361 const int32 filter_rows = filter.dim_size(0); 362 363 const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); 364 OP_REQUIRES( 365 context, 366 FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()), 367 errors::InvalidArgument("Input cols too large")); 368 const int32 input_cols = static_cast<int32>(input_cols_raw); 369 const int32 filter_cols = filter.dim_size(1); 370 371 // The first dimension for input is batch. 372 const int32 batch = input.dim_size(0); 373 374 int64 out_rows = 0, out_cols = 0, pad_top = 0, pad_bottom = 0, pad_left = 0, 375 pad_right = 0; 376 if (padding_ == Padding::EXPLICIT) { 377 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', &pad_top, 378 &pad_bottom); 379 GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', &pad_left, 380 &pad_right); 381 } 382 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 383 input_rows, filter_rows, stride_, padding_, 384 &out_rows, &pad_top, &pad_bottom)); 385 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 386 input_cols, filter_cols, stride_, padding_, 387 &out_cols, &pad_left, &pad_right)); 388 TensorShape out_shape = 389 ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); 390 OP_REQUIRES( 391 context, 392 (!std::is_same<Device, GPUDevice>::value || 393 FastBoundsCheck(out_shape.num_elements(), 394 std::numeric_limits<int32>::max())), 395 errors::InvalidArgument("Output elements too large for GPU kernel")); 396 397 Tensor* output = nullptr; 398 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 399 400 // If there is nothing to compute, return. 401 if (out_shape.num_elements() == 0) { 402 return; 403 } 404 405 // TODO(csigg): Have autotune decide if native is faster than cuDNN. 406 // If in_depth==1, this operation is just a standard convolution. 407 // Depthwise convolution is a special case of cuDNN's grouped convolution. 408 bool use_cudnn = std::is_same<Device, GPUDevice>::value && 409 (in_depth == 1 || 410 (use_cudnn_grouped_conv_ && 411 IsCudnnSupportedFilterSize(/*filter_rows=*/filter_rows, 412 /*filter_cols=*/filter_cols, 413 /*in_depth=*/in_depth, 414 /*out_depth=*/out_depth))); 415 416 VLOG(2) << "DepthwiseConv2dNative: " 417 << " Input: [" << batch << ", " << input_rows << ", " << input_cols 418 << ", " << in_depth << "]; Filter: [" << filter_rows << ", " 419 << filter_cols << ", " << in_depth << ", " << depth_multiplier 420 << "]; Output: [" << batch << ", " << out_rows << ", " << out_cols 421 << ", " << out_depth << "], stride = " << stride_ 422 << ", pad_top = " << pad_top << ", pad_left = " << pad_left 423 << ", Use cuDNN: " << use_cudnn; 424 425 if (use_cudnn) { 426 // Reshape from TF depthwise filter to cuDNN grouped convolution filter: 427 // 428 // | TensorFlow | cuDNN 429 // -------------------------------------------------------------------- 430 // filter_out_depth | depth_multiplier | depth_multiplier * group_count 431 // filter_in_depth | in_depth | in_depth / group_count 432 // 433 // For depthwise convolution, we have group_count == in_depth. 434 int32 filter_in_depth = 1; 435 TensorShape shape = 436 TensorShape{filter_rows, filter_cols, filter_in_depth, out_depth}; 437 Tensor reshaped_filter(/*type=*/dtype_); 438 OP_REQUIRES( 439 context, reshaped_filter.CopyFrom(filter, shape), 440 errors::Internal( 441 "Failed to reshape filter tensor for grouped convolution.")); 442 // TODO(yangzihao): Send in arbitrary dilation rates after the dilated 443 // conv is supported. 444 launcher_(context, /*use_cudnn=*/true, cudnn_use_autotune_, input, 445 reshaped_filter, /*row_dilation=*/1, /*col_dilation=*/1, 446 stride_, stride_, padding_, explicit_paddings_, output, 447 data_format_); 448 return; 449 } 450 451 DepthwiseArgs args; 452 args.batch = batch; 453 args.in_rows = input_rows; 454 args.in_cols = input_cols; 455 args.in_depth = in_depth; 456 args.filter_rows = filter_rows; 457 args.filter_cols = filter_cols; 458 args.depth_multiplier = depth_multiplier; 459 args.stride = stride_; 460 args.pad_rows = pad_top; 461 args.pad_cols = pad_left; 462 args.out_rows = out_rows; 463 args.out_cols = out_cols; 464 args.out_depth = out_depth; 465 466 auto input_ptr = input.template flat<T>().data(); 467 auto filter_ptr = filter.template flat<T>().data(); 468 auto output_ptr = output->template flat<T>().data(); 469 LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr, 470 output_ptr, data_format_); 471 } 472 473 protected: 474 bool use_cudnn_grouped_conv_; 475 476 private: 477 std::vector<int32> strides_; 478 Padding padding_; 479 std::vector<int64> explicit_paddings_; 480 TensorFormat data_format_; 481 482 int64 stride_; // in height/width dimension. 483 484 // For in_depth == 1 and grouped convolutions. 485 LaunchConv2DOp<Device, T> launcher_; 486 bool cudnn_use_autotune_; 487 DataType dtype_; 488 489 TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp); 490 }; 491 492 #define REGISTER_CPU_KERNEL(T) \ 493 REGISTER_KERNEL_BUILDER( \ 494 Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 495 DepthwiseConv2dNativeOp<CPUDevice, T>) 496 497 TF_CALL_half(REGISTER_CPU_KERNEL); 498 TF_CALL_float(REGISTER_CPU_KERNEL); 499 #if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG) 500 TF_CALL_double(REGISTER_CPU_KERNEL); 501 #endif 502 503 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 504 505 #define REGISTER_GPU_KERNEL(T) \ 506 REGISTER_KERNEL_BUILDER( \ 507 Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 508 DepthwiseConv2dNativeOp<GPUDevice, T>) 509 510 TF_CALL_half(REGISTER_GPU_KERNEL); 511 TF_CALL_float(REGISTER_GPU_KERNEL); 512 TF_CALL_double(REGISTER_GPU_KERNEL); 513 514 #if CUDNN_VERSION >= 7000 515 template <typename T> 516 class DepthwiseConv2dGroupedConvOp 517 : public DepthwiseConv2dNativeOp<GPUDevice, T> { 518 public: DepthwiseConv2dGroupedConvOp(OpKernelConstruction * context)519 DepthwiseConv2dGroupedConvOp(OpKernelConstruction* context) 520 : DepthwiseConv2dNativeOp<GPUDevice, T>(context) { 521 this->use_cudnn_grouped_conv_ = true; 522 } 523 }; 524 525 #define REGISTER_GROUPED_CONV_KERNEL(T) \ 526 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative") \ 527 .Device(DEVICE_GPU) \ 528 .TypeConstraint<T>("T") \ 529 .Label("cudnn_grouped_convolution"), \ 530 DepthwiseConv2dGroupedConvOp<T>) 531 532 TF_CALL_half(REGISTER_GROUPED_CONV_KERNEL); 533 TF_CALL_float(REGISTER_GROUPED_CONV_KERNEL); 534 TF_CALL_double(REGISTER_GROUPED_CONV_KERNEL); 535 #endif // CUDNN_VERSION 536 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 537 538 } // namespace tensorflow 539