1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if GOOGLE_CUDA
17 #define EIGEN_USE_GPU
18 #endif // GOOGLE_CUDA
19
20 #include "tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h"
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/framework/numeric_op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_slice.h"
29 #include "tensorflow/core/kernels/conv_2d.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/util/padding.h"
34 #include "tensorflow/core/util/use_cudnn.h"
35
36 #if GOOGLE_CUDA
37 #include "google/protobuf/duration.pb.h"
38 #include "absl/time/time.h"
39 #include "cuda/include/cudnn.h"
40 #include "tensorflow/core/framework/node_def.pb.h"
41 #include "tensorflow/core/framework/tensor.pb.h"
42 #include "tensorflow/core/kernels/conv_ops_gpu.h"
43 #include "tensorflow/core/platform/logger.h"
44 #include "tensorflow/core/platform/stream_executor.h"
45 #include "tensorflow/core/protobuf/autotuning.pb.h"
46 #include "tensorflow/core/protobuf/conv_autotuning.pb.h"
47 #include "tensorflow/core/util/activation_mode.h"
48 #endif // GOOGLE_CUDA
49
50 namespace tensorflow {
51
52 namespace {
53 typedef Eigen::GpuDevice GPUDevice;
54
55 template <typename T>
56 struct RawType {
57 using type = T;
58 };
59
60 template <>
61 struct RawType<qint8> {
62 using type = int8;
63 };
64
65 // Template struct to convert int8x4 to int32.
66 // (for NCHW_VECT_C with element type int8, we can consider it to be
67 // an NCHW layout with element type int32 for operations like padding).
68 template <typename T>
69 struct Int8x4ToInt32 {
70 // By default, do not change T.
71 using type = T;
72 };
73
74 template <>
75 struct Int8x4ToInt32<int8> {
76 using type = int32;
77 };
78 } // namespace
79
80 // T is the element type of the conv_input, filter and side_input tensors.
81 // BiasType is the element type of the bias tensor, which can be different.
82 // ScaleType is the type used for conv_input_scale, side_input_scale.
83 template <typename Device, typename T, typename BiasType, typename ScaleType>
84 class FusedConv2DBiasActivationOp : public OpKernel {
85 public:
86 enum InputIndexes {
87 kConvInput = 0,
88 kFilter,
89 kBias,
90 kSideInput,
91 kConvInputScale,
92 kSideInputScale,
93 kNumInputs
94 };
95
FusedConv2DBiasActivationOp(OpKernelConstruction * context)96 explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
97 : OpKernel(context) {
98 string data_format_str, filter_format_str;
99 CHECK_EQ(kNumInputs, context->num_inputs());
100 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
101 OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
102 errors::InvalidArgument("Invalid data format"));
103 OP_REQUIRES_OK(context,
104 context->GetAttr("filter_format", &filter_format_str));
105 OP_REQUIRES(context,
106 FilterFormatFromString(filter_format_str, &filter_format_),
107 errors::InvalidArgument("Invalid filter format"));
108
109 std::vector<int32> strides;
110 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
111 OP_REQUIRES(context, strides.size() == 4,
112 errors::InvalidArgument("Sliding window strides field must "
113 "specify 4 dimensions"));
114
115 stride_rows_ = GetTensorDim(strides, data_format_, 'H');
116 stride_cols_ = GetTensorDim(strides, data_format_, 'W');
117 OP_REQUIRES(
118 context,
119 (GetTensorDim(strides, data_format_, 'N') == 1 &&
120 GetTensorDim(strides, data_format_, 'C') == 1),
121 errors::Unimplemented("Convolutional strides are not supported in "
122 "the batch and depth dimensions."));
123
124 // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
125 constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
126
127 // Note: Only NCHW_VECT_C format is supported for int8.
128 // This is because it is expected to be the fastest, and our previous tests
129 // found cudnn 6 does not fully support the other formats for int8 mode.
130 OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
131 errors::InvalidArgument(
132 "qint8 should be used with data_format NCHW_VECT_C."));
133
134 OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)),
135 errors::InvalidArgument(
136 "qint8 should be used with filter_format OIHW_VECT_I."));
137
138 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_));
139 eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_);
140 string activation_mode_str;
141 OP_REQUIRES_OK(context,
142 context->GetAttr("activation_mode", &activation_mode_str));
143 OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str,
144 &activation_mode_));
145 OP_REQUIRES(context,
146 activation_mode_ == ActivationMode::RELU ||
147 activation_mode_ == ActivationMode::NONE,
148 errors::InvalidArgument(
149 "Current implementation only supports RELU or NONE "
150 "as the activation function."));
151 cudnn_use_autotune_ = CudnnUseAutotune();
152 }
153
CheckShape(const Tensor & tensor,const string & tensor_name)154 Status CheckShape(const Tensor& tensor, const string& tensor_name) {
155 const int num_dims = tensor.dims();
156 for (int i = 0; i < num_dims; i++) {
157 if (!FastBoundsCheck(tensor.dim_size(i),
158 std::numeric_limits<int32>::max())) {
159 return errors::InvalidArgument(tensor_name, " dimension ", i,
160 " too large");
161 }
162 }
163 // If there is a 5th dimension it is the VECT_C or VECT_I dimension.
164 if (num_dims == 5 && tensor.dim_size(4) != 4) {
165 return errors::InvalidArgument("The last dimension of ", tensor_name,
166 " must be of size 4 for qint8.");
167 }
168 return Status::OK();
169 }
170
Compute(OpKernelContext * context)171 void Compute(OpKernelContext* context) override {
172 // The conv_input tensor is one of the following formats:
173 // NHWC, NCHW, NCHW_VECT_C.
174 const Tensor& conv_input = context->input(kConvInput);
175 OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
176
177 // The filter tensor is one of the following formats:
178 // HWIO, OIHW, OIHW_VECT_I.
179 const Tensor& filter = context->input(kFilter);
180 OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
181
182 // Input bias is a 1-D tensor, with size matching output depth.
183 const Tensor& bias = context->input(kBias);
184 OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
185
186 const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
187 const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
188
189 auto conv_input_scale = *reinterpret_cast<const ScaleType*>(
190 conv_input_scale_tensor.tensor_data().data());
191 auto side_input_scale = *reinterpret_cast<const ScaleType*>(
192 side_input_scale_tensor.tensor_data().data());
193
194 // If side_input_scale != 0, then side_input is not ignored and
195 // has the same type and dimensions as the output.
196 const Tensor& side_input = context->input(kSideInput);
197 if (side_input_scale != 0) {
198 OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
199 }
200
201 // TODO(pauldonnelly): Switch to a more efficient mechanism to access
202 // dimension indexes and per-dimension attributes.
203 const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H');
204 const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W');
205 const int32 output_depth = GetFilterDim(filter, filter_format_, 'O');
206
207 const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N');
208 const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H');
209 const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W');
210
211 int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0;
212 OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows,
213 stride_rows_, padding_type_,
214 &output_rows, &pad_rows));
215 OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols,
216 stride_cols_, padding_type_,
217 &output_cols, &pad_cols));
218 // Initialize the output tensor shape according to data_format_
219 TensorShape output_shape = ShapeFromFormat(
220 data_format_, batch_size, output_rows, output_cols, output_depth);
221 Tensor* output = nullptr;
222 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
223
224 VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = "
225 << conv_input_cols << ", conv_input_rows = " << conv_input_rows
226 << ", filter_cols = " << filter_cols
227 << ", filter_rows = " << filter_rows
228 << ", stride_cols = " << stride_cols_
229 << ", stride_rows = " << stride_rows_
230 << ", output_depth = " << output_depth
231 << ", output_cols = " << output_cols
232 << ", output_rows = " << output_rows
233 << ", output_shape.num_elements = " << output_shape.num_elements();
234
235 // If there is nothing to compute, return.
236 if (output_shape.num_elements() == 0) {
237 return;
238 }
239
240 launcher_.launch(context, cudnn_use_autotune_, conv_input, conv_input_scale,
241 filter, stride_rows_, stride_cols_, eigen_padding_type_,
242 side_input, side_input_scale, bias, activation_mode_,
243 data_format_, filter_format_, output);
244 }
245
246 private:
247 int32 stride_rows_, stride_cols_;
248 Padding padding_type_;
249 Eigen::PaddingType eigen_padding_type_;
250 ActivationMode activation_mode_;
251 TensorFormat data_format_;
252 FilterTensorFormat filter_format_;
253 LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
254 bool cudnn_use_autotune_;
255
256 TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp);
257 };
258
259 #if GOOGLE_CUDA
260 namespace dnn = se::dnn;
261
262 // Several functions are copyed over from tensorflow/core/kernels/gpu_utils,
263 // since this file may be compiled down to a tf_custom_op_library .so file,
264 // which can't depend on basic dependencies like tensorflow/core:lib. Instead,
265 // the code has to depend on whatever is the same in libtensorflow_framework.so.
266 //
267 // In theory, we can lift the dependencies of gpu_utils by turning it into a
268 // template library that provides duck typing, but I think duplication is the
269 // lesser of two evils.
270 namespace internal {
271 namespace {
272
GetCudnnVersion(se::StreamExecutor * stream_executor)273 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
274 tensorflow::CudnnVersion cudnn_version;
275 if (auto* dnn = stream_executor->AsDnn()) {
276 se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
277 if (version_or.ok()) {
278 const auto& version = version_or.ValueOrDie();
279 cudnn_version.set_major(version.major_version());
280 cudnn_version.set_minor(version.minor_version());
281 cudnn_version.set_patch(version.patch());
282 }
283 }
284 return cudnn_version;
285 }
286
287 // Converts an absl::Duration to a google::protobuf::Duration.
ToDurationProto(absl::Duration duration)288 inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
289 google::protobuf::Duration proto;
290 proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
291 proto.set_nanos(
292 absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
293 return proto;
294 }
295
296 // Converts a google::protobuf::Duration to an absl::Duration.
FromDurationProto(google::protobuf::Duration proto)297 inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
298 return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
299 }
300
GetComputeCapability(se::StreamExecutor * stream_executor)301 tensorflow::ComputeCapability GetComputeCapability(
302 se::StreamExecutor* stream_executor) {
303 tensorflow::ComputeCapability cc;
304 int cc_major, cc_minor;
305 stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
306 &cc_minor);
307 cc.set_major(cc_major);
308 cc.set_minor(cc_minor);
309 return cc;
310 }
311
LogFusedConvAutotuneResults(const NodeDef & node,const Tensor & input,const Tensor & filter,const Tensor & output,const Tensor & bias,const Tensor * side_input,se::StreamExecutor * stream_exec,absl::Span<const AutotuneResult> results)312 void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
313 const Tensor& filter, const Tensor& output,
314 const Tensor& bias, const Tensor* side_input,
315 se::StreamExecutor* stream_exec,
316 absl::Span<const AutotuneResult> results) {
317 AutotuningLog log;
318 ConvNodeDef instr;
319 *instr.mutable_conv() = node;
320 input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
321 instr.mutable_input()->set_dtype(input.dtype());
322 filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
323 instr.mutable_filter()->set_dtype(filter.dtype());
324 output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
325 instr.mutable_output()->set_dtype(output.dtype());
326 bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape());
327 instr.mutable_bias()->set_dtype(bias.dtype());
328 if (side_input) {
329 side_input->shape().AsProto(
330 instr.mutable_side_input()->mutable_tensor_shape());
331 instr.mutable_side_input()->set_dtype(side_input->dtype());
332 }
333 log.mutable_instr()->PackFrom(std::move(instr));
334 *log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec);
335 *log.mutable_compute_capability() =
336 internal::GetComputeCapability(stream_exec);
337 for (const auto& result : results) {
338 *log.add_results() = result;
339 }
340 Logger::Singleton()->LogProto(log);
341 }
342
BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,se::dnn::AlgorithmConfig * algo)343 Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
344 se::dnn::AlgorithmConfig* algo) {
345 // For the "!xhs.has_success()" below, this is because we want successful ones
346 // to order first, therefore they need a smaller key per "min_element".
347 const AutotuneResult* best_result = std::min_element(
348 results.begin(), results.end(),
349 [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
350 return std::make_tuple(
351 !lhs.has_success(),
352 internal::FromDurationProto(lhs.success().run_time())) <
353 std::make_tuple(
354 !rhs.has_success(),
355 internal::FromDurationProto(rhs.success().run_time()));
356 });
357
358 const AutotuneResult* best_result_no_scratch = std::min_element(
359 results.begin(), results.end(),
360 [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
361 return std::make_tuple(
362 !lhs.has_success(), lhs.success().scratch_bytes(),
363 internal::FromDurationProto(lhs.success().run_time())) <
364 std::make_tuple(
365 !rhs.has_success(), rhs.success().scratch_bytes(),
366 internal::FromDurationProto(rhs.success().run_time()));
367 });
368
369 if (best_result == results.end() || !best_result->has_success()) {
370 return errors::NotFound("No algorithm worked!");
371 }
372 algo->set_algorithm({best_result->conv().algorithm(),
373 best_result->conv().tensor_ops_enabled()});
374 if (best_result_no_scratch != results.end() &&
375 best_result_no_scratch->has_success() &&
376 best_result_no_scratch->success().scratch_bytes() == 0) {
377 algo->set_algorithm_no_scratch(
378 {best_result_no_scratch->conv().algorithm(),
379 best_result_no_scratch->conv().tensor_ops_enabled()});
380 }
381 return Status::OK();
382 }
383
384 } // namespace
385 } // namespace internal
386
387 // A dummy type to group forward convolution autotune results together.
388 struct ConvBiasActivationAutoTuneGroup {
nametensorflow::ConvBiasActivationAutoTuneGroup389 static string name() { return "ConvBiasActivation"; }
390 };
391 typedef AutoTuneSingleton<ConvBiasActivationAutoTuneGroup, FusedConvParameters,
392 dnn::AlgorithmConfig>
393 AutoTuneConvBiasActivation;
394
395 // Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it
396 // using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions.
397 template <typename T, size_t NDIMS>
TransformNHWCToNCHW(OpKernelContext * ctx,const Tensor & nhwc_tensor,int batch_size,int rows,int cols,int depth,Tensor * transformed_tensor,const Tensor ** result)398 Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor,
399 int batch_size, int rows, int cols, int depth,
400 Tensor* transformed_tensor, const Tensor** result) {
401 TensorShape nchw_shape =
402 ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth);
403 if (depth > 1) {
404 TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
405 transformed_tensor));
406 functor::NHWCToNCHW<GPUDevice, T, NDIMS>()(
407 ctx->eigen_device<GPUDevice>(), nhwc_tensor.tensor<T, NDIMS>(),
408 transformed_tensor->tensor<T, NDIMS>());
409 } else {
410 // If depth <= 1, then just reshape.
411 CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape));
412 }
413 *result = transformed_tensor;
414 return Status::OK();
415 }
416
417 // Adjusts padding so cudnn supports it. Sets `adjusted_padding` to be the
418 // adjusted padding, and `extra_padding_before` and `extra_padding_after` to be
419 // the extra padding that FusedConv needs to apply before calling cudnn.
AdjustPaddingForCudnn(int padding,bool is_int8x4,int filter_size,int * adjusted_padding,int * extra_padding_before,int * extra_padding_after)420 void AdjustPaddingForCudnn(int padding, bool is_int8x4, int filter_size,
421 int* adjusted_padding, int* extra_padding_before,
422 int* extra_padding_after) {
423 #if CUDNN_VERSION < 7000
424 if (is_int8x4 && filter_size >= 6) {
425 // TODO(b/70795525): Remove after NVIDIA fixes this bug with int8 fused
426 // convolution. I don't know cuDNN7 still has the bug, so enable this
427 // workaround for cuDNN6 or older.
428 *adjusted_padding = 0;
429 *extra_padding_before = padding / 2;
430 *extra_padding_after = padding - *extra_padding_before;
431 return;
432 }
433 #endif
434 *adjusted_padding = padding / 2 * 2;
435 *extra_padding_before = 0;
436 *extra_padding_after = padding % 2;
437 }
438
439 template <typename T, typename BiasType, typename ScaleType>
440 void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
launch(OpKernelContext * ctx,bool cudnn_use_autotune,const Tensor & conv_input_param,ScaleType conv_input_scale,const Tensor & filter_param,int32 row_stride,int32 col_stride,const Eigen::PaddingType & padding,const Tensor & side_input_param,ScaleType side_input_scale,const Tensor & bias,ActivationMode activation_mode,TensorFormat data_format,FilterTensorFormat filter_format,Tensor * output_param)441 launch(OpKernelContext* ctx, bool cudnn_use_autotune,
442 const Tensor& conv_input_param, ScaleType conv_input_scale,
443 const Tensor& filter_param, int32 row_stride, int32 col_stride,
444 const Eigen::PaddingType& padding, const Tensor& side_input_param,
445 ScaleType side_input_scale, const Tensor& bias,
446 ActivationMode activation_mode, TensorFormat data_format,
447 FilterTensorFormat filter_format, Tensor* output_param) {
448 auto* stream = ctx->op_device_context()->stream();
449 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
450
451 // TODO(yangzihao): refactor all the complicated/duplicated code in regular
452 // conv ops to a shared conv utility.
453
454 // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
455 constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
456 constexpr int rank = is_int8x4 ? 5 : 4;
457 constexpr int vect = is_int8x4 ? 4 : 1;
458
459 if (is_int8x4) {
460 int cc_major, cc_minor;
461 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
462 &cc_minor);
463 OP_REQUIRES(
464 ctx, ((cc_major == 6 && cc_minor >= 1) || cc_major > 6),
465 errors::Unimplemented(
466 "FusedConv2DBiasActivation for int8 is only supported on GPUs with "
467 "compute capability 6.1 or later."));
468 }
469
470 const int batch_size = GetTensorDim(conv_input_param, data_format, 'N');
471 int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H');
472 int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W');
473
474 const int conv_input_depth =
475 GetTensorDim(conv_input_param, data_format, 'C') * vect;
476 const int output_rows = GetTensorDim(*output_param, data_format, 'H');
477 const int output_cols = GetTensorDim(*output_param, data_format, 'W');
478 const int output_depth = GetFilterDim(filter_param, filter_format, 'O');
479 const int filter_rows = GetFilterDim(filter_param, filter_format, 'H');
480 const int filter_cols = GetFilterDim(filter_param, filter_format, 'W');
481 int padding_rows = 0;
482 int padding_cols = 0;
483 const Tensor* conv_input = &conv_input_param;
484
485 Tensor maybe_padded_conv_input;
486 if (padding == Eigen::PADDING_SAME) {
487 // Total padding on rows and cols is
488 // Pr = (R' - 1) * S + Kr - R
489 // Pc = (C' - 1) * S + Kc - C
490 // where (R', C') are output dimensions, (R, C) are input dimensions, S
491 // is stride, (Kr, Kc) are filter dimensions.
492 // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
493 // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
494 // we pad more on the right and bottom than on the top and left.
495 padding_rows = std::max<int>(
496 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows);
497 padding_cols = std::max<int>(
498 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols);
499 int extra_top_padding = 0;
500 int extra_bottom_padding = 0;
501 int extra_left_padding = 0;
502 int extra_right_padding = 0;
503 AdjustPaddingForCudnn(padding_rows, is_int8x4, filter_rows, &padding_rows,
504 &extra_top_padding, &extra_bottom_padding);
505 AdjustPaddingForCudnn(padding_cols, is_int8x4, filter_cols, &padding_cols,
506 &extra_left_padding, &extra_right_padding);
507 if (extra_top_padding != 0 || extra_bottom_padding != 0 ||
508 extra_left_padding != 0 || extra_right_padding != 0) {
509 Tensor transformed_input;
510 const int new_conv_input_rows =
511 conv_input_rows + extra_top_padding + extra_bottom_padding;
512 const int new_conv_input_cols =
513 conv_input_cols + extra_left_padding + extra_right_padding;
514
515 using VectT = typename Int8x4ToInt32<typename RawType<T>::type>::type;
516 auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format;
517
518 OP_REQUIRES_OK(
519 ctx, ctx->allocate_temp(
520 DataTypeToEnum<T>::value,
521 ShapeFromFormat(data_format, batch_size, new_conv_input_rows,
522 new_conv_input_cols, conv_input_depth),
523 &maybe_padded_conv_input));
524
525 auto conv_input_eigen_tensor =
526 To32Bit(conv_input_param.reinterpret_last_dimension<VectT, 4>());
527 auto padded_conv_input_eigen_tensor = To32Bit(
528 maybe_padded_conv_input.reinterpret_last_dimension<VectT, 4>());
529
530 functor::PadInput<GPUDevice, VectT, int, 4>()(
531 ctx->eigen_device<GPUDevice>(), conv_input_eigen_tensor,
532 {{extra_top_padding, extra_left_padding}},
533 {{extra_bottom_padding, extra_right_padding}},
534 padded_conv_input_eigen_tensor, pad_data_format);
535
536 conv_input = &maybe_padded_conv_input;
537 conv_input_rows = new_conv_input_rows;
538 conv_input_cols = new_conv_input_cols;
539 }
540 }
541
542 Tensor maybe_transformed_conv_input, maybe_transformed_side_input;
543 Tensor maybe_transformed_output;
544 const Tensor* side_input = &side_input_param;
545 Tensor* output = output_param;
546
547 // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary
548 // and inefficient, but it is actually both a time and code size optimization,
549 // since 'is_int8x4' is a constexpr determined by the template parameter.
550 if (!is_int8x4 && data_format == FORMAT_NHWC) {
551 OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW<T, rank>(
552 ctx, *conv_input, batch_size, conv_input_rows,
553 conv_input_cols, conv_input_depth,
554 &maybe_transformed_conv_input, &conv_input)));
555 if (side_input_scale != 0) {
556 OP_REQUIRES_OK(
557 ctx, (TransformNHWCToNCHW<T, rank>(
558 ctx, side_input_param, batch_size, output_rows, output_cols,
559 output_depth, &maybe_transformed_side_input, &side_input)));
560 }
561 if (output_depth > 1) {
562 // Allocate a tensor for the NCHW output of the kernel and point output
563 // to it. Afterwards, we will transform it to NHWC while copying back to
564 // 'output_param'.
565 TensorShape nchw_shape = ShapeFromFormat(
566 FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth);
567 OP_REQUIRES_OK(ctx,
568 ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
569 &maybe_transformed_output));
570 output = &maybe_transformed_output;
571 }
572 }
573
574 constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4
575 : dnn::DataLayout::kBatchDepthYX;
576 constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
577 : dnn::FilterLayout::kOutputInputYX;
578 constexpr auto compute_data_format =
579 is_int8x4 ? FORMAT_NCHW_VECT_C : FORMAT_NCHW;
580
581 dnn::BatchDescriptor conv_input_desc;
582 conv_input_desc.set_count(batch_size)
583 .set_feature_map_count(conv_input_depth)
584 .set_height(conv_input_rows)
585 .set_width(conv_input_cols)
586 .set_layout(data_layout);
587 dnn::FilterDescriptor filter_desc;
588 filter_desc.set_input_filter_height(filter_rows)
589 .set_input_filter_width(filter_cols)
590 .set_input_feature_map_count(conv_input_depth)
591 .set_output_feature_map_count(output_depth)
592 .set_layout(filter_layout);
593 dnn::BatchDescriptor side_input_desc;
594 side_input_desc.set_count(batch_size)
595 .set_height(output_rows)
596 .set_width(output_cols)
597 .set_feature_map_count(output_depth)
598 .set_layout(data_layout);
599 dnn::BatchDescriptor bias_desc;
600 bias_desc.set_count(1)
601 .set_height(1)
602 .set_width(1)
603 .set_feature_map_count(output_depth)
604 .set_layout(dnn::DataLayout::kBatchDepthYX);
605 dnn::BatchDescriptor output_desc;
606 output_desc.set_count(batch_size)
607 .set_height(output_rows)
608 .set_width(output_cols)
609 .set_feature_map_count(output_depth)
610 .set_layout(data_layout);
611 dnn::ConvolutionDescriptor conv_desc;
612 CHECK_EQ(0, padding_rows % 2);
613 CHECK_EQ(0, padding_cols % 2);
614 conv_desc.set_vertical_filter_stride(row_stride)
615 .set_horizontal_filter_stride(col_stride)
616 .set_zero_padding_height(padding_rows / 2)
617 .set_zero_padding_width(padding_cols / 2);
618
619 Tensor maybe_transformed_filter;
620 const Tensor* filter = &filter_param;
621 // For qint8, we have already checked filter is OIHW_VECT_I in the
622 // constructor, but we need to test for is_int8x4 so the if block doesn't
623 // generate code for qint8.
624 if (!is_int8x4 && filter_format == FORMAT_HWIO) {
625 // Shuffle filter tensor from HWIO to OIHW:
626 OP_REQUIRES_OK(ctx, ctx->allocate_temp(
627 DataTypeToEnum<T>::value,
628 ShapeFromFilterFormat(
629 FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
630 &maybe_transformed_filter));
631 functor::TransformFilter<GPUDevice, T, int, 4>()(
632 ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
633 To32Bit(filter_param.tensor<T, 4>()),
634 To32Bit(maybe_transformed_filter.tensor<T, 4>()));
635 filter = &maybe_transformed_filter;
636 }
637
638 auto conv_input_ptr =
639 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
640 conv_input->template flat<T>().data()),
641 conv_input->template flat<T>().size());
642 auto filter_ptr =
643 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
644 filter->template flat<T>().data()),
645 filter->template flat<T>().size());
646 auto side_input_ptr =
647 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
648 side_input->template flat<T>().data()),
649 side_input->template flat<T>().size());
650 auto output_ptr =
651 AsDeviceMemory(reinterpret_cast<const typename RawType<T>::type*>(
652 output->template flat<T>().data()),
653 output->template flat<T>().size());
654 auto bias_ptr = AsDeviceMemory(bias.template flat<BiasType>().data(),
655 bias.template flat<BiasType>().size());
656
657 static int64 ConvolveScratchSize = GetDnnWorkspaceLimit(
658 // default value is in bytes despite the name of the environment variable
659 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
660 );
661
662 int device_id = stream->parent()->device_ordinal();
663 FusedConvParameters fused_conv_parameters = {
664 batch_size,
665 conv_input_depth,
666 {{conv_input_rows, conv_input_cols}},
667 compute_data_format,
668 output_depth,
669 {{filter_rows, filter_cols}},
670 // TODO(yangzihao): Add support for arbitrary dilations for fused conv.
671 {{1, 1}}, // dilation_rows, dilation_cols
672 {{row_stride, col_stride}},
673 {{padding_rows, padding_cols}},
674 conv_input->dtype(),
675 device_id,
676 (side_input_scale != 0),
677 activation_mode,
678 };
679
680 dnn::ActivationMode dnn_activation_mode;
681 switch (activation_mode) {
682 case ActivationMode::NONE:
683 dnn_activation_mode = dnn::ActivationMode::kNone;
684 break;
685 case ActivationMode::RELU:
686 dnn_activation_mode = dnn::ActivationMode::kRelu;
687 break;
688 default:
689 LOG(FATAL) << "Activation mode " << activation_mode << " not supported";
690 }
691
692 dnn::AlgorithmConfig algorithm_config;
693 if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
694 fused_conv_parameters, &algorithm_config)) {
695 std::vector<dnn::AlgorithmDesc> algorithms;
696 CHECK(stream->parent()->GetConvolveAlgorithms(
697 fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
698 stream->parent()),
699 &algorithms));
700 if (activation_mode == ActivationMode::NONE) {
701 // Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is supported for
702 // identity activation, other algs seem to quietly do Relu.
703 // See
704 // https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBiasActivationForward
705 algorithms.erase(
706 std::remove_if(
707 algorithms.begin(), algorithms.end(),
708 [](dnn::AlgorithmDesc alg) {
709 return alg.algo_id() !=
710 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
711 }),
712 algorithms.end());
713 }
714 std::vector<tensorflow::AutotuneResult> results;
715 for (auto profile_algorithm : algorithms) {
716 // TODO(zhengxq): profile each algorithm multiple times to better
717 // accuracy.
718 DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
719 dnn::ProfileResult profile_result;
720 bool cudnn_launch_status =
721 stream
722 ->ThenFusedConvolveWithAlgorithm(
723 conv_input_desc, conv_input_ptr, conv_input_scale,
724 filter_desc, filter_ptr, conv_desc, side_input_ptr,
725 side_input_scale, bias_desc, bias_ptr, dnn_activation_mode,
726 output_desc, &output_ptr, &scratch_allocator,
727 dnn::AlgorithmConfig(profile_algorithm), &profile_result)
728 .ok();
729 if (cudnn_launch_status) {
730 if (profile_result.is_valid()) {
731 results.emplace_back();
732 auto& result = results.back();
733 result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
734 result.mutable_conv()->set_tensor_ops_enabled(
735 profile_algorithm.tensor_ops_enabled());
736 result.mutable_success()->set_scratch_bytes(
737 scratch_allocator.TotalByteSize());
738 *result.mutable_success()->mutable_run_time() =
739 internal::ToDurationProto(
740 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
741 }
742 }
743 }
744 internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input,
745 *filter, *output, bias, side_input,
746 stream->parent(), results);
747 OP_REQUIRES_OK(
748 ctx, internal::BestCudnnConvAlgorithm(results, &algorithm_config));
749 AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
750 algorithm_config);
751 }
752
753 DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
754 bool cudnn_launch_status =
755 stream
756 ->ThenFusedConvolveWithAlgorithm(
757 conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
758 filter_ptr, conv_desc, side_input_ptr, side_input_scale,
759 bias_desc, bias_ptr, dnn_activation_mode, output_desc,
760 &output_ptr, &scratch_allocator, algorithm_config,
761 /*output_profile_result=*/nullptr)
762 .ok();
763
764 if (!cudnn_launch_status) {
765 ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(",
766 conv_input->shape().DebugString(),
767 ") filter shape(",
768 filter->shape().DebugString(), ")"));
769 }
770
771 // Convert the output tensor back from NCHW to NHWC if necessary.
772 if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) {
773 functor::NCHWToNHWC<GPUDevice, T, 4>()(
774 ctx->eigen_device<GPUDevice>(),
775 const_cast<const Tensor*>(output)->tensor<T, 4>(),
776 output_param->tensor<T, 4>());
777 }
778 }
779
780 // Forward declarations of the functor specializations for GPU used above.
781 namespace functor {
782 #define DECLARE_GPU_SPEC(T) \
783 template <> \
784 void PadInput<GPUDevice, T, int, 4>::operator()( \
785 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
786 const std::array<int, 2>& padding_left, \
787 const std::array<int, 2>& padding_right, \
788 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \
789 extern template struct PadInput<GPUDevice, T, int, 4>;
790
791 DECLARE_GPU_SPEC(float);
792 DECLARE_GPU_SPEC(int32);
793 #undef DECLARE_GPU_SPEC
794 } // namespace functor
795
796 // Registration of the GPU implementations.
797
798 REGISTER_KERNEL_BUILDER(
799 Name("FusedConv2DBiasActivation")
800 .Device(DEVICE_GPU)
801 .TypeConstraint<float>("T")
802 .TypeConstraint<float>("Tbias")
803 .HostMemory("conv_input_scale")
804 .HostMemory("side_input_scale"),
805 FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
806
807 REGISTER_KERNEL_BUILDER(
808 Name("FusedConv2DBiasActivation")
809 .Device(DEVICE_GPU)
810 .TypeConstraint<qint8>("T")
811 .TypeConstraint<float>("Tbias")
812 .HostMemory("conv_input_scale")
813 .HostMemory("side_input_scale"),
814 FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
815
816 #endif // GOOGLE_CUDA
817
818 } // namespace tensorflow
819