1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17
18 #include <functional>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/tensor_float_32_utils.h"
27 #include "tensorflow/core/util/env_var.h"
28 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
29 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
30 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
31 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
32 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
33 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
34 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
35 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
36 #include "tensorflow/stream_executor/dnn.h"
37 #include "tensorflow/stream_executor/lib/env.h"
38 #include "tensorflow/stream_executor/lib/error.h"
39 #include "tensorflow/stream_executor/lib/initialize.h"
40 #include "tensorflow/stream_executor/lib/mathutil.h"
41 #include "tensorflow/stream_executor/lib/threadpool.h"
42 #include "tensorflow/stream_executor/platform/logging.h"
43 #include "tensorflow/stream_executor/plugin_registry.h"
44 #include "tensorflow/stream_executor/scratch_allocator.h"
45 #include "tensorflow/stream_executor/stream.h"
46 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
47 // clang-format off
48 #include "third_party/gpus/cudnn/cudnn.h"
49 #include "absl/strings/string_view.h"
50 // clang-format on
51
52 #pragma clang diagnostic push
53
54 // Make sure that Eigen::half forward declaration in dnn.h matches the
55 // declaration in Eigen.
56 #pragma clang diagnostic warning "-Wmismatched-tags"
57
58 namespace stream_executor {
59 namespace gpu {
60
61 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
62
63 namespace {
64
65 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
66
67 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
68 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
69
70 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
71 // function with a non-successful port::Status.
72 #define RETURN_IF_CUDNN_ERROR(expr) \
73 do { \
74 cudnnStatus_t _status = expr; \
75 if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
76 std::ostringstream oss; \
77 oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
78 << "): '" << #expr << "'"; \
79 return port::Status(port::error::UNKNOWN, oss.str().c_str()); \
80 } \
81 } while (false)
82
83 // Converts (via narrowing) a type T value to a type U, and checks that the
84 // value has no value change due to the conversion.
85 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)86 NarrowT CheckedNarrowing(const WideT& wide) {
87 NarrowT narrow = wide;
88 CHECK_EQ(narrow, wide)
89 << "checked narrowing failed; values not equal post-conversion";
90 return narrow;
91 }
92
ToString(cudnnStatus_t status)93 std::string ToString(cudnnStatus_t status) {
94 switch (status) {
95 case CUDNN_STATUS_SUCCESS:
96 return "CUDNN_STATUS_SUCCESS";
97 case CUDNN_STATUS_NOT_INITIALIZED:
98 return "CUDNN_STATUS_NOT_INITIALIZED";
99 case CUDNN_STATUS_ALLOC_FAILED:
100 return "CUDNN_STATUS_ALLOC_FAILED";
101 case CUDNN_STATUS_BAD_PARAM:
102 return "CUDNN_STATUS_BAD_PARAM";
103 case CUDNN_STATUS_INTERNAL_ERROR:
104 return "CUDNN_STATUS_INTERNAL_ERROR";
105 case CUDNN_STATUS_INVALID_VALUE:
106 return "CUDNN_STATUS_INVALID_VALUE";
107 case CUDNN_STATUS_ARCH_MISMATCH:
108 return "CUDNN_STATUS_ARCH_MISMATCH";
109 case CUDNN_STATUS_MAPPING_ERROR:
110 return "CUDNN_STATUS_MAPPING_ERROR";
111 case CUDNN_STATUS_EXECUTION_FAILED:
112 return "CUDNN_STATUS_EXECUTION_FAILED";
113 case CUDNN_STATUS_NOT_SUPPORTED:
114 return "CUDNN_STATUS_NOT_SUPPORTED";
115 case CUDNN_STATUS_LICENSE_ERROR:
116 return "CUDNN_STATUS_LICENSE_ERROR";
117 case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
118 return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
119 case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
120 return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
121 case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
122 return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
123 default:
124 return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
125 ">");
126 }
127 }
128
129 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
130 //
131 // See CudnnAccess::GetHandle() for details.
132 class CudnnHandle {
133 public:
134 // Takes ownership of the executor context and the lock to access cuDNN
135 // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)136 CudnnHandle(gpu::ScopedActivateExecutorContext context,
137 std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
138 : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
139
140 // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
141 // a copy.
handle() const142 cudnnHandle_t handle() const { return handle_; }
143
144 private:
145 gpu::ScopedActivateExecutorContext context_;
146 std::unique_ptr<absl::MutexLock> lock_;
147 cudnnHandle_t handle_; // Not owned.
148 };
149
150 } // namespace
151
152 // Wraps a cuDNN handle and provides access to it through CudnnHandle
153 // instances, which also locks a mutex, acquires the CUDA context, and sets
154 // the stream that cuDNN should use to enqueue any work.
155 //
156 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
157 class CudnnAccess {
158 public:
159 // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)160 explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
161
~CudnnAccess()162 ~CudnnAccess() {
163 absl::MutexLock lock(&mutex_);
164 cudnnDestroy(handle_);
165 }
166
167 // Creates a CudnnHandle instance for stream.
168 //
169 // cuDNN API calls using the same handle instance need to be serialized
170 // across threads. This is guaranteed by CudnnHandle instances locking the
171 // mutex owned by this class.
172 //
173 // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
174 // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
175 // to use the provided stream.
176 //
177 // The stream argument may be null, which translates to the legacy default
178 // stream. See
179 // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
180 // The legacy default stream synchronizes with all other streams and it is
181 // therefore a bad idea (performance wise) to call any cuDNN APIs that
182 // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)183 CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
184 auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
185 mutex_.AssertHeld();
186 gpu::ScopedActivateExecutorContext context(executor);
187 CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188 const auto status = cudnnSetStream(handle_, cu_stream);
189 CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190 return CudnnHandle(std::move(context), std::move(lock), handle_);
191 }
192
193 private:
194 // Guards the enqueueing of cuDNN operations via the handle_ below.
195 absl::Mutex mutex_;
196
197 // cuDNN library handle.
198 cudnnHandle_t handle_ TF_GUARDED_BY(mutex_); // Owned.
199 };
200
201 namespace {
202
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208 cudnnConvolutionFwdAlgo_t algo =
209 cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210 switch (algo) {
211 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219 return algo;
220 default:
221 LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222 << algorithm.algo_id();
223 }
224 }
225
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227 dnn::AlgorithmDesc algorithm) {
228 cudnnConvolutionBwdDataAlgo_t algo =
229 cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230 switch (algo) {
231 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237 return algo;
238 default:
239 LOG(FATAL)
240 << "Unsupported Cudnn convolution backward algorithm for data: "
241 << algorithm.algo_id();
242 }
243 }
244
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246 dnn::AlgorithmDesc algorithm) {
247 cudnnConvolutionBwdFilterAlgo_t algo =
248 cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249 switch (algo) {
250 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254 // Based on cudnn.h, the following is not implemented.
255 // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
258 return algo;
259 default:
260 LOG(FATAL)
261 << "Unsupported Cudnn convolution backward algorithm for filter: "
262 << algorithm.algo_id();
263 }
264 }
265
GetCudnnProperty(libraryPropertyType type)266 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
267 int value;
268 RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
269 return value;
270 }
271
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)272 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
273 if (!algorithm.has_value()) {
274 return CUDNN_RNN_ALGO_STANDARD;
275 }
276 cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
277 switch (algo) {
278 case CUDNN_RNN_ALGO_STANDARD:
279 case CUDNN_RNN_ALGO_PERSIST_STATIC:
280 case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
281 return algo;
282 default:
283 LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
284 }
285 }
286
GetLoadedCudnnVersion(CudnnVersion * version)287 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
288 SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
289 SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
290 SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
291 return port::Status::OK();
292 }
293
294 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
PreloadCudnnLibrary(cudnnStatus_t (* version_check_fn)(),absl::string_view sub_library)295 void PreloadCudnnLibrary(cudnnStatus_t (*version_check_fn)(),
296 absl::string_view sub_library) {
297 cudnnStatus_t status = version_check_fn();
298 if (status != CUDNN_STATUS_SUCCESS) {
299 VLOG(1) << "Could not pre-initialize cuDNN sub-library " << sub_library
300 << ". Error: " << cudnnGetErrorString(status) << ".";
301 }
302 }
303 #endif
304
305 } // namespace
306
CudnnSupport(GpuExecutor * parent)307 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
308
Init()309 port::Status CudnnSupport::Init() {
310 ScopedActivateExecutorContext context(parent_);
311 cudnnHandle_t cudnn_handle = nullptr;
312 const auto status = cudnnCreate(&cudnn_handle);
313 if (status == CUDNN_STATUS_SUCCESS) {
314 CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
315
316 CudnnVersion loaded_version;
317 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
318 if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
319 const std::string error = absl::StrCat(
320 "Loaded runtime CuDNN library: ", loaded_version.ToString(),
321 " but source was compiled with: ", source_version.ToString(),
322 ". CuDNN library needs to have matching major version and equal or "
323 "higher minor version. If using a binary install, upgrade your CuDNN "
324 "library. If building from sources, make sure the library loaded at "
325 "runtime is compatible with the version specified during compile "
326 "configuration.");
327 LOG(ERROR) << error;
328 cudnnDestroy(cudnn_handle);
329 return port::Status(port::error::INTERNAL, error);
330 }
331
332 cudnn_.reset(new CudnnAccess(cudnn_handle));
333
334 LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
335 return port::Status::OK();
336 }
337
338 CHECK_EQ(cudnn_handle, nullptr);
339 LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
340 if (status == CUDNN_STATUS_NOT_INITIALIZED) {
341 auto result = gpu::Diagnostician::FindKernelDriverVersion();
342 if (!result.ok()) {
343 LOG(ERROR) << "Error retrieving driver version: "
344 << cuda::DriverVersionStatusToString(result);
345 } else {
346 const auto& version = result.ValueOrDie();
347 LOG(ERROR) << "Possibly insufficient driver version: "
348 << cuda::DriverVersionToString(version);
349 }
350 }
351
352 return port::Status(port::error::INTERNAL,
353 absl::StrCat("cudnn library could not create a handle: ",
354 ToString(status)));
355 }
356
357 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()358 CudnnSupport::GetVersion() {
359 CudnnVersion version;
360 TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
361 return perftools::gputools::dnn::VersionInfo(
362 version.major_version, version.minor_version, version.patch_level);
363 }
364
365 namespace {
366
367 // Deleter functors for cuDNN types that need to be deleted.
368 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::TensorDescriptorDeleter369 void operator()(cudnnTensorDescriptor_t descriptor) const {
370 CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
371 }
372 };
373 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::RNNDataDescriptorDeleter374 void operator()(cudnnRNNDataDescriptor_t descriptor) const {
375 CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
376 }
377 };
378 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::FilterDescriptorDeleter379 void operator()(cudnnFilterDescriptor_t descriptor) const {
380 CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
381 }
382 };
383 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::ConvolutionDescriptorDeleter384 void operator()(cudnnConvolutionDescriptor_t descriptor) const {
385 CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
386 }
387 };
388 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::PoolingDescriptorDeleter389 void operator()(cudnnPoolingDescriptor_t descriptor) const {
390 CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
391 }
392 };
393 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::LrnDescriptorDeleter394 void operator()(cudnnLRNDescriptor_t descriptor) const {
395 CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
396 }
397 };
398
399 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::ActivationDescriptorDeleter400 void operator()(cudnnActivationDescriptor_t descriptor) const {
401 CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
402 }
403 };
404 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::DropoutDescriptorDeleter405 void operator()(cudnnDropoutDescriptor_t descriptor) const {
406 CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
407 }
408 };
409 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::RnnDescriptorDeleter410 void operator()(cudnnRNNDescriptor_t descriptor) const {
411 CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
412 }
413 };
414 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::PersistentRnnPlanDeleter415 void operator()(cudnnPersistentRNNPlan_t plan) const {
416 CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
417 }
418 };
419 #if CUDNN_VERSION >= 7603
420 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::CtcLossDescriptorDeleter421 void operator()(cudnnCTCLossDescriptor_t descriptor) const {
422 CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
423 }
424 };
425 #endif
426
427 // RAII wrappers for cuDNN types.
428 using TensorDescriptor =
429 std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
430 using RNNDataDescriptor =
431 std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
432 using FilterDescriptor =
433 std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
434 using ConvolutionDescriptor =
435 std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
436 using PoolingDescriptor =
437 std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
438 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
439 using ActivationDescriptor =
440 std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
441 using DropoutDescriptor =
442 std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
443 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
444 using PersistentRnnPlan =
445 std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
446 #if CUDNN_VERSION >= 7603
447 using CtcLossDescriptor =
448 std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
449 #endif
450
451 // Factory methods for cuDNN types.
CreateTensorDescriptor()452 TensorDescriptor CreateTensorDescriptor() {
453 cudnnTensorDescriptor_t result;
454 CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
455 return TensorDescriptor(result);
456 }
CreateRNNDataDescriptor()457 RNNDataDescriptor CreateRNNDataDescriptor() {
458 cudnnRNNDataDescriptor_t result;
459 CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
460 return RNNDataDescriptor(result);
461 }
CreateFilterDescriptor()462 FilterDescriptor CreateFilterDescriptor() {
463 cudnnFilterDescriptor_t result;
464 CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
465 return FilterDescriptor(result);
466 }
CreateConvolutionDescriptor()467 ConvolutionDescriptor CreateConvolutionDescriptor() {
468 cudnnConvolutionDescriptor_t result;
469 CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
470 return ConvolutionDescriptor(result);
471 }
CreatePoolingDescriptor()472 PoolingDescriptor CreatePoolingDescriptor() {
473 cudnnPoolingDescriptor_t result;
474 CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
475 return PoolingDescriptor(result);
476 }
CreateLrnDescriptor()477 LrnDescriptor CreateLrnDescriptor() {
478 cudnnLRNDescriptor_t result;
479 CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
480 return LrnDescriptor(result);
481 }
CreateActivationDescriptor()482 ActivationDescriptor CreateActivationDescriptor() {
483 cudnnActivationDescriptor_t result;
484 CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
485 return ActivationDescriptor(result);
486 }
CreateDropoutDescriptor()487 DropoutDescriptor CreateDropoutDescriptor() {
488 cudnnDropoutDescriptor_t result;
489 CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
490 return DropoutDescriptor(result);
491 }
CreateRnnDescriptor()492 RnnDescriptor CreateRnnDescriptor() {
493 cudnnRNNDescriptor_t result;
494 CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
495 return RnnDescriptor(result);
496 }
497 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()498 CtcLossDescriptor CreateCtcLossDescriptor() {
499 cudnnCTCLossDescriptor_t result;
500 CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
501 return CtcLossDescriptor(result);
502 }
503 #endif
504
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)505 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
506 cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
507 cudnnPersistentRNNPlan_t result;
508 RETURN_IF_CUDNN_ERROR(
509 cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
510 return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
511 }
512
513 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
514 // scope.
515 class CudnnTensorDescriptor {
516 public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)517 CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
518 cudnnDataType_t elem_type)
519 : handle_(CreateTensorDescriptor()) {
520 switch (batch_descriptor.layout()) {
521 case dnn::DataLayout::kBatchYXDepth:
522 case dnn::DataLayout::kBatchDepthYX: {
523 const int nd = batch_descriptor.ndims() + 2;
524 // cuDNN requires the strides and dims to be ordered as BDYX.
525 std::vector<int64> strides64 =
526 batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
527 std::vector<int64> dims64 =
528 batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
529
530 // cuDNN requires arrays of ints.
531 std::vector<int> strides(nd);
532 std::vector<int> dims(nd);
533 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
534 &CheckedNarrowing<int64, int>);
535 std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
536 &CheckedNarrowing<int64, int>);
537 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
538 dims.data(), strides.data()))
539 << "batch_descriptor: " << batch_descriptor.ToString();
540 } break;
541 case dnn::DataLayout::kBatchDepthYX4: {
542 CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
543 handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
544 batch_descriptor.count(), batch_descriptor.feature_map_count(),
545 batch_descriptor.height(), batch_descriptor.width()))
546 << "batch_descriptor: " << batch_descriptor.ToString();
547 } break;
548 default:
549 LOG(FATAL) << "Unsupported tensor format "
550 << DataLayoutString(batch_descriptor.layout());
551 break;
552 }
553 }
554
handle() const555 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
556
557 private:
558 TensorDescriptor handle_;
559
560 SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
561 };
562
563 // Turns a FilterDescriptor structure into a cudnn filter handle within a
564 // scope.
565 class CudnnFilterDescriptor {
566 public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)567 CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
568 cudnnDataType_t elem_type)
569 : handle_(CreateFilterDescriptor()) {
570 // TODO(b/23032134): Even if the filter layout is not supported,
571 // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
572 // it does not take layout as an input. Maybe force cuDNN by giving wrong
573 // inputs intentionally?
574 cudnnTensorFormat_t format;
575 switch (filter_descriptor.layout()) {
576 case dnn::FilterLayout::kOutputInputYX:
577 format = CUDNN_TENSOR_NCHW;
578 break;
579 case dnn::FilterLayout::kOutputYXInput:
580 format = CUDNN_TENSOR_NHWC;
581 break;
582 case dnn::FilterLayout::kOutputInputYX4:
583 format = CUDNN_TENSOR_NCHW_VECT_C;
584 break;
585 default:
586 LOG(FATAL) << "Unsupported filter format "
587 << FilterLayoutString(filter_descriptor.layout());
588 break;
589 }
590
591 std::vector<int> dims(2 + filter_descriptor.ndims());
592 dims[0] = filter_descriptor.output_feature_map_count();
593 dims[1] = filter_descriptor.input_feature_map_count();
594 absl::Span<const int64> spatial_dims =
595 filter_descriptor.input_filter_dims();
596 std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
597
598 CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
599 dims.size(), dims.data()));
600 }
601
handle() const602 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
603
604 private:
605 FilterDescriptor handle_; // Owned.
606
607 SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
608 };
609
610 // A helper function to decide whether to use
611 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
612 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
613 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
614 // reason we set it to false by default is that this mode may use scaled
615 // atomic integer reduction that may cause a numerical overflow for certain
616 // input data range.
617 // TODO(yangzihao): Use autotune to choose between this mode and
618 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()619 bool BatchnormSpatialPersistentEnabled() {
620 static bool is_enabled = [] {
621 bool is_enabled = false;
622 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
623 "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
624 /*default_val=*/false, &is_enabled));
625 return is_enabled;
626 }();
627 return is_enabled;
628 }
629
630 // The following function allows deterministic ops to be implemented relatively
631 // quickly using environment variables. It is intended to be temporary. The
632 // longer-term intention is to enable deterministic ops via tf.config and
633 // appropriate plumbing. See the discussion on PR 34951 for more information:
634 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
635 // This function and associated comment are replicated in the following three
636 // places:
637 // 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
638 // 2. tensorflow/core/kernels/gpu_utils.cc
639 // 3. tensorflow/stream_executor/cuda/cuda_dnn.cc
640 // When implementing the plumbing, you should also search for the use of
641 // TF_DETERMINISTIC_OPS on its own.
642 // TODO(duncanriach): move to an API that uses tf.config and implement the first
643 // phase of plumbing.
RequireCudnnDeterminism()644 bool RequireCudnnDeterminism() {
645 static bool require_cudnn_determinism = [] {
646 bool deterministic_ops = false;
647 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
648 /*default_val=*/false,
649 &deterministic_ops));
650 bool cudnn_deterministic = false;
651 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
652 /*default_val=*/false,
653 &cudnn_deterministic));
654 return deterministic_ops || cudnn_deterministic;
655 }();
656 return require_cudnn_determinism;
657 }
658
659 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()660 bool ConvUseDefaultAlgorithm() {
661 static bool use_default = [] {
662 bool use_default = false;
663 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
664 /*default_val=*/false,
665 &use_default));
666 return use_default;
667 }();
668 return use_default;
669 }
670
GetCcMajorMinor(Stream * stream)671 std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
672 int cc_major, cc_minor;
673 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
674 &cc_minor);
675 return std::make_tuple(cc_major, cc_minor);
676 }
677
678 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
679 // within a scope.
680 class CudnnConvolutionDescriptor {
681 public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)682 CudnnConvolutionDescriptor(
683 const dnn::ConvolutionDescriptor& convolution_descriptor,
684 cudnnDataType_t data_type)
685 : handle_(CreateConvolutionDescriptor()) {
686 absl::Span<const int64> strides64 = convolution_descriptor.strides();
687 absl::Span<const int64> padding64 = convolution_descriptor.padding();
688 absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
689 CHECK_NE(convolution_descriptor.pad_alignment(),
690 dnn::PadAlignment::kTensorFlowPadding)
691 << "TensorFlow padding alignment is not supported.";
692
693 // cuDNN requires arrays of ints.
694 std::vector<int> strides(convolution_descriptor.ndims());
695 std::vector<int> padding(convolution_descriptor.ndims());
696 std::vector<int> dilations(convolution_descriptor.ndims());
697 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
698 &CheckedNarrowing<int64, int>);
699 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
700 &CheckedNarrowing<int64, int>);
701 // TODO(yangzihao): Test with negative dilation to make sure that cudnn
702 // doesn't crash.
703 std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
704 &CheckedNarrowing<int64, int>);
705
706 CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
707 handle_.get(), convolution_descriptor.ndims(), padding.data(),
708 strides.data(), dilations.data(),
709 convolution_descriptor.convolution_not_crosscorr()
710 ? CUDNN_CONVOLUTION
711 : CUDNN_CROSS_CORRELATION,
712 data_type));
713
714 #if CUDNN_MAJOR >= 7
715 VLOG(2) << "Requesting grouped convolution: "
716 << convolution_descriptor.group_count();
717 CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
718 handle_.get(), convolution_descriptor.group_count()));
719 #else
720 CHECK_EQ(convolution_descriptor.group_count(), 1)
721 << "Requested grouped convolution for cuDNN version < 7";
722 #endif
723 }
724
set_use_tensor_op_math(bool use_tensor_op_math)725 void set_use_tensor_op_math(bool use_tensor_op_math) {
726 cudnnMathType_t math_type =
727 #if CUDNN_VERSION >= 8000
728 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
729 #else
730 (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
731 #endif
732 CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
733 }
734
handle() const735 cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
736
737 private:
738 ConvolutionDescriptor handle_; // Owned.
739
740 SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
741 };
742
743 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
744 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)745 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
746 cudnnMathType_t math_type;
747 CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
748 #if CUDNN_VERSION >= 8000
749 return math_type != CUDNN_FMA_MATH;
750 #else
751 return math_type == CUDNN_TENSOR_OP_MATH;
752 #endif
753 }
754
TensorOpMathAvailable(int cc_major)755 static bool TensorOpMathAvailable(int cc_major) { return cc_major >= 7; }
756
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)757 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
758 int cc_major, cc_minor;
759 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
760 if (!TensorOpMathAvailable(cc_major)) {
761 return false;
762 }
763 if (input_type == dnn::DataType::kFloat) {
764 #if CUDNN_VERSION < 8000
765 return false;
766 #else
767 if (!tensorflow::tensor_float_32_execution_enabled()) {
768 return false;
769 }
770 #endif
771 }
772 return true;
773 }
774
775 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
776 // within a scope.
777 class CudnnPoolingDescriptor {
778 public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)779 explicit CudnnPoolingDescriptor(
780 const dnn::PoolingDescriptor& pooling_descriptor)
781 : handle_(CreatePoolingDescriptor()) {
782 absl::Span<const int64> strides64 = pooling_descriptor.strides();
783 absl::Span<const int64> padding64 = pooling_descriptor.padding();
784 absl::Span<const int64> shape64 = pooling_descriptor.window();
785
786 const int nd = pooling_descriptor.ndims();
787 std::vector<int> shape(nd);
788 std::vector<int> padding(nd);
789 std::vector<int> strides(nd);
790 std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
791 &CheckedNarrowing<int64, int>);
792 std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
793 &CheckedNarrowing<int64, int>);
794 std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
795 &CheckedNarrowing<int64, int>);
796 bool propagate_nans = pooling_descriptor.propagate_nans();
797 const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
798 ? CUDNN_POOLING_MAX_DETERMINISTIC
799 : CUDNN_POOLING_MAX;
800 CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
801 handle_.get(),
802 (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
803 ? cudnn_max_pooling_mode
804 : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
805 propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
806 shape.data(), padding.data(), strides.data()));
807 }
808
handle() const809 cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
810
811 private:
812 PoolingDescriptor handle_; // Owned.
813
814 SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
815 };
816
817 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
818 class CudnnNormalizeDescriptor {
819 public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)820 explicit CudnnNormalizeDescriptor(
821 const dnn::NormalizeDescriptor& normalize_descriptor)
822 : handle_(CreateLrnDescriptor()) {
823 // The range specifies that the indices in the closed range
824 // [i - range, i + range] should be included in the normalization for index
825 // i. The lrnN value is the total number of elements in the range, so
826 // lrnN = 2*range + 1.
827 unsigned lrnN = 2 * normalize_descriptor.range() + 1;
828
829 // Note that SE defines the normalization operation as
830 //
831 // U_i = V_i / ((bias + alpha * (sum_j V_j^2)) ^ beta)
832 //
833 // but cuDNN defines it as
834 //
835 // U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
836 //
837 // i.e. there is a factor of n difference between the meaning of the alphas
838 // in the two contexts. The cuDNN alpha is n times the SE alpha.
839 double lrnAlpha = lrnN * normalize_descriptor.alpha();
840
841 double lrnBeta = normalize_descriptor.beta();
842 double lrnK = normalize_descriptor.bias();
843 CHECK_CUDNN_OK(
844 cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
845 }
846
handle() const847 cudnnLRNDescriptor_t handle() const { return handle_.get(); }
848
849 private:
850 LrnDescriptor handle_; // Owned.
851
852 SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
853 };
854
855 // Turns a ActivationDescriptor structure into a cudnn activation
856 // descriptor handle within a scope.
857 class CudnnActivationDescriptor {
858 public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)859 CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
860 cudnnNanPropagation_t nan_propagation,
861 double value_max)
862 : handle_(CreateActivationDescriptor()) {
863 double relu_ceiling = 0.0;
864 cudnnActivationMode_t mode;
865 switch (activation_mode) {
866 case dnn::ActivationMode::kNone:
867 mode = CUDNN_ACTIVATION_IDENTITY;
868 break;
869 case dnn::ActivationMode::kRelu6:
870 relu_ceiling = 6.0;
871 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
872 break;
873 case dnn::ActivationMode::kReluX:
874 relu_ceiling = value_max;
875 mode = CUDNN_ACTIVATION_CLIPPED_RELU;
876 break;
877 case dnn::ActivationMode::kRelu:
878 mode = CUDNN_ACTIVATION_RELU;
879 break;
880 case dnn::ActivationMode::kSigmoid:
881 mode = CUDNN_ACTIVATION_SIGMOID;
882 break;
883 case dnn::ActivationMode::kTanh:
884 mode = CUDNN_ACTIVATION_TANH;
885 break;
886 default:
887 LOG(FATAL) << "unrecognized activation mode: "
888 << static_cast<int>(activation_mode);
889 }
890
891 CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
892 nan_propagation, relu_ceiling));
893 }
894
handle() const895 cudnnActivationDescriptor_t handle() const { return handle_.get(); }
896
897 private:
898 ActivationDescriptor handle_; // Owned.
899
900 SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
901 };
902
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)903 cudnnDataType_t ToCudnnDataType(
904 dnn::DataType data_type,
905 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
906 switch (data_type) {
907 case dnn::DataType::kFloat:
908 return CUDNN_DATA_FLOAT;
909 case dnn::DataType::kDouble:
910 return CUDNN_DATA_DOUBLE;
911 case dnn::DataType::kHalf:
912 return CUDNN_DATA_HALF;
913 case dnn::DataType::kInt8:
914 return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
915 : CUDNN_DATA_INT8;
916 case dnn::DataType::kInt32:
917 return CUDNN_DATA_INT32;
918 default:
919 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
920 }
921 }
922
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)923 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
924 dnn::FilterLayout filter_layout) {
925 if (data_type == dnn::DataType::kInt8 &&
926 filter_layout == dnn::FilterLayout::kOutputInputYX4) {
927 return CUDNN_DATA_INT8x4;
928 }
929 return ToCudnnDataType(data_type);
930 }
931
932 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)933 cudnnDataType_t GetCudnnDataType(
934 dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
935 return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
936 }
937
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)938 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
939 switch (input_mode) {
940 case dnn::RnnInputMode::kRnnLinearSkip:
941 case dnn::RnnInputMode::kRnnSkipInput:
942 return static_cast<cudnnRNNInputMode_t>(input_mode);
943 default:
944 LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
945 }
946 }
947
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)948 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
949 dnn::RnnDirectionMode direction_mode) {
950 switch (direction_mode) {
951 case dnn::RnnDirectionMode::kRnnUnidirectional:
952 case dnn::RnnDirectionMode::kRnnBidirectional:
953 return static_cast<cudnnDirectionMode_t>(direction_mode);
954 default:
955 LOG(FATAL) << "Invalid RNN direction mode: "
956 << static_cast<int>(direction_mode);
957 }
958 }
959
ToCudnnRnnMode(dnn::RnnMode rnn_mode)960 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
961 switch (rnn_mode) {
962 case dnn::RnnMode::kRnnRelu:
963 case dnn::RnnMode::kRnnTanh:
964 case dnn::RnnMode::kRnnLstm:
965 case dnn::RnnMode::kRnnGru:
966 return static_cast<cudnnRNNMode_t>(rnn_mode);
967 default:
968 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
969 }
970 }
971
CudnnDataTypeToByteSize(cudnnDataType_t data_type)972 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
973 switch (data_type) {
974 case CUDNN_DATA_FLOAT:
975 return sizeof(float);
976 case CUDNN_DATA_DOUBLE:
977 return sizeof(double);
978 case CUDNN_DATA_HALF:
979 return sizeof(Eigen::half);
980 default:
981 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
982 }
983 }
984
985 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)986 explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
987 : handle_(std::move(handle)) {}
988
989 public:
990 CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
991
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)992 static port::StatusOr<CudnnDropoutDescriptor> Create(
993 const CudnnHandle& cudnn, float dropout, uint64 seed,
994 ScratchAllocator* state_allocator) {
995 DropoutDescriptor handle = CreateDropoutDescriptor();
996
997 if (dropout == 0.0f) {
998 // Return 'empty' dropout descriptor.
999 return CudnnDropoutDescriptor(std::move(handle));
1000 }
1001
1002 DeviceMemory<uint8> state_memory;
1003 if (state_allocator) {
1004 size_t state_sizes_in_bytes = 0;
1005 RETURN_IF_CUDNN_ERROR(
1006 cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1007 SE_ASSIGN_OR_RETURN(state_memory,
1008 state_allocator->AllocateBytes(state_sizes_in_bytes));
1009 }
1010 RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1011 handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1012 state_memory.size(), seed));
1013
1014 return CudnnDropoutDescriptor(std::move(handle));
1015 }
1016
handle() const1017 cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1018
1019 private:
1020 DropoutDescriptor handle_; // Owned.
1021 SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1022 };
1023
1024 class CudnnRnnParamsDescriptor {
1025 typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1026
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1027 CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
1028 ParamsRegions weights, ParamsRegions biases)
1029 : handle_(std::move(handle)),
1030 params_size_in_bytes_(params_size_in_bytes),
1031 weights_(std::move(weights)),
1032 biases_(std::move(biases)) {}
1033
1034 public:
1035 CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1036
1037 static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1038 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1039 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1040 cudnnDirectionMode_t direction_mode, int num_layers);
1041
handle() const1042 cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1043 int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1044 ParamsRegions params_weights() const { return weights_; }
params_biases() const1045 ParamsRegions params_biases() const { return biases_; }
1046
1047 private:
1048 FilterDescriptor handle_;
1049 int64 params_size_in_bytes_;
1050 ParamsRegions weights_;
1051 ParamsRegions biases_;
1052 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1053 };
1054
1055 } // namespace
1056
1057 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1058 CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1059 PersistentRnnPlan rnn_plan, int num_layers,
1060 int hidden_size, int input_size, int cell_size,
1061 int batch_size, cudnnRNNInputMode_t input_mode,
1062 cudnnDirectionMode_t direction_mode,
1063 cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1064 cudnnDataType_t compute_type,
1065 const dnn::AlgorithmConfig& algorithm_config,
1066 CudnnDropoutDescriptor dropout_desc,
1067 CudnnRnnParamsDescriptor params_desc)
1068 : rnn_desc_(std::move(rnn_desc)),
1069 rnn_plan_(std::move(rnn_plan)),
1070 num_layers_(num_layers),
1071 hidden_size_(hidden_size),
1072 input_size_(input_size),
1073 cell_size_(cell_size),
1074 batch_size_(batch_size),
1075 rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1076 input_mode_(input_mode),
1077 direction_mode_(direction_mode),
1078 rnn_mode_(rnn_mode),
1079 data_type_(data_type),
1080 compute_type_(compute_type),
1081 algorithm_config_(algorithm_config),
1082 dropout_desc_(std::move(dropout_desc)),
1083 params_desc_(std::move(params_desc)) {}
1084
1085 public:
1086 CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1087
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1088 static port::StatusOr<CudnnRnnDescriptor> Create(
1089 const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1090 int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1091 cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1092 cudnnDataType_t data_type, cudnnDataType_t compute_type,
1093 const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1094 ScratchAllocator* state_allocator, bool use_padded_io) {
1095 SE_ASSIGN_OR_RETURN(
1096 CudnnDropoutDescriptor dropout_desc,
1097 CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1098
1099 gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1100 cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1101
1102 // TODO: allow the user to choose an algorithm.
1103 auto proj_size = hidden_size;
1104 hidden_size = std::max(hidden_size, cell_size);
1105
1106 // Require explicit algorithm config to enable tensor cores. Some configs
1107 // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1108 // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1109 // We can only reasonably expect the user to handle the subsequent failure
1110 // in profile mode, which is run with algorithms returned from
1111 // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1112 // use tensor ops). CuDNN 7.2.1 fixed this issue.
1113 // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1114 bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1115 if (data_type == CUDNN_DATA_FLOAT)
1116 allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1117 bool use_tensor_ops =
1118 algorithm_config.algorithm().has_value()
1119 ? algorithm_config.algorithm()->tensor_ops_enabled()
1120 : allow_tensor_ops;
1121 if (use_tensor_ops && !allow_tensor_ops) {
1122 return port::Status(port::error::INVALID_ARGUMENT,
1123 "Algo requests disallowed tensor op evaluation.");
1124 }
1125
1126 #if CUDNN_VERSION >= 8000
1127 cudnnMathType_t math_type =
1128 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1129 #else
1130 cudnnMathType_t math_type =
1131 use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1132 #endif
1133
1134 #if CUDNN_VERSION >= 8000
1135 cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1136 uint32_t aux_flags = 0;
1137 if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1138 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1139 /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1140 /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1141 /*inputMode=*/input_mode,
1142 /*dataType=*/data_type, /*mathPrec=*/compute_type,
1143 /*mathType=*/math_type,
1144 /*inputSize=*/input_size,
1145 /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1146 /*numLayers=*/num_layers,
1147 /*dropoutDesc=*/dropout_desc.handle(),
1148 /*auxFlags=*/aux_flags));
1149 #else
1150 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1151 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1152 /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1153 /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1154 /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1155 /*dataType=*/compute_type));
1156 CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1157
1158 if (proj_size < hidden_size) {
1159 RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1160 cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1161 /*recProjSize=*/proj_size, /*outProjSize=*/0));
1162 }
1163
1164 // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1165 // But in the future if these APIs are used to process full length arrays,
1166 // we need to distinguish when to set it.
1167 if (use_padded_io) {
1168 RETURN_IF_CUDNN_ERROR(
1169 cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1170 }
1171 #endif
1172
1173 port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1174 PersistentRnnPlan rnn_plan;
1175 if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1176 CHECK_GE(batch_size, 0);
1177 rnn_plan_wrapper =
1178 CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1179 if (!rnn_plan_wrapper.ok()) {
1180 return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1181 } else {
1182 rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1183 RETURN_IF_CUDNN_ERROR(
1184 cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1185 }
1186 }
1187
1188 // Create the params handle.
1189 SE_ASSIGN_OR_RETURN(auto params_desc,
1190 CudnnRnnParamsDescriptor::Create(
1191 cudnn, input_size, data_type, rnn_desc.get(),
1192 rnn_mode, direction_mode, num_layers));
1193
1194 return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1195 num_layers, hidden_size, input_size, cell_size,
1196 batch_size, input_mode, direction_mode, rnn_mode,
1197 data_type, compute_type, algorithm_config,
1198 std::move(dropout_desc), std::move(params_desc));
1199 }
1200
handle() const1201 cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1202 int num_layers() const { return num_layers_; }
hidden_size() const1203 int hidden_size() const { return hidden_size_; }
input_size() const1204 int input_size() const { return input_size_; }
cell_size() const1205 int cell_size() const { return cell_size_; }
batch_size() const1206 int batch_size() const { return batch_size_; }
input_mode() const1207 cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1208 cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1209 cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1210 cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1211 cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1212 const dnn::AlgorithmConfig& algorithm_config() const {
1213 return algorithm_config_;
1214 }
ParamsSizeInBytes() const1215 int64 ParamsSizeInBytes() const override {
1216 return params_desc_.params_size_in_bytes();
1217 }
params_handle() const1218 cudnnFilterDescriptor_t params_handle() const {
1219 return params_desc_.handle();
1220 }
ParamsWeightRegions() const1221 ParamsRegions ParamsWeightRegions() const override {
1222 return params_desc_.params_weights();
1223 }
ParamsBiasRegions() const1224 ParamsRegions ParamsBiasRegions() const override {
1225 return params_desc_.params_biases();
1226 }
1227
1228 private:
1229 gpu::RnnDescriptor rnn_desc_;
1230 PersistentRnnPlan rnn_plan_;
1231 int num_layers_;
1232 int hidden_size_;
1233 int input_size_;
1234 // cell_size_ is the size of cell state, which will be different from
1235 // hidden_size_ if the projection is used.
1236 int cell_size_;
1237 // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1238 // algorithm.
1239 int batch_size_;
1240 cudnnRNNAlgo_t rnn_algo_;
1241 cudnnRNNInputMode_t input_mode_;
1242 cudnnDirectionMode_t direction_mode_;
1243 cudnnRNNMode_t rnn_mode_;
1244 cudnnDataType_t data_type_;
1245 cudnnDataType_t compute_type_;
1246 dnn::AlgorithmConfig algorithm_config_;
1247 CudnnDropoutDescriptor dropout_desc_;
1248 CudnnRnnParamsDescriptor params_desc_;
1249 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1250 };
1251
1252 #if CUDNN_VERSION >= 7603
1253 class CudnnCtcLossDescriptor {
1254 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1255 explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1256 : handle_(CreateCtcLossDescriptor()) {
1257 CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1258 /*ctcLossDesc=*/handle_.get(),
1259 /*compType=*/data_type,
1260 /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1261 /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1262 }
1263
handle() const1264 cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1265
1266 private:
1267 CtcLossDescriptor handle_; // Owned
1268
1269 SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1270 };
1271 #else
1272 // dummy class
1273 class CudnnCtcLossDescriptor {
1274 public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1275 CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1276 };
1277 #endif
1278
1279 namespace {
1280
1281 // Check if the LSTM projection is used. If yes, an additional weight matrix
1282 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1283 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1284 port::Status CheckAndFetchProjectionWeights(
1285 const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1286 const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1287 const FilterDescriptor& region_desc_handle,
1288 dnn::RnnDescriptor::ParamsRegions* weights) {
1289 int hidden_size_v;
1290 int num_layers_v;
1291 cudnnDropoutDescriptor_t dropout_desc;
1292 cudnnRNNInputMode_t input_mode;
1293 cudnnDirectionMode_t direction;
1294 cudnnRNNMode_t mode;
1295 cudnnRNNAlgo_t algo;
1296 cudnnDataType_t data_type;
1297 #if CUDNN_VERSION >= 8000
1298 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1299 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1300 /*hiddenSize=*/&hidden_size_v,
1301 /*numLayers=*/&num_layers_v,
1302 /*dropoutDesc=*/&dropout_desc,
1303 /*inputMode=*/&input_mode,
1304 /*direction=*/&direction,
1305 /*mode=*/&mode,
1306 /*algo=*/&algo,
1307 /*mathPrec=*/&data_type));
1308 #else
1309 RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1310 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1311 /*hiddenSize=*/&hidden_size_v,
1312 /*numLayers=*/&num_layers_v,
1313 /*dropoutDesc=*/&dropout_desc,
1314 /*inputMode=*/&input_mode,
1315 /*direction=*/&direction,
1316 /*mode=*/&mode,
1317 /*algo=*/&algo,
1318 /*mathPrec=*/&data_type));
1319 #endif
1320 int rec_proj_size_v;
1321 int out_proj_size_v;
1322 RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1323 /*handle=*/cudnn.handle(),
1324 /*rnnDesc=*/rnn_desc,
1325 /*recProjSize*/ &rec_proj_size_v,
1326 /*outProjSize*/ &out_proj_size_v));
1327 if (rec_proj_size_v != hidden_size_v) {
1328 void* offset = nullptr;
1329 int region_id = 8;
1330 RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1331 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1332 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1333 /*wDesc=*/filter_desc.get(),
1334 /*w=*/nullptr, /*linLayerID=*/region_id,
1335 /*linLayerMatDesc=*/region_desc_handle.get(),
1336 /*linLayerMat or linLayerBias=*/&offset));
1337 int dims[] = {1, 1, 1};
1338 cudnnDataType_t data_type;
1339 cudnnTensorFormat_t tensor_format;
1340 int n_dims;
1341 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1342 /*filterDesc=*/region_desc_handle.get(),
1343 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1344 /*dataType=*/&data_type, /*format=*/&tensor_format,
1345 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1346 int64 size =
1347 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1348 dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1349 size};
1350 weights->push_back(region);
1351 }
1352 return port::Status::OK();
1353 }
1354
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1355 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1356 const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1357 cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1358 cudnnDirectionMode_t direction_mode, int num_layers) {
1359 // Query the params size.
1360 TensorDescriptor input_desc = CreateTensorDescriptor();
1361 int tensor_dims[] = {1, input_size, 1};
1362 int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1363 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1364 /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1365 /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1366 /*dimA=*/tensor_dims,
1367 /*strideA=*/strides));
1368
1369 size_t params_size = 0;
1370 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1371 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1372 /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size,
1373 /*dataType=*/data_type));
1374 int64 params_size_in_bytes = static_cast<int64>(params_size);
1375
1376 FilterDescriptor filter_desc = CreateFilterDescriptor();
1377 int64 filter_dim0 = params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1378 int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1379 RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1380 /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1381 /*format=*/CUDNN_TENSOR_NCHW,
1382 /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1383 /*filterDimA=*/filter_dims));
1384
1385 // Create the weights and biases into the params buffer
1386 int region_count_per_layer = [&] {
1387 switch (rnn_mode) {
1388 case CUDNN_RNN_RELU:
1389 case CUDNN_RNN_TANH:
1390 return 2;
1391 case CUDNN_LSTM:
1392 return 8;
1393 case CUDNN_GRU:
1394 return 6;
1395 default:
1396 LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1397 return 0;
1398 }
1399 }();
1400
1401 FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1402 const int layer_count =
1403 direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1404
1405 ParamsRegions weights;
1406 ParamsRegions biases;
1407
1408 for (int layer = 0; layer < layer_count; layer++) {
1409 for (int region = 0; region < region_count_per_layer; region++) {
1410 for (int type = 0; type < 2; type++) {
1411 void* offset = nullptr;
1412 RETURN_IF_CUDNN_ERROR(
1413 type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1414 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1415 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1416 /*wDesc=*/filter_desc.get(),
1417 /*w=*/nullptr, /*linLayerID=*/region,
1418 /*linLayerMatDesc=*/region_desc_handle.get(),
1419 /*linLayerMat or linLayerBias=*/&offset)
1420 : cudnnGetRNNLinLayerBiasParams(
1421 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1422 /*layer=*/layer, /*xDesc=*/input_desc.get(),
1423 /*wDesc=*/filter_desc.get(),
1424 /*w=*/nullptr, /*linLayerID=*/region,
1425 /*linLayerMatDesc=*/region_desc_handle.get(),
1426 /*linLayerMat or linLayerBias=*/&offset));
1427 int dims[] = {1, 1, 1};
1428 cudnnDataType_t data_type;
1429 cudnnTensorFormat_t tensor_format;
1430 int n_dims;
1431 RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1432 /*filterDesc=*/region_desc_handle.get(),
1433 /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1434 /*dataType=*/&data_type, /*format=*/&tensor_format,
1435 /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1436 int64 size =
1437 dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1438 dnn::RnnDescriptor::ParamsRegion region = {
1439 reinterpret_cast<int64>(offset), size};
1440 (type == 0 ? weights : biases).push_back(region);
1441 }
1442 }
1443 TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1444 cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1445 &weights));
1446 }
1447
1448 return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1449 weights, biases);
1450 }
1451
1452 } // namespace
1453
1454 class CudnnRnnSequenceTensorDescriptor
1455 : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1456 CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1457 int batch_size, int data_size,
1458 cudnnDataType_t data_type,
1459 RNNDataDescriptor data_handle,
1460 TensorDescriptor handle)
1461 : max_seq_length_(max_seq_length),
1462 batch_size_(batch_size),
1463 data_size_(data_size),
1464 data_type_(data_type),
1465 handle_(std::move(handle)),
1466 rnn_data_handle_(std::move(data_handle)),
1467 handles_(max_seq_length, handle_.get()) {
1468 }
1469
1470 public:
1471 CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1472 default;
1473
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1474 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1475 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1476 cudnnDataType_t data_type) {
1477 if (max_seq_length <= 0) {
1478 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1479 }
1480 int dims[] = {batch_size, data_size, 1};
1481 int strides[] = {dims[1] * dims[2], dims[2], 1};
1482 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1483 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1484 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1485 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1486 /*strideA=*/strides));
1487 return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1488 data_size, data_type,
1489 nullptr,
1490 std::move(tensor_desc));
1491 }
1492
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1493 static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1494 GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1495 const absl::Span<const int>& seq_lengths, bool time_major,
1496 cudnnDataType_t data_type) {
1497 if (max_seq_length <= 0) {
1498 return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1499 }
1500 int dims[] = {batch_size, data_size, 1};
1501 int strides[] = {dims[1] * dims[2], dims[2], 1};
1502 TensorDescriptor tensor_desc = CreateTensorDescriptor();
1503 RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1504 /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1505 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1506 /*strideA=*/strides));
1507 const int* seq_lengths_array = seq_lengths.data();
1508 RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1509 float padding_fill = 0.0f;
1510 cudnnRNNDataLayout_t layout;
1511 if (time_major) {
1512 layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1513 } else {
1514 layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1515 }
1516 RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1517 /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1518 /*layout=*/layout,
1519 /*maxSeqLength=*/max_seq_length,
1520 /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1521 /*seqLengthArray=*/seq_lengths_array,
1522 /*paddingFill*/ (void*)&padding_fill));
1523 return CudnnRnnSequenceTensorDescriptor(
1524 parent, max_seq_length, batch_size, data_size, data_type,
1525 std::move(data_desc), std::move(tensor_desc));
1526 }
1527
handles() const1528 const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1529 const cudnnRNNDataDescriptor_t data_handle() const {
1530 return rnn_data_handle_.get();
1531 }
1532
max_seq_length() const1533 int max_seq_length() const { return max_seq_length_; }
batch_size() const1534 int batch_size() const { return batch_size_; }
data_size() const1535 int data_size() const { return data_size_; }
is_var_seq_lengths() const1536 bool is_var_seq_lengths() const {
1537 return rnn_data_handle_ != nullptr;
1538 }
1539
1540 private:
1541 int max_seq_length_;
1542 int batch_size_;
1543 int data_size_;
1544 cudnnDataType_t data_type_;
1545 TensorDescriptor handle_;
1546 RNNDataDescriptor rnn_data_handle_;
1547 std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
1548 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1549 };
1550
1551 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1552 public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1553 CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1554 int batch_size, int data_size,
1555 cudnnDataType_t data_type)
1556 : handle_(CreateTensorDescriptor()),
1557 num_layers_(num_layers),
1558 batch_size_(batch_size),
1559 data_size_(data_size),
1560 data_type_(data_type) {
1561 int dims[] = {num_layers, batch_size, data_size};
1562 int strides[] = {dims[1] * dims[2], dims[2], 1};
1563 CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1564 /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1565 /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1566 /*strideA=*/strides));
1567 }
1568
handle() const1569 cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1570
num_layers() const1571 int num_layers() const { return num_layers_; }
batch_size() const1572 int batch_size() const { return batch_size_; }
data_size() const1573 int data_size() const { return data_size_; }
1574
1575 private:
1576 TensorDescriptor handle_;
1577 int num_layers_;
1578 int batch_size_;
1579 int data_size_;
1580 cudnnDataType_t data_type_;
1581 SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1582 };
1583
1584 namespace {
1585
1586 struct RnnModelDims {
1587 int num_layers = 0;
1588 int batch_size = 0;
1589 int max_seq_length = 0;
1590 int hidden_size = 0;
1591 int input_size = 0;
1592 int cell_size = 0;
1593 int dir_count = 0;
1594 };
1595
1596 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1597 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1598 const CudnnRnnDescriptor& rnn_desc,
1599 const CudnnRnnSequenceTensorDescriptor& input_desc,
1600 const DeviceMemory<T>& input_data,
1601 const CudnnRnnStateTensorDescriptor& input_h_desc,
1602 const DeviceMemory<T>& input_h_data,
1603 const CudnnRnnStateTensorDescriptor& input_c_desc,
1604 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1605 const CudnnRnnSequenceTensorDescriptor& output_desc,
1606 const DeviceMemory<T>& output_data,
1607 const CudnnRnnStateTensorDescriptor& output_h_desc,
1608 const DeviceMemory<T>& output_h_data,
1609 const CudnnRnnStateTensorDescriptor& output_c_desc,
1610 const DeviceMemory<T>& output_c_data) {
1611 // extract model parameters
1612 RnnModelDims model_dims;
1613 model_dims.num_layers = rnn_desc.num_layers();
1614 model_dims.batch_size = input_desc.batch_size();
1615 model_dims.max_seq_length = input_desc.max_seq_length();
1616 model_dims.hidden_size = rnn_desc.hidden_size();
1617 model_dims.input_size = input_desc.data_size();
1618 model_dims.cell_size = rnn_desc.cell_size();
1619 model_dims.dir_count =
1620 (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1621
1622 // check parameters
1623 if (!(input_h_desc.num_layers() ==
1624 model_dims.num_layers * model_dims.dir_count &&
1625 input_h_desc.batch_size() == model_dims.batch_size &&
1626 input_h_desc.data_size() == model_dims.hidden_size)) {
1627 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1628 }
1629 // The LSTM projection will be used if input_h_desc.data_size() <
1630 // input_c_desc.data_size()
1631 if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1632 input_h_desc.batch_size() == input_c_desc.batch_size() &&
1633 input_h_desc.data_size() <= input_c_desc.data_size())) {
1634 return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1635 }
1636 if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1637 output_desc.batch_size() == model_dims.batch_size &&
1638 output_desc.data_size() ==
1639 model_dims.hidden_size * model_dims.dir_count)) {
1640 return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1641 }
1642 if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1643 input_h_desc.batch_size() == output_h_desc.batch_size() &&
1644 input_h_desc.data_size() == output_h_desc.data_size())) {
1645 return port::Status(port::error::INVALID_ARGUMENT,
1646 "Invalid output_h shape");
1647 }
1648 if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1649 input_h_desc.batch_size() == output_c_desc.batch_size() &&
1650 input_h_desc.data_size() <= output_c_desc.data_size())) {
1651 return port::Status(port::error::INVALID_ARGUMENT,
1652 "Invalid output_c shape");
1653 }
1654
1655 return model_dims;
1656 }
1657
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1658 port::Status CheckRNNParameterSize(
1659 const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1660 const CudnnRnnSequenceTensorDescriptor& input_desc) {
1661 size_t params_size_in_bytes = 0;
1662 RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1663 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1664 /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes,
1665 /*dataType=*/rnn_desc.data_type()));
1666 if (static_cast<int64>(params_size_in_bytes) !=
1667 rnn_desc.ParamsSizeInBytes()) {
1668 return port::Status(port::error::INVALID_ARGUMENT,
1669 "Mismatching RNN parameter size");
1670 }
1671 return port::Status::OK();
1672 }
1673
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1674 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1675 Stream* stream, const CudnnHandle& cudnn,
1676 const CudnnRnnDescriptor& rnn_desc,
1677 const CudnnRnnSequenceTensorDescriptor& input_desc,
1678 ScratchAllocator* workspace_allocator) {
1679 // Query the workspace size.
1680 size_t workspace_size_in_bytes = 0;
1681 RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1682 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1683 /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1684 /*sizeInBytes=*/&workspace_size_in_bytes));
1685 // Allocate the workspace.
1686 if (workspace_size_in_bytes == 0) {
1687 return DeviceMemory<uint8>();
1688 }
1689 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1690 }
1691
1692 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1693 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1694 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1695 const cudnnBatchNormOps_t& bn_ops,
1696 const cudnnActivationDescriptor_t& activation_desc,
1697 const CudnnTensorDescriptor& x_descriptor,
1698 const CudnnTensorDescriptor& scale_offset_descriptor,
1699 ScratchAllocator* workspace_allocator) {
1700 // Query the workspace size.
1701 size_t workspace_size_in_bytes = 0;
1702 RETURN_IF_CUDNN_ERROR(
1703 cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1704 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1705 /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1706 /*yDesc=*/x_descriptor.handle(),
1707 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1708 /*activationDesc=*/activation_desc,
1709 /*sizeInBytes=*/&workspace_size_in_bytes));
1710 // Allocate the workspace.
1711 if (workspace_size_in_bytes == 0) {
1712 return DeviceMemory<uint8>();
1713 }
1714 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1715 }
1716
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1717 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1718 Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1719 const cudnnBatchNormOps_t& bn_ops,
1720 const CudnnTensorDescriptor& x_descriptor,
1721 const CudnnTensorDescriptor& scale_offset_descriptor,
1722 ScratchAllocator* workspace_allocator) {
1723 // Query the workspace size.
1724 size_t workspace_size_in_bytes = 0;
1725 RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1726 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1727 /*xDesc=*/x_descriptor.handle(),
1728 /*yDesc=*/x_descriptor.handle(),
1729 /*dyDesc=*/x_descriptor.handle(),
1730 /*dzDesc=*/nullptr,
1731 /*dxDesc=*/x_descriptor.handle(),
1732 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1733 /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1734 // Allocate the workspace.
1735 if (workspace_size_in_bytes == 0) {
1736 return DeviceMemory<uint8>();
1737 }
1738 return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1739 }
1740
1741 #endif
1742
1743 } // namespace
1744
1745 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1746 port::Status CudnnSupport::DoRnnForwardImpl(
1747 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1748 const CudnnRnnSequenceTensorDescriptor& input_desc,
1749 const DeviceMemory<T>& input_data,
1750 const CudnnRnnStateTensorDescriptor& input_h_desc,
1751 const DeviceMemory<T>& input_h_data,
1752 const CudnnRnnStateTensorDescriptor& input_c_desc,
1753 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1754 const CudnnRnnSequenceTensorDescriptor& output_desc,
1755 DeviceMemory<T>* output_data,
1756 const CudnnRnnStateTensorDescriptor& output_h_desc,
1757 DeviceMemory<T>* output_h_data,
1758 const CudnnRnnStateTensorDescriptor& output_c_desc,
1759 DeviceMemory<T>* output_c_data, bool is_training,
1760 ScratchAllocator* reserve_space_allocator,
1761 ScratchAllocator* workspace_allocator,
1762 dnn::ProfileResult* output_profile_result) {
1763 SE_ASSIGN_OR_RETURN(
1764 RnnModelDims model_dims,
1765 ExtractAndCheckRnnForward(
1766 rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1767 input_c_desc, input_c_data, params, output_desc, *output_data,
1768 output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1769
1770 auto cudnn = cudnn_->GetHandle(parent_, stream);
1771
1772 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1773 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1774 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1775 workspace_allocator))
1776
1777 // query the reserve space size
1778 // allocate the reserve space
1779 DeviceMemory<uint8> reserve_space;
1780 if (is_training) {
1781 size_t reserve_space_size_in_bytes = 0;
1782 RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1783 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1784 /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1785 /*sizeInBytes=*/&reserve_space_size_in_bytes));
1786
1787 if (reserve_space_size_in_bytes > 0) {
1788 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1789 reserve_space_size_in_bytes));
1790 }
1791 }
1792
1793 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1794 const bool is_profiling = output_profile_result != nullptr;
1795 if (is_profiling) {
1796 timer.reset(new GpuTimer(parent_));
1797 // The start and stop of the timer should be as close to the Cudnn call as
1798 // possible. It is still possible for other threads to issue workload on
1799 // to this stream. So it could take multiple profiling measurements.
1800 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1801 return port::Status(port::error::INTERNAL, "Failed to start timer");
1802 }
1803 }
1804
1805 if (!is_training) {
1806 if (input_desc.is_var_seq_lengths()) {
1807 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1808 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1809 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1810 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1811 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1812 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1813 /*yDesc=*/output_desc.data_handle(),
1814 /*y=*/output_data->opaque(),
1815 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1816 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1817 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1818 nullptr,
1819 /*workspace=*/workspace.opaque(),
1820 /*workSpaceSizeInBytes=*/workspace.size()));
1821 } else {
1822 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1823 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1824 /*seqLength=*/model_dims.max_seq_length,
1825 /*xDesc=*/input_desc.handles(),
1826 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1827 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1828 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1829 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1830 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1831 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1832 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1833 /*workSpaceSizeInBytes=*/workspace.size()));
1834 }
1835 } else {
1836 if (input_desc.is_var_seq_lengths()) {
1837 // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1838 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1839 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1840 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1841 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1842 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1843 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1844 /*yDesc=*/output_desc.data_handle(),
1845 /*y=*/output_data->opaque(),
1846 /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1847 /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1848 nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1849 nullptr,
1850 /*workspace=*/workspace.opaque(),
1851 /*workSpaceSizeInBytes=*/workspace.size(),
1852 /*reserveSpace=*/reserve_space.opaque(),
1853 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1854 } else {
1855 RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1856 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1857 /*seqLength=*/model_dims.max_seq_length,
1858 /*xDesc=*/input_desc.handles(),
1859 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1860 /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1861 /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1862 /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1863 /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1864 /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1865 /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1866 /*workSpaceSizeInBytes=*/workspace.size(),
1867 /*reserveSpace=*/reserve_space.opaque(),
1868 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1869 }
1870 }
1871
1872 if (is_profiling) {
1873 if (!timer->Stop(AsGpuStream(stream))) {
1874 return port::Status(port::error::INTERNAL, "Failed to stop timer");
1875 }
1876 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1877 output_profile_result->set_algorithm(algo_desc);
1878 output_profile_result->set_elapsed_time_in_ms(
1879 timer->GetElapsedMilliseconds());
1880 }
1881
1882 return port::Status::OK();
1883 }
1884
1885 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1886 port::Status CudnnSupport::DoRnnBackwardImpl(
1887 Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1888 const CudnnRnnSequenceTensorDescriptor& input_desc,
1889 const DeviceMemory<T>& input_data,
1890 const CudnnRnnStateTensorDescriptor& input_h_desc,
1891 const DeviceMemory<T>& input_h_data,
1892 const CudnnRnnStateTensorDescriptor& input_c_desc,
1893 const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1894 const CudnnRnnSequenceTensorDescriptor& output_desc,
1895 const DeviceMemory<T>& output_data,
1896 const CudnnRnnStateTensorDescriptor& output_h_desc,
1897 const DeviceMemory<T>& output_h_data,
1898 const CudnnRnnStateTensorDescriptor& output_c_desc,
1899 const DeviceMemory<T>& output_c_data,
1900 const DeviceMemory<T>& output_backprop_data,
1901 const DeviceMemory<T>& output_h_backprop_data,
1902 const DeviceMemory<T>& output_c_backprop_data,
1903 DeviceMemory<T>* input_backprop_data,
1904 DeviceMemory<T>* input_h_backprop_data,
1905 DeviceMemory<T>* input_c_backprop_data,
1906 DeviceMemory<T>* params_backprop_data,
1907 DeviceMemory<uint8>* reserve_space_data,
1908 ScratchAllocator* workspace_allocator,
1909 dnn::ProfileResult* output_profile_result) {
1910 SE_ASSIGN_OR_RETURN(
1911 RnnModelDims model_dims,
1912 ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1913 input_h_data, input_c_desc, input_c_data,
1914 params, output_desc, output_data, output_h_desc,
1915 output_h_data, output_c_desc, output_c_data));
1916
1917 auto cudnn = cudnn_->GetHandle(parent_, stream);
1918
1919 SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1920 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1921 CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1922 workspace_allocator));
1923
1924 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1925 const bool is_profiling = output_profile_result != nullptr;
1926 if (is_profiling) {
1927 timer.reset(new GpuTimer(parent_));
1928 // The start and stop of the timer should be as close to the Cudnn call as
1929 // possible. It is still possible for other threads to issue workload on
1930 // to this stream. So it could take multiple profiling measurements.
1931 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1932 return port::Status(port::error::INTERNAL, "Failed to start timer");
1933 }
1934 }
1935
1936 if (input_desc.is_var_seq_lengths()) {
1937 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1938 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1939 /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1940 /*dyDesc=*/output_desc.data_handle(),
1941 /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1942 /*dhyDesc=*/output_h_desc.handle(),
1943 /*dhy=*/output_h_backprop_data.opaque(),
1944 /*dcyDesc=*/output_c_desc.handle(),
1945 /*dcy=*/output_c_backprop_data.opaque(),
1946 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1947 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1948 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1949 /*dxDesc=*/input_desc.data_handle(),
1950 /*dx=*/input_backprop_data->opaque(),
1951 /*dhxDesc=*/input_h_desc.handle(),
1952 /*dhx=*/input_h_backprop_data->opaque(),
1953 /*dcxDesc=*/input_c_desc.handle(),
1954 /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1955 /*workspace=*/workspace.opaque(),
1956 /*workSpaceSizeInBytes=*/workspace.size(),
1957 /*reserveSpace=*/reserve_space_data->opaque(),
1958 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1959 } else {
1960 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1961 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1962 /*seqLength=*/model_dims.max_seq_length,
1963 /*yDesc=*/output_desc.handles(),
1964 /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1965 /*dy=*/output_backprop_data.opaque(),
1966 /*dhyDesc=*/output_h_desc.handle(),
1967 /*dhy=*/output_h_backprop_data.opaque(),
1968 /*dcyDesc=*/output_c_desc.handle(),
1969 /*dcy=*/output_c_backprop_data.opaque(),
1970 /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1971 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1972 /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1973 /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1974 /*dhxDesc=*/input_h_desc.handle(),
1975 /*dhx=*/input_h_backprop_data->opaque(),
1976 /*dcxDesc=*/input_c_desc.handle(),
1977 /*dcx=*/input_c_backprop_data->opaque(),
1978 /*workspace=*/workspace.opaque(),
1979 /*workSpaceSizeInBytes=*/workspace.size(),
1980 /*reserveSpace=*/reserve_space_data->opaque(),
1981 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1982 }
1983
1984 if (params_backprop_data != nullptr) {
1985 // Clear the dw to zeros.
1986 stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1987 if (input_desc.is_var_seq_lengths()) {
1988 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1989 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1990 /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1991 /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1992 /*yDesc=*/output_desc.data_handle(),
1993 /*y=*/output_data.opaque(),
1994 /*workspace=*/workspace.opaque(),
1995 /*workSpaceSizeInBytes=*/workspace.size(),
1996 /*dwDesc=*/rnn_desc.params_handle(),
1997 /*dw=*/params_backprop_data->opaque(),
1998 /*reserveSpace=*/reserve_space_data->opaque(),
1999 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2000 } else {
2001 // make the backward weight call
2002 RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2003 /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2004 /*seqLength=*/model_dims.max_seq_length,
2005 /*xDesc=*/input_desc.handles(),
2006 /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2007 /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2008 /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2009 /*workSpaceSizeInBytes=*/workspace.size(),
2010 /*dwDesc=*/rnn_desc.params_handle(),
2011 /*dw=*/params_backprop_data->opaque(),
2012 /*reserveSpace=*/reserve_space_data->opaque(),
2013 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2014 }
2015 }
2016
2017 if (is_profiling) {
2018 if (!timer->Stop(AsGpuStream(stream))) {
2019 return port::Status(port::error::INTERNAL, "Failed to stop timer");
2020 }
2021 auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2022 output_profile_result->set_algorithm(algo_desc);
2023 output_profile_result->set_elapsed_time_in_ms(
2024 timer->GetElapsedMilliseconds());
2025 }
2026
2027 return port::Status::OK();
2028 }
2029
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2030 port::Status CudnnSupport::DoCtcLossImpl(
2031 Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2032 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2033 absl::Span<const int> labels_lengths_data,
2034 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2035 const CudnnRnnStateTensorDescriptor& grads_desc,
2036 DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2037 DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2038 auto cudnn = cudnn_->GetHandle(parent_, stream);
2039
2040 int kNumTimestamps = probs_desc.num_layers();
2041 int kBatchSize = probs_desc.batch_size();
2042 int kNumLabels = probs_desc.data_size();
2043 int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2044 (void)total_size;
2045
2046 #if CUDNN_VERSION >= 7603
2047 cudnnCTCLossAlgo_t ctc_loss_algo =
2048 static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2049 RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2050 /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2051 /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2052 /*labelLengths=*/labels_lengths_data.data(),
2053 /*inputLengths=*/input_lengths_data.data(),
2054 /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2055 /*gradients=*/grads_data.opaque(),
2056 /*algo=*/ctc_loss_algo,
2057 /*ctcLossDesc=*/ctc_loss_desc.handle(),
2058 /*workspace=*/scratch_memory.opaque(),
2059 /*workSpaceSizeInBytes=*/scratch_memory.size()));
2060 #else
2061 return port::Status(port::error::INVALID_ARGUMENT,
2062 "No supported cudnnCTCLoss when "
2063 "CUDNN_VERSION < 7.6.3");
2064 #endif
2065
2066 return port::Status::OK();
2067 }
2068
2069 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2070 CudnnSupport::createRnnDescriptor(
2071 int num_layers, int hidden_size, int input_size, int cell_size,
2072 int batch_size, dnn::RnnInputMode input_mode,
2073 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2074 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2075 float dropout, uint64 seed, ScratchAllocator* state_allocator,
2076 bool use_padded_io) {
2077 // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2078 // not enqueueing anything into a stream, we pass in the null stream.
2079 auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2080 SE_ASSIGN_OR_RETURN(
2081 CudnnRnnDescriptor rnn_desc,
2082 CudnnRnnDescriptor::Create(
2083 cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2084 ToCudnnRnnInputMode(input_mode),
2085 ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2086 ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2087 algorithm_config, dropout, seed, state_allocator, use_padded_io));
2088 return std::unique_ptr<dnn::RnnDescriptor>(
2089 new CudnnRnnDescriptor(std::move(rnn_desc)));
2090 }
2091
2092 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2093 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2094 int batch_size, int data_size,
2095 dnn::DataType data_type) {
2096 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2097 CudnnRnnSequenceTensorDescriptor::Create(
2098 parent_, max_seq_length, batch_size, data_size,
2099 ToCudnnDataType(data_type)));
2100 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2101 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2102 }
2103
2104 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2105 CudnnSupport::createRnnSequenceTensorDescriptor(
2106 int max_seq_length, int batch_size, int data_size,
2107 const absl::Span<const int>& seq_lengths, bool time_major,
2108 dnn::DataType data_type) {
2109 SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2110 CudnnRnnSequenceTensorDescriptor::Create(
2111 parent_, max_seq_length, batch_size, data_size,
2112 seq_lengths, time_major, ToCudnnDataType(data_type)));
2113 return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2114 new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2115 }
2116
2117 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2118 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2119 int data_size,
2120 dnn::DataType data_type) {
2121 return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2122 new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2123 data_size, ToCudnnDataType(data_type)));
2124 }
2125
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2126 bool CudnnSupport::DoRnnForward(
2127 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2128 const dnn::RnnSequenceTensorDescriptor& input_desc,
2129 const DeviceMemory<Eigen::half>& input_data,
2130 const dnn::RnnStateTensorDescriptor& input_h_desc,
2131 const DeviceMemory<Eigen::half>& input_h_data,
2132 const dnn::RnnStateTensorDescriptor& input_c_desc,
2133 const DeviceMemory<Eigen::half>& input_c_data,
2134 const DeviceMemory<Eigen::half>& params,
2135 const dnn::RnnSequenceTensorDescriptor& output_desc,
2136 DeviceMemory<Eigen::half>* output_data,
2137 const dnn::RnnStateTensorDescriptor& output_h_desc,
2138 DeviceMemory<Eigen::half>* output_h_data,
2139 const dnn::RnnStateTensorDescriptor& output_c_desc,
2140 DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2141 ScratchAllocator* reserve_space_allocator,
2142 ScratchAllocator* workspace_allocator,
2143 dnn::ProfileResult* output_profile_result) {
2144 const CudnnRnnDescriptor& cudnn_rnn_desc =
2145 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2146 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2147 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2148 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2149 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2150 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2151 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2152 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2153 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2154 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2155 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2156 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2157 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2158 return IsStatusOk(
2159 DoRnnForwardImpl<Eigen::half>(
2160 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2161 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2162 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2163 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2164 reserve_space_allocator, workspace_allocator, output_profile_result),
2165 /*report_error=*/!output_profile_result);
2166 }
2167
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2168 bool CudnnSupport::DoRnnForward(
2169 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2170 const dnn::RnnSequenceTensorDescriptor& input_desc,
2171 const DeviceMemory<float>& input_data,
2172 const dnn::RnnStateTensorDescriptor& input_h_desc,
2173 const DeviceMemory<float>& input_h_data,
2174 const dnn::RnnStateTensorDescriptor& input_c_desc,
2175 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2176 const dnn::RnnSequenceTensorDescriptor& output_desc,
2177 DeviceMemory<float>* output_data,
2178 const dnn::RnnStateTensorDescriptor& output_h_desc,
2179 DeviceMemory<float>* output_h_data,
2180 const dnn::RnnStateTensorDescriptor& output_c_desc,
2181 DeviceMemory<float>* output_c_data, bool is_training,
2182 ScratchAllocator* reserve_space_allocator,
2183 ScratchAllocator* workspace_allocator,
2184 dnn::ProfileResult* output_profile_result) {
2185 const CudnnRnnDescriptor& cudnn_rnn_desc =
2186 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2187 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2188 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2189 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2190 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2191 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2192 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2193 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2194 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2195 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2196 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2197 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2198 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2199 return IsStatusOk(
2200 DoRnnForwardImpl<float>(
2201 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2202 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2203 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2204 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2205 reserve_space_allocator, workspace_allocator, output_profile_result),
2206 /*report_error=*/!output_profile_result);
2207 }
2208
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2209 bool CudnnSupport::DoRnnForward(
2210 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2211 const dnn::RnnSequenceTensorDescriptor& input_desc,
2212 const DeviceMemory<double>& input_data,
2213 const dnn::RnnStateTensorDescriptor& input_h_desc,
2214 const DeviceMemory<double>& input_h_data,
2215 const dnn::RnnStateTensorDescriptor& input_c_desc,
2216 const DeviceMemory<double>& input_c_data,
2217 const DeviceMemory<double>& params,
2218 const dnn::RnnSequenceTensorDescriptor& output_desc,
2219 DeviceMemory<double>* output_data,
2220 const dnn::RnnStateTensorDescriptor& output_h_desc,
2221 DeviceMemory<double>* output_h_data,
2222 const dnn::RnnStateTensorDescriptor& output_c_desc,
2223 DeviceMemory<double>* output_c_data, bool is_training,
2224 ScratchAllocator* reserve_space_allocator,
2225 ScratchAllocator* workspace_allocator,
2226 dnn::ProfileResult* output_profile_result) {
2227 const CudnnRnnDescriptor& cudnn_rnn_desc =
2228 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2229 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2230 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2231 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2232 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2233 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2234 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2235 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2236 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2237 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2238 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2239 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2240 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2241 return IsStatusOk(
2242 DoRnnForwardImpl<double>(
2243 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2244 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2245 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2246 output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2247 reserve_space_allocator, workspace_allocator, output_profile_result),
2248 /*report_error=*/!output_profile_result);
2249 }
2250
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2251 bool CudnnSupport::DoRnnBackward(
2252 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2253 const dnn::RnnSequenceTensorDescriptor& input_desc,
2254 const DeviceMemory<Eigen::half>& input_data,
2255 const dnn::RnnStateTensorDescriptor& input_h_desc,
2256 const DeviceMemory<Eigen::half>& input_h_data,
2257 const dnn::RnnStateTensorDescriptor& input_c_desc,
2258 const DeviceMemory<Eigen::half>& input_c_data,
2259 const DeviceMemory<Eigen::half>& params,
2260 const dnn::RnnSequenceTensorDescriptor& output_desc,
2261 const DeviceMemory<Eigen::half>& output_data,
2262 const dnn::RnnStateTensorDescriptor& output_h_desc,
2263 const DeviceMemory<Eigen::half>& output_h_data,
2264 const dnn::RnnStateTensorDescriptor& output_c_desc,
2265 const DeviceMemory<Eigen::half>& output_c_data,
2266 const DeviceMemory<Eigen::half>& output_backprop_data,
2267 const DeviceMemory<Eigen::half>& output_h_backprop_data,
2268 const DeviceMemory<Eigen::half>& output_c_backprop_data,
2269 DeviceMemory<Eigen::half>* input_backprop_data,
2270 DeviceMemory<Eigen::half>* input_h_backprop_data,
2271 DeviceMemory<Eigen::half>* input_c_backprop_data,
2272 DeviceMemory<Eigen::half>* params_backprop_data,
2273 DeviceMemory<uint8>* reserve_space_data,
2274 ScratchAllocator* workspace_allocator,
2275 dnn::ProfileResult* output_profile_result) {
2276 const CudnnRnnDescriptor& cudnn_rnn_desc =
2277 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2278 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2279 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2280 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2281 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2282 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2283 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2284 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2285 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2286 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2287 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2288 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2289 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2290 return IsStatusOk(
2291 DoRnnBackwardImpl<Eigen::half>(
2292 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2293 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2294 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2295 output_h_data, cudnn_output_c_desc, output_c_data,
2296 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2297 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2298 params_backprop_data, reserve_space_data, workspace_allocator,
2299 output_profile_result),
2300 /*report_error=*/!output_profile_result);
2301 }
2302
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2303 bool CudnnSupport::DoRnnBackward(
2304 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2305 const dnn::RnnSequenceTensorDescriptor& input_desc,
2306 const DeviceMemory<float>& input_data,
2307 const dnn::RnnStateTensorDescriptor& input_h_desc,
2308 const DeviceMemory<float>& input_h_data,
2309 const dnn::RnnStateTensorDescriptor& input_c_desc,
2310 const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2311 const dnn::RnnSequenceTensorDescriptor& output_desc,
2312 const DeviceMemory<float>& output_data,
2313 const dnn::RnnStateTensorDescriptor& output_h_desc,
2314 const DeviceMemory<float>& output_h_data,
2315 const dnn::RnnStateTensorDescriptor& output_c_desc,
2316 const DeviceMemory<float>& output_c_data,
2317 const DeviceMemory<float>& output_backprop_data,
2318 const DeviceMemory<float>& output_h_backprop_data,
2319 const DeviceMemory<float>& output_c_backprop_data,
2320 DeviceMemory<float>* input_backprop_data,
2321 DeviceMemory<float>* input_h_backprop_data,
2322 DeviceMemory<float>* input_c_backprop_data,
2323 DeviceMemory<float>* params_backprop_data,
2324 DeviceMemory<uint8>* reserve_space_data,
2325 ScratchAllocator* workspace_allocator,
2326 dnn::ProfileResult* output_profile_result) {
2327 const CudnnRnnDescriptor& cudnn_rnn_desc =
2328 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2329 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2330 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2331 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2332 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2333 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2334 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2335 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2336 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2337 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2338 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2339 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2340 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2341 return IsStatusOk(
2342 DoRnnBackwardImpl<float>(
2343 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2344 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2345 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2346 output_h_data, cudnn_output_c_desc, output_c_data,
2347 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2348 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2349 params_backprop_data, reserve_space_data, workspace_allocator,
2350 output_profile_result),
2351 /*report_error=*/!output_profile_result);
2352 }
2353
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2354 bool CudnnSupport::DoRnnBackward(
2355 Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2356 const dnn::RnnSequenceTensorDescriptor& input_desc,
2357 const DeviceMemory<double>& input_data,
2358 const dnn::RnnStateTensorDescriptor& input_h_desc,
2359 const DeviceMemory<double>& input_h_data,
2360 const dnn::RnnStateTensorDescriptor& input_c_desc,
2361 const DeviceMemory<double>& input_c_data,
2362 const DeviceMemory<double>& params,
2363 const dnn::RnnSequenceTensorDescriptor& output_desc,
2364 const DeviceMemory<double>& output_data,
2365 const dnn::RnnStateTensorDescriptor& output_h_desc,
2366 const DeviceMemory<double>& output_h_data,
2367 const dnn::RnnStateTensorDescriptor& output_c_desc,
2368 const DeviceMemory<double>& output_c_data,
2369 const DeviceMemory<double>& output_backprop_data,
2370 const DeviceMemory<double>& output_h_backprop_data,
2371 const DeviceMemory<double>& output_c_backprop_data,
2372 DeviceMemory<double>* input_backprop_data,
2373 DeviceMemory<double>* input_h_backprop_data,
2374 DeviceMemory<double>* input_c_backprop_data,
2375 DeviceMemory<double>* params_backprop_data,
2376 DeviceMemory<uint8>* reserve_space_data,
2377 ScratchAllocator* workspace_allocator,
2378 dnn::ProfileResult* output_profile_result) {
2379 const CudnnRnnDescriptor& cudnn_rnn_desc =
2380 static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2381 const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2382 static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2383 const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2384 static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2385 const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2386 static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2387 const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2388 static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2389 const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2390 static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2391 const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2392 static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2393 return IsStatusOk(
2394 DoRnnBackwardImpl<double>(
2395 stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2396 cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2397 params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2398 output_h_data, cudnn_output_c_desc, output_c_data,
2399 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2400 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2401 params_backprop_data, reserve_space_data, workspace_allocator,
2402 output_profile_result),
2403 /*report_error=*/!output_profile_result);
2404 }
2405
2406 namespace {
2407
2408 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2409 // and backward filter.
2410
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2411 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2412 const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2413 const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2414 const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2415 size_t memory_limit_bytes) {
2416 #if CUDNN_VERSION >= 8000
2417 const int num_requested_algos = 5;
2418 int num_returned_algos = 0;
2419 cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2420
2421 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2422 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2423 output_nd.handle(), num_requested_algos, &num_returned_algos,
2424 perf_results));
2425
2426 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2427 for (int r = 0; r < num_returned_algos; r++) {
2428 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2429 perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2430 perf_results[r].memory <= mem_limit) {
2431 return perf_results[r].algo;
2432 }
2433 }
2434 return port::Status(port::error::INTERNAL,
2435 "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2436 "no suitable algorithms. This could be a cudnn bug.");
2437 #else
2438 cudnnConvolutionFwdPreference_t preference =
2439 specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2440 : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2441 cudnnConvolutionFwdAlgo_t algo_to_use;
2442 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2443 cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2444 output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2445 return algo_to_use;
2446 #endif
2447 }
2448
2449 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2450 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2451 const CudnnTensorDescriptor& input_nd,
2452 const CudnnFilterDescriptor& filter,
2453 const CudnnConvolutionDescriptor& conv,
2454 const CudnnTensorDescriptor& output_nd,
2455 bool specify_workspace_limit,
2456 size_t memory_limit_bytes) {
2457 #if CUDNN_VERSION >= 8000
2458 const int num_requested_algos = 5;
2459 int num_returned_algos = 0;
2460 cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2461
2462 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2463 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2464 input_nd.handle(), num_requested_algos, &num_returned_algos,
2465 perf_results));
2466
2467 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2468 for (int r = 0; r < num_returned_algos; r++) {
2469 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2470 perf_results[r].algo !=
2471 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2472 perf_results[r].memory <= mem_limit) {
2473 return perf_results[r].algo;
2474 }
2475 }
2476 return port::Status(port::error::INTERNAL,
2477 "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2478 "no suitable algorithms. This could be a cudnn bug.");
2479 #else
2480 cudnnConvolutionBwdDataPreference_t preference =
2481 specify_workspace_limit
2482 ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2483 : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2484 cudnnConvolutionBwdDataAlgo_t algo_to_use;
2485 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2486 cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2487 input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2488 return algo_to_use;
2489 #endif
2490 }
2491
2492 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2493 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2494 const CudnnTensorDescriptor& input_nd,
2495 const CudnnFilterDescriptor& filter,
2496 const CudnnConvolutionDescriptor& conv,
2497 const CudnnTensorDescriptor& output_nd,
2498 bool specify_workspace_limit,
2499 size_t memory_limit_bytes) {
2500 #if CUDNN_VERSION >= 8000
2501 const int num_requested_algos = 5;
2502 int num_returned_algos = 0;
2503 cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2504
2505 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2506 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2507 filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2508
2509 size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2510 for (int r = 0; r < num_returned_algos; r++) {
2511 if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2512 perf_results[r].algo !=
2513 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2514 perf_results[r].memory <= mem_limit) {
2515 return perf_results[r].algo;
2516 }
2517 }
2518 return port::Status(port::error::INTERNAL,
2519 "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2520 "no suitable algorithms. This could be a cudnn bug.");
2521 #else
2522 cudnnConvolutionBwdFilterPreference_t preference =
2523 specify_workspace_limit
2524 ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2525 : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2526 cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2527 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2528 cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2529 filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2530 return algo_to_use;
2531 #endif
2532 }
2533
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2534 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2535 Stream* stream, const CudnnHandle& cudnn,
2536 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2537 const CudnnConvolutionDescriptor& conv,
2538 const CudnnTensorDescriptor& output_nd,
2539 const dnn::AlgorithmDesc& algorithm_desc,
2540 ScratchAllocator* scratch_allocator) {
2541 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2542 return port::Status(
2543 port::error::INTERNAL,
2544 "Mismatch between cudnn conv and algorithm descriptors.");
2545 }
2546
2547 // Query the size of the workspace and allocate it.
2548 size_t size_in_bytes;
2549 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2550 cudnn.handle(),
2551 /*xDesc=*/input_nd.handle(),
2552 /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2553 /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2554 /*sizeInBytes=*/&size_in_bytes));
2555
2556 int64 size_in_bytes_int64 = size_in_bytes;
2557
2558 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2559 return port::Status(
2560 port::error::INTERNAL,
2561 "cudnnGetConvolutionForwardWorkspaceSize() returned "
2562 "negative sizeInBytes value. This could be a cudnn bug.");
2563 }
2564
2565 if (size_in_bytes_int64 == 0) {
2566 return DeviceMemory<uint8>();
2567 }
2568
2569 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2570 return port::Status(port::error::INVALID_ARGUMENT,
2571 "No scratch allocator provided");
2572 }
2573
2574 return scratch_allocator->AllocateBytes(size_in_bytes);
2575 }
2576
2577 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2578 AllocateCudnnConvolutionBackwardDataWorkspace(
2579 Stream* stream, const CudnnHandle& cudnn,
2580 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2581 const CudnnConvolutionDescriptor& conv,
2582 const CudnnTensorDescriptor& output_nd,
2583 const dnn::AlgorithmDesc& algorithm_desc,
2584 ScratchAllocator* scratch_allocator) {
2585 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2586 return port::Status(
2587 port::error::INTERNAL,
2588 "Mismatch between cudnn conv and algorithm descriptors.");
2589 }
2590
2591 // Query the size of the workspace and allocate it.
2592 size_t size_in_bytes;
2593 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2594 cudnn.handle(),
2595 /*wDesc=*/filter.handle(),
2596 /*dyDesc=*/output_nd.handle(),
2597 /*convDesc=*/conv.handle(),
2598 /*dxDesc=*/input_nd.handle(),
2599 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2600 /*sizeInBytes=*/&size_in_bytes));
2601
2602 int64 size_in_bytes_int64 = size_in_bytes;
2603
2604 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2605 return port::Status(
2606 port::error::INTERNAL,
2607 "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2608 "negative sizeInBytes value. This could be a cudnn bug.");
2609 }
2610
2611 if (size_in_bytes_int64 == 0) {
2612 return DeviceMemory<uint8>();
2613 }
2614
2615 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2616 return port::Status(port::error::INVALID_ARGUMENT,
2617 "No scratch allocator provided");
2618 }
2619
2620 return scratch_allocator->AllocateBytes(size_in_bytes);
2621 }
2622
2623 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2624 AllocateCudnnConvolutionBackwardFilterWorkspace(
2625 Stream* stream, const CudnnHandle& cudnn,
2626 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2627 const CudnnConvolutionDescriptor& conv,
2628 const CudnnTensorDescriptor& output_nd,
2629 const dnn::AlgorithmDesc& algorithm_desc,
2630 ScratchAllocator* scratch_allocator) {
2631 if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2632 return port::Status(
2633 port::error::INTERNAL,
2634 "Mismatch between cudnn conv and algorithm descriptors.");
2635 }
2636
2637 // Query the size of the workspace and allocate it.
2638 size_t size_in_bytes;
2639 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2640 cudnn.handle(),
2641 /*xDesc=*/input_nd.handle(),
2642 /*dyDesc=*/output_nd.handle(),
2643 /*convDesc=*/conv.handle(),
2644 /*gradDesc=*/filter.handle(),
2645 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2646 /*sizeInBytes=*/&size_in_bytes));
2647
2648 int64 size_in_bytes_int64 = size_in_bytes;
2649
2650 if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2651 return port::Status(
2652 port::error::INTERNAL,
2653 "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2654 "negative sizeInBytes value. This could be a cudnn bug.");
2655 }
2656
2657 if (size_in_bytes_int64 == 0) {
2658 return DeviceMemory<uint8>();
2659 }
2660
2661 if (TF_PREDICT_FALSE(!scratch_allocator)) {
2662 return port::Status(port::error::INVALID_ARGUMENT,
2663 "No scratch allocator provided");
2664 }
2665
2666 return scratch_allocator->AllocateBytes(size_in_bytes);
2667 }
2668
UseTensorOps(Stream * stream,dnn::DataType type,absl::optional<dnn::AlgorithmDesc> desc)2669 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
2670 absl::optional<dnn::AlgorithmDesc> desc) {
2671 bool use_tensor_ops;
2672 if (desc.has_value()) {
2673 use_tensor_ops = desc->tensor_ops_enabled();
2674 if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
2675 return port::Status(port::error::INVALID_ARGUMENT,
2676 "Algo requests disabled tensor op evaluation.");
2677 }
2678 } else {
2679 use_tensor_ops = IsTensorMathEnabled(stream, type);
2680 }
2681 return use_tensor_ops;
2682 }
2683
2684 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
2685 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
2686
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2687 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2688 Stream* stream, const CudnnHandle& cudnn,
2689 const dnn::AlgorithmConfig& algorithm_config,
2690 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2691 dnn::DataType element_type,
2692 const dnn::ConvolutionDescriptor& convolution_descriptor,
2693 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2694 DeviceMemory<uint8>* scratch) {
2695 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2696
2697 CudnnConvolutionDescriptor conv(
2698 convolution_descriptor,
2699 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2700 bool use_tensor_ops;
2701 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2702 UseTensorOps(stream, element_type, algo_desc));
2703 conv.set_use_tensor_op_math(use_tensor_ops);
2704
2705 if (!algo_desc.has_value()) {
2706 // Pick fastest algorithm within memory limit according to cuDNN's
2707 // heuristics.
2708 bool specify_workspace_limit = scratch_allocator != nullptr;
2709 auto memory_limit_bytes =
2710 specify_workspace_limit
2711 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2712 : int64{0};
2713 SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2714 GetCudnnConvolutionForwardAlgo(
2715 cudnn, input_nd, filter, conv, output_nd,
2716 specify_workspace_limit, memory_limit_bytes));
2717 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2718 }
2719
2720 const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2721 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2722 scratch_allocator);
2723
2724 if (scratch_or.ok()) {
2725 *scratch = scratch_or.ValueOrDie();
2726 return *algo_desc;
2727 }
2728
2729 algo_desc = algorithm_config.algorithm_no_scratch();
2730
2731 // Failed to allocate workspace for the first algorithm, fall back to the
2732 // no_scratch algorithm.
2733 if (!algo_desc.has_value()) {
2734 return port::Status(
2735 scratch_or.status().code(),
2736 absl::StrCat("The primary convolution algorithm failed, ",
2737 "while a secondary algorithm is not provided. ",
2738 "Returned status: ", scratch_or.status().ToString()));
2739 }
2740
2741 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2742 UseTensorOps(stream, element_type, algo_desc));
2743 conv.set_use_tensor_op_math(use_tensor_ops);
2744 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2745 stream, cudnn, input_nd, filter, conv,
2746 output_nd, *algo_desc, scratch_allocator));
2747 return *algo_desc;
2748 }
2749
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2750 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2751 Stream* stream, const CudnnHandle& cudnn,
2752 const dnn::AlgorithmConfig& algorithm_config,
2753 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2754 dnn::DataType element_type,
2755 const dnn::ConvolutionDescriptor& convolution_descriptor,
2756 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2757 DeviceMemory<uint8>* scratch) {
2758 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2759 CudnnConvolutionDescriptor conv(
2760 convolution_descriptor,
2761 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2762 bool use_tensor_ops;
2763 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2764 UseTensorOps(stream, element_type, algo_desc));
2765 conv.set_use_tensor_op_math(use_tensor_ops);
2766
2767 if (!algo_desc.has_value()) {
2768 // Pick fastest algorithm within memory limit according to cuDNN's
2769 // heuristics.
2770 bool specify_workspace_limit = scratch_allocator != nullptr;
2771 auto memory_limit_bytes =
2772 specify_workspace_limit
2773 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2774 : int64{0};
2775 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2776 GetCudnnConvolutionBackwardDataAlgo(
2777 cudnn, input_nd, filter, conv, output_nd,
2778 specify_workspace_limit, memory_limit_bytes));
2779 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2780 }
2781
2782 const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2783 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2784 scratch_allocator);
2785
2786 if (scratch_or.ok()) {
2787 *scratch = scratch_or.ValueOrDie();
2788 return *algo_desc;
2789 }
2790
2791 algo_desc = algorithm_config.algorithm_no_scratch();
2792
2793 // Failed to allocate workspace for the first algorithm, fall back to the
2794 // no_scratch algorithm.
2795 if (!algo_desc.has_value()) {
2796 return port::Status(
2797 port::error::INVALID_ARGUMENT,
2798 "The primary convolution algorithm failed memory allocation, "
2799 "while a secondary algorithm is not provided.");
2800 }
2801
2802 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2803 UseTensorOps(stream, element_type, algo_desc));
2804 conv.set_use_tensor_op_math(use_tensor_ops);
2805 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2806 stream, cudnn, input_nd, filter, conv,
2807 output_nd, *algo_desc, scratch_allocator));
2808 return *algo_desc;
2809 }
2810
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2811 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2812 Stream* stream, const CudnnHandle& cudnn,
2813 const dnn::AlgorithmConfig& algorithm_config,
2814 const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2815 dnn::DataType element_type,
2816 const dnn::ConvolutionDescriptor& convolution_descriptor,
2817 const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2818 DeviceMemory<uint8>* scratch) {
2819 absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2820 CudnnConvolutionDescriptor conv(
2821 convolution_descriptor,
2822 ToCudnnDataType(GetConvAccumulatorType(element_type)));
2823 bool use_tensor_ops;
2824 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2825 UseTensorOps(stream, element_type, algo_desc));
2826 conv.set_use_tensor_op_math(use_tensor_ops);
2827
2828 if (!algo_desc.has_value()) {
2829 // Pick fastest algorithm within memory limit according to cuDNN's
2830 // heuristics.
2831 bool specify_workspace_limit = scratch_allocator != nullptr;
2832 auto memory_limit_bytes =
2833 specify_workspace_limit
2834 ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2835 : int64{0};
2836 SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2837 GetCudnnConvolutionBackwardFilterAlgo(
2838 cudnn, input_nd, filter, conv, output_nd,
2839 specify_workspace_limit, memory_limit_bytes));
2840 algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2841 }
2842
2843 auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2844 stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2845 scratch_allocator);
2846
2847 if (scratch_or.ok()) {
2848 *scratch = scratch_or.ValueOrDie();
2849 return *algo_desc;
2850 }
2851
2852 algo_desc = algorithm_config.algorithm_no_scratch();
2853
2854 // Failed to allocate workspace for the first algorithm, fall back to the
2855 // no_scratch algorithm.
2856 if (!algo_desc.has_value()) {
2857 return port::Status(
2858 port::error::INVALID_ARGUMENT,
2859 "The primary convolution algorithm failed memory allocation, "
2860 "while a secondary algorithm is not provided.");
2861 }
2862
2863 SE_ASSIGN_OR_RETURN(use_tensor_ops,
2864 UseTensorOps(stream, element_type, algo_desc));
2865 conv.set_use_tensor_op_math(use_tensor_ops);
2866 SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2867 stream, cudnn, input_nd, filter, conv,
2868 output_nd, *algo_desc, scratch_allocator));
2869 return *algo_desc;
2870 }
2871
2872 // A helper class to set env-vars and choose options for cudnn-related
2873 // algorithms.
2874 template <typename EnvVar>
2875 class CudnnEnvVar {
2876 public:
IsEnabled()2877 static bool IsEnabled() {
2878 static bool is_enabled = IsEnabledImpl();
2879 return is_enabled;
2880 }
2881
2882 private:
IsEnabledImpl()2883 static bool IsEnabledImpl() {
2884 const char* tf_env_var_val = getenv(EnvVar::kName);
2885 if (tf_env_var_val != nullptr) {
2886 absl::string_view tf_env_var_val_str(tf_env_var_val);
2887 if (tf_env_var_val_str == "0") {
2888 return false;
2889 }
2890 return true;
2891 }
2892 return EnvVar::kDefaultFlag;
2893 }
2894 };
2895
2896 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2897 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2898 // caused by some shapes with this algorithm. Users can explicitly enable the
2899 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2900 struct FftTilingForward {
2901 static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2902 static constexpr bool kDefaultFlag = true;
2903 };
2904
2905 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2906 // By default it is turned on, users can explicitly disable them through an
2907 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2908 // https://github.com/tensorflow/tensorflow/pull/4901
2909 struct WinogradNonfused {
2910 static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2911 // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2912 // we have a workaround.
2913 static constexpr bool kDefaultFlag = true;
2914 };
2915
2916 // A helper struct to decide whether to use FP32 as the internal compute type
2917 // for convolution when the input data type is FP16. By default it is turned on,
2918 // users can explicitly disable them (choose to use FP16 as the internal compute
2919 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2920 struct ConvDoFP32ComputationFP16Input {
2921 static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2922 // Using FP16 as the internal compute type for convolution when the input data
2923 // type is FP16 is only supported on architectures with true fp16 support
2924 // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2925 // architecture will cause internal errors.
2926 static constexpr bool kDefaultFlag = true;
2927 };
2928
2929 // A helper struct to decide whether to use FP32 as the internal compute type
2930 // for rnn when the input data type is FP16. At present it is turned off,
2931 // users can explicitly control them through an env-var
2932 // TF_FP16_RNN_USE_FP32_COMPUTE.
2933 // After the TODO below is fixed, users should almost always use fp32 compute
2934 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2935 // in precision.
2936 struct RnnDoFP32ComputationFP16Input {
2937 static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2938 // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2939 // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2940 // precision is set.
2941 // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2942 // fp32 math precision.
2943 //
2944 // cuDNN == 7.5.0 is verified to have this fixed.
2945 static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
2946 };
2947
GetRnnComputeType(dnn::DataType data_type)2948 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2949 switch (data_type) {
2950 case dnn::DataType::kFloat:
2951 return CUDNN_DATA_FLOAT;
2952 case dnn::DataType::kDouble:
2953 return CUDNN_DATA_DOUBLE;
2954 case dnn::DataType::kHalf:
2955 if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2956 return CUDNN_DATA_FLOAT;
2957 } else {
2958 return CUDNN_DATA_HALF;
2959 }
2960 default:
2961 LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2962 }
2963 }
2964
GetConvAccumulatorType(dnn::DataType data_type)2965 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2966 switch (data_type) {
2967 case dnn::DataType::kFloat:
2968 case dnn::DataType::kDouble:
2969 return data_type;
2970 case dnn::DataType::kHalf:
2971 return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2972 ? dnn::DataType::kFloat
2973 : dnn::DataType::kHalf;
2974 case dnn::DataType::kInt8:
2975 case dnn::DataType::kInt32:
2976 return dnn::DataType::kInt32;
2977 default:
2978 LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2979 }
2980 }
2981 } // namespace
2982
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2983 port::Status CudnnSupport::DoPrepareForConvolution(
2984 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2985 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2986 const dnn::FilterDescriptor& filter_descriptor,
2987 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2988 DeviceMemoryBase output_data,
2989 const dnn::ConvolutionDescriptor& convolution_descriptor,
2990 const dnn::AlgorithmConfig& algorithm_config,
2991 ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2992 DeviceMemory<uint8>* scratch_memory) {
2993 CudnnTensorDescriptor input_nd(
2994 input_descriptor,
2995 ToCudnnDataType(element_type, input_descriptor.layout()));
2996 CudnnFilterDescriptor filter_nd(
2997 filter_descriptor,
2998 ToCudnnDataType(element_type, filter_descriptor.layout()));
2999 CudnnTensorDescriptor output_nd(
3000 output_descriptor,
3001 ToCudnnDataType(element_type, output_descriptor.layout()));
3002
3003 auto cudnn = cudnn_->GetHandle(parent_, stream);
3004
3005 switch (kind) {
3006 case dnn::ConvolutionKind::FORWARD: {
3007 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3008 GetCudnnConvolutionForwardAlgorithm(
3009 stream, cudnn, algorithm_config, input_nd,
3010 filter_nd, element_type, convolution_descriptor,
3011 output_nd, scratch_allocator, scratch_memory));
3012 break;
3013 }
3014 case dnn::ConvolutionKind::BACKWARD_DATA: {
3015 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3016 GetCudnnConvolutionBackwardDataAlgorithm(
3017 stream, cudnn, algorithm_config, input_nd,
3018 filter_nd, element_type, convolution_descriptor,
3019 output_nd, scratch_allocator, scratch_memory));
3020 break;
3021 }
3022 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3023 SE_ASSIGN_OR_RETURN(*algorithm_desc,
3024 GetCudnnConvolutionBackwardFilterAlgorithm(
3025 stream, cudnn, algorithm_config, input_nd,
3026 filter_nd, element_type, convolution_descriptor,
3027 output_nd, scratch_allocator, scratch_memory));
3028 break;
3029 }
3030 default:
3031 return port::InternalError(
3032 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3033 }
3034
3035 return port::Status::OK();
3036 }
3037
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)3038 port::Status CudnnSupport::DoConvolve(
3039 dnn::ConvolutionKind kind, dnn::DataType element_type,
3040 dnn::DataType output_type, Stream* stream,
3041 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3042 const dnn::FilterDescriptor& filter_descriptor,
3043 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3044 DeviceMemoryBase output_data,
3045 const dnn::ConvolutionDescriptor& convolution_descriptor,
3046 dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
3047 dnn::ProfileResult* output_profile_result) {
3048 cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3049 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3050 CudnnTensorDescriptor output_nd(output_descriptor,
3051 ToCudnnDataType(output_type));
3052 CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
3053 auto accumulator_type = GetConvAccumulatorType(element_type);
3054 CudnnConvolutionDescriptor conv(convolution_descriptor,
3055 ToCudnnDataType(accumulator_type));
3056 SE_ASSIGN_OR_RETURN(bool use_tensor_ops,
3057 UseTensorOps(stream, element_type, algorithm_desc));
3058 conv.set_use_tensor_op_math(use_tensor_ops);
3059
3060 auto cudnn = cudnn_->GetHandle(parent_, stream);
3061 // Alpha is the scaling factor for input.
3062 float falpha = 1.0;
3063 double dalpha = 1.0;
3064 void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
3065 : static_cast<void*>(&falpha);
3066 // Beta is the scaling factor for output.
3067 float fbeta = 0.0;
3068 double dbeta = 0.0;
3069 void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
3070 : static_cast<void*>(&fbeta);
3071
3072 const bool is_profiling = output_profile_result != nullptr;
3073
3074 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3075 if (is_profiling) {
3076 timer.reset(new GpuTimer(parent_)); // NOLINT
3077 // The start and stop of the timer should be as close to the Cudnn call as
3078 // possible. It is still possible for other threads to issue workload on
3079 // to this stream. So it could take multiple profiling measurements.
3080 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3081 return port::Status(port::error::INTERNAL, "Failed to start timer");
3082 }
3083 }
3084
3085 const auto get_fwd_bugs = [&]() -> port::Status {
3086 if (CUDNN_VERSION < 8000) {
3087 if (algorithm_desc.algo_id() ==
3088 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3089 ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3090 ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3091 return port::Status(
3092 port::error::FAILED_PRECONDITION,
3093 "This configuration potentially produces incorrect results.");
3094 }
3095 }
3096 return port::Status::OK();
3097 };
3098
3099 auto get_bwd_data_bugs = [&]() -> port::Status {
3100 return port::Status::OK();
3101 };
3102
3103 const auto get_bwd_filter_bugs = [&]() -> port::Status {
3104 return port::Status::OK();
3105 };
3106
3107 switch (kind) {
3108 case dnn::ConvolutionKind::FORWARD: {
3109 SE_RETURN_IF_ERROR(get_fwd_bugs());
3110 RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3111 cudnn.handle(),
3112 /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3113 /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3114 /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3115 /*algo=*/ToConvForwardAlgo(algorithm_desc),
3116 /*workSpace=*/scratch_memory.opaque(),
3117 /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
3118 /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
3119 break;
3120 }
3121 case dnn::ConvolutionKind::BACKWARD_DATA: {
3122 SE_RETURN_IF_ERROR(get_bwd_data_bugs());
3123 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
3124 cudnn.handle(),
3125 /*alpha=*/alpha,
3126 /*wDesc=*/filter_nd.handle(),
3127 /*w=*/filter_data.opaque(),
3128 /*dyDesc=*/output_nd.handle(),
3129 /*dy=*/output_data.opaque(),
3130 /*convDesc=*/conv.handle(),
3131 /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
3132 /*workSpace=*/scratch_memory.opaque(),
3133 /*workSpaceSizeInBytes=*/scratch_memory.size(),
3134 /*beta=*/beta,
3135 /*dxDesc=*/input_nd.handle(),
3136 /*dx=*/input_data.opaque()));
3137 break;
3138 }
3139 case dnn::ConvolutionKind::BACKWARD_FILTER: {
3140 SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
3141 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
3142 cudnn.handle(),
3143 /*alpha=*/alpha,
3144 /*srcDesc=*/input_nd.handle(),
3145 /*srcData=*/input_data.opaque(),
3146 /*diffDesc=*/output_nd.handle(),
3147 /*diffData=*/output_data.opaque(),
3148 /*convDesc=*/conv.handle(),
3149 /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3150 /*workSpace=*/scratch_memory.opaque(),
3151 /*workSpaceSizeInBytes=*/scratch_memory.size(),
3152 /*beta=*/beta,
3153 /*gradDesc=*/filter_nd.handle(),
3154 /*dw=*/filter_data.opaque()));
3155 break;
3156 }
3157 default:
3158 return port::InternalError(
3159 absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3160 }
3161
3162 if (is_profiling) {
3163 if (!timer->Stop(AsGpuStream(stream))) {
3164 return port::Status(port::error::INTERNAL, "Failed to stop timer");
3165 }
3166 output_profile_result->set_algorithm(algorithm_desc);
3167 output_profile_result->set_elapsed_time_in_ms(
3168 timer->GetElapsedMilliseconds());
3169 output_profile_result->set_scratch_size(scratch_memory.size());
3170 }
3171
3172 return port::Status::OK();
3173 }
3174
3175 template <typename ElementType, typename BiasType, typename ScaleType,
3176 typename OutputType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<OutputType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3177 port::Status CudnnSupport::DoFusedConvolveImpl(
3178 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3179 const DeviceMemory<ElementType>& conv_input_data,
3180 ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
3181 const DeviceMemory<ElementType>& filter_data,
3182 const dnn::ConvolutionDescriptor& convolution_descriptor,
3183 const DeviceMemory<OutputType>& side_input_data, ScaleType side_input_scale,
3184 const dnn::BatchDescriptor& bias_descriptor,
3185 const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
3186 const dnn::BatchDescriptor& output_descriptor,
3187 DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
3188 ScratchAllocator* scratch_allocator,
3189 const dnn::AlgorithmConfig& algorithm_config,
3190 dnn::ProfileResult* output_profile_result) {
3191 if (activation_mode != dnn::ActivationMode::kRelu &&
3192 activation_mode != dnn::ActivationMode::kNone) {
3193 return port::Status(port::error::INVALID_ARGUMENT,
3194 "cudnnConvolutionBiasActivationForward() only supports "
3195 "Relu or None activation.");
3196 }
3197
3198 CudnnTensorDescriptor conv_input_nd(
3199 conv_input_descriptor,
3200 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3201 CudnnTensorDescriptor output_nd(
3202 output_descriptor,
3203 GetCudnnDataType<OutputType>(conv_input_descriptor.layout()));
3204 CudnnFilterDescriptor filter(
3205 filter_descriptor,
3206 GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3207 CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
3208
3209 auto cudnn = cudnn_->GetHandle(parent_, stream);
3210
3211 const bool is_profiling = output_profile_result != nullptr;
3212
3213 DeviceMemory<uint8> scratch;
3214 SE_ASSIGN_OR_RETURN(
3215 dnn::AlgorithmDesc algo_desc,
3216 GetCudnnConvolutionForwardAlgorithm(
3217 stream, cudnn, algorithm_config, conv_input_nd, filter,
3218 dnn::ToDataType<ElementType>::value, convolution_descriptor,
3219 output_nd, scratch_allocator, &scratch));
3220
3221 CudnnConvolutionDescriptor conv(convolution_descriptor,
3222 ToCudnnDataType(accumulator_type));
3223 conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
3224
3225 std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3226 if (is_profiling) {
3227 timer.reset(new GpuTimer(parent_)); // NOLINT
3228 // The start and stop of the timer should be as close to the Cudnn call as
3229 // possible. It is still possible for other threads to issue workload on
3230 // to this stream. So it could take multiple profiling measurements.
3231 if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3232 return port::Status(port::error::INTERNAL, "Failed to start timer");
3233 }
3234 }
3235 // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
3236 // activation descriptor. Note that this will change the nan propagation
3237 // behavior from separate conv, bias, and relu (which by default is
3238 // CUDNN_PROPAGATE_NAN.
3239 CudnnActivationDescriptor activation_desc(
3240 activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
3241 auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
3242 : side_input_data.opaque();
3243
3244 VLOG(2) << "\nconv_input_scale = " << conv_input_scale
3245 << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
3246 << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
3247 << "\nfilter.handle() = " << filter.handle()
3248 << "\nfilter_data.opaque() = " << filter_data.opaque()
3249 << "\nconv.handle() = " << conv.handle()
3250 << "\nalgo = " << algo_desc.algo_id()
3251 << "\nscratch.opaque() = " << scratch.opaque()
3252 << "\nscratch.size() = " << scratch.size()
3253 << "\nside_input_scale = " << side_input_scale
3254 << "\noutput_nd.handle() = " << output_nd.handle()
3255 << "\nside_input_data_ptr = " << side_input_data_ptr
3256 << "\nbias_nd.handle() = " << bias_nd.handle()
3257 << "\nbiases.opaque() = " << biases.opaque()
3258 << "\nactivation_desc.handle() = " << activation_desc.handle()
3259 << "\noutput_nd.handle() = " << output_nd.handle()
3260 << "\noutput_data->opaque() = " << output_data->opaque();
3261
3262 if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
3263 return port::Status(port::error::FAILED_PRECONDITION,
3264 "Tensor op math type in dnn::AlgorithmDesc does not "
3265 "match that of the CudnnConvolutionDescriptor");
3266 }
3267
3268 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3269 cudnn.handle(),
3270 /*alpha1=*/&conv_input_scale,
3271 /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3272 /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3273 /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3274 /*workSpace=*/scratch.opaque(),
3275 /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3276 /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3277 /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3278 /*activationDesc=*/activation_desc.handle(),
3279 /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3280
3281 if (is_profiling) {
3282 if (!timer->Stop(AsGpuStream(stream))) {
3283 return port::Status(port::error::INTERNAL, "Failed to stop timer");
3284 }
3285 output_profile_result->set_algorithm(algo_desc);
3286 output_profile_result->set_elapsed_time_in_ms(
3287 timer->GetElapsedMilliseconds());
3288 output_profile_result->set_scratch_size(scratch.size());
3289 }
3290
3291 return port::Status::OK();
3292 }
3293
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3294 bool CudnnSupport::GetConvolveAlgorithms(
3295 bool with_winograd_nonfused, int cc_major, int cc_minor,
3296 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3297 // Preload sub libs for cudnn 8.0.4+
3298 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3299 cudnnOpsInferVersionCheck();
3300 cudnnCnnInferVersionCheck();
3301 #endif
3302 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3303 out_algorithms->clear();
3304
3305 std::vector<dnn::AlgorithmDesc::Index> algo_types;
3306 if (ConvUseDefaultAlgorithm()) {
3307 // Force a fallback algorithm.
3308 algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
3309 } else {
3310 algo_types = {
3311 // clang-format off
3312 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3313 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3314 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3315 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3316 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3317 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3318 // clang-format on
3319 };
3320 if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3321 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3322 }
3323 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3324 algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3325 }
3326 }
3327
3328 // The algorithms are intentionally ordered for deterministic operation
3329 for (auto i : algo_types) {
3330 if (tensor_op_math_available) {
3331 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3332 }
3333 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3334 }
3335
3336 return true;
3337 }
3338
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3339 bool CudnnSupport::GetRnnAlgorithms(
3340 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3341 // Preload sub libs for cudnn 8.0.4+
3342 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3343 cudnnOpsInferVersionCheck();
3344 cudnnOpsTrainVersionCheck();
3345 cudnnAdvInferVersionCheck();
3346 cudnnAdvTrainVersionCheck();
3347 #endif
3348 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3349 // clang-format off
3350 CUDNN_RNN_ALGO_STANDARD,
3351 CUDNN_RNN_ALGO_PERSIST_STATIC,
3352 CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3353 // clang-format on
3354 };
3355
3356 out_algorithms->clear();
3357 for (auto i : algo_types) {
3358 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3359 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3360 }
3361 return true;
3362 }
3363
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3364 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3365 bool with_winograd_nonfused, int cc_major, int cc_minor,
3366 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3367 // Preload sub libs for cudnn 8.0.4+
3368 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3369 cudnnOpsInferVersionCheck();
3370 cudnnOpsTrainVersionCheck();
3371 cudnnCnnInferVersionCheck();
3372 cudnnCnnTrainVersionCheck();
3373 #endif
3374 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3375 out_algorithms->clear();
3376
3377 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3378 // clang-format off
3379 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3380 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3381 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3382 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3383 // clang-format on
3384 };
3385 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3386 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3387 }
3388 if (!RequireCudnnDeterminism()) {
3389 algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
3390 }
3391
3392 // The algorithms are intentionally ordered for deterministic operation
3393 for (auto i : algo_types) {
3394 if (tensor_op_math_available) {
3395 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3396 }
3397 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3398 }
3399
3400 return true;
3401 }
3402
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3403 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3404 bool with_winograd_nonfused, int cc_major, int cc_minor,
3405 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3406 // Preload sub libs for cudnn 8.0.4+
3407 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3408 cudnnOpsInferVersionCheck();
3409 cudnnOpsTrainVersionCheck();
3410 cudnnCnnInferVersionCheck();
3411 cudnnCnnTrainVersionCheck();
3412 #endif
3413 bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3414 out_algorithms->clear();
3415
3416 std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3417 // clang-format off
3418 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3419 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3420 // Based on cudnn.h, the following is not implemented.
3421 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3422
3423 // Produces incorrect results for some shapes. Disabled for now, see
3424 // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3425 // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3426 // clang-format on
3427 };
3428 if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3429 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3430 }
3431 if (!RequireCudnnDeterminism()) {
3432 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
3433 algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
3434 }
3435
3436 // The algorithms are intentionally ordered for deterministic operation
3437 for (auto i : algo_types) {
3438 if (tensor_op_math_available) {
3439 out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3440 }
3441 out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3442 }
3443
3444 return true;
3445 }
3446
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3447 bool CudnnSupport::DoBatchNormalizationForward(
3448 Stream* stream, const DeviceMemory<float>& x,
3449 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3450 const DeviceMemory<float>& estimated_mean,
3451 const DeviceMemory<float>& estimated_variance,
3452 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3453 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3454 const double exponential_average_factor,
3455 dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3456 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3457 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3458 bool is_training, ScratchAllocator* reserve_space_allocator,
3459 ScratchAllocator* workspace_allocator) {
3460 return IsStatusOk(
3461 DoBatchNormalizationForwardImpl<float, float>(
3462 stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3463 offset, estimated_mean, estimated_variance, side_input, x_desc,
3464 scale_offset_desc, epsilon, exponential_average_factor,
3465 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3466 is_training, reserve_space_allocator, workspace_allocator),
3467 /*report_error=*/true);
3468 }
3469
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3470 bool CudnnSupport::DoBatchNormalizationForward(
3471 Stream* stream, const DeviceMemory<Eigen::half>& x,
3472 const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3473 const DeviceMemory<float>& estimated_mean,
3474 const DeviceMemory<float>& estimated_variance,
3475 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3476 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3477 const double exponential_average_factor,
3478 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3479 DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3480 DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3481 bool is_training, ScratchAllocator* reserve_space_allocator,
3482 ScratchAllocator* workspace_allocator) {
3483 return IsStatusOk(
3484 DoBatchNormalizationForwardImpl<Eigen::half, float>(
3485 stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3486 estimated_mean, estimated_variance, side_input, x_desc,
3487 scale_offset_desc, epsilon, exponential_average_factor,
3488 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3489 is_training, reserve_space_allocator, workspace_allocator),
3490 /*report_error=*/true);
3491 }
3492
3493 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<U> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3494 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3495 Stream* stream, dnn::DataType input_data_type,
3496 dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3497 const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3498 const DeviceMemory<U>& estimated_mean,
3499 const DeviceMemory<U>& estimated_variance,
3500 const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
3501 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3502 const double exponential_average_factor,
3503 dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3504 DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3505 DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3506 bool is_training, ScratchAllocator* reserve_space_allocator,
3507 ScratchAllocator* workspace_allocator) {
3508 CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3509 CudnnTensorDescriptor scale_offset_descriptor(
3510 scale_offset_desc, ToCudnnDataType(scale_data_type));
3511 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3512 if (BatchnormSpatialPersistentEnabled() && is_training) {
3513 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3514 }
3515 float one = 1.0;
3516 float zero = 0.0;
3517 auto cudnn = cudnn_->GetHandle(parent_, stream);
3518
3519 DeviceMemory<uint8> workspace;
3520 DeviceMemory<uint8> reserve_space;
3521
3522 #if CUDNN_VERSION >= 7402
3523 const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
3524 if (side_input.is_null()) {
3525 return activation_mode == dnn::ActivationMode::kNone
3526 ? CUDNN_BATCHNORM_OPS_BN
3527 : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
3528 } else {
3529 return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
3530 }
3531 };
3532 const cudnnBatchNormOps_t bn_ops = get_bn_ops();
3533
3534 // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
3535 CudnnActivationDescriptor activation_desc(
3536 activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
3537
3538 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3539 SE_ASSIGN_OR_RETURN(
3540 workspace,
3541 CreateBatchNormForwardWorkspace(
3542 stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
3543 scale_offset_descriptor, workspace_allocator))
3544 if (is_training) {
3545 size_t reserve_space_size_in_bytes = 0;
3546 RETURN_IF_CUDNN_ERROR(
3547 cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
3548 /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
3549 /*activationDesc=*/activation_desc.handle(),
3550 /*xDesc=*/x_descriptor.handle(),
3551 /*sizeInBytes=*/&reserve_space_size_in_bytes));
3552 SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
3553 reserve_space_size_in_bytes));
3554 }
3555 }
3556 #endif
3557
3558 auto check_no_side_input_or_activation = [&]() -> port::Status {
3559 if (activation_mode != dnn::ActivationMode::kNone ||
3560 !side_input.is_null()) {
3561 return port::Status(
3562 port::error::INTERNAL,
3563 absl::StrCat(
3564 "Side input and activation are not supported by cuDNN version: ",
3565 CUDNN_VERSION));
3566 } else {
3567 return port::Status::OK();
3568 }
3569 };
3570
3571 if (is_training) {
3572 CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3573 << "batch_mean and batch_var must both be null or both be non-null";
3574
3575 void* batch_mean_opaque;
3576 void* batch_var_opaque;
3577 if (!batch_mean->is_null() && !batch_var->is_null()) {
3578 if (exponential_average_factor == 1.0) {
3579 stream->ThenMemZero(batch_mean, batch_mean->size());
3580 stream->ThenMemZero(batch_var, batch_var->size());
3581 }
3582 batch_mean_opaque = batch_mean->opaque();
3583 batch_var_opaque = batch_var->opaque();
3584 } else {
3585 batch_mean_opaque = nullptr;
3586 batch_var_opaque = nullptr;
3587 }
3588
3589 bool called = false;
3590 #if CUDNN_VERSION >= 7402
3591 if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3592 called = true;
3593 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
3594 /*handle=*/cudnn.handle(),
3595 /*mode=*/mode,
3596 /*bnOps=*/bn_ops,
3597 /*alpha=*/&one,
3598 /*beta=*/&zero,
3599 /*xDesc=*/x_descriptor.handle(),
3600 /*xData=*/x.opaque(),
3601 /*zDesc=*/x_descriptor.handle(),
3602 /*zData=*/side_input.opaque(),
3603 /*yDesc=*/x_descriptor.handle(),
3604 /*yData=*/y->opaque(),
3605 /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
3606 /*bnScale=*/scale.opaque(),
3607 /*bnBias=*/offset.opaque(),
3608 /*exponentialAverageFactor=*/exponential_average_factor,
3609 /*resultRunningMean=*/batch_mean_opaque,
3610 /*resultRunningVariance=*/batch_var_opaque,
3611 /*epsilon=*/epsilon,
3612 /*resultSaveMean=*/saved_mean->opaque(),
3613 /*resultSaveInvVariance=*/saved_inv_var->opaque(),
3614 /*activationDesc=*/activation_desc.handle(),
3615 /*workspace=*/workspace.opaque(),
3616 /*workSpaceSizeInBytes=*/workspace.size(),
3617 /*reserveSpace=*/reserve_space.opaque(),
3618 /*reserveSpaceSizeInBytes=*/reserve_space.size()));
3619 }
3620 #endif
3621 if (!called) {
3622 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3623 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3624 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3625 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3626 scale.opaque(), offset.opaque(), exponential_average_factor,
3627 batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
3628 saved_inv_var->opaque()));
3629 }
3630 } else {
3631 const void* maybe_inv_var = estimated_variance.opaque();
3632 SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3633 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3634 cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3635 x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3636 scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3637 epsilon));
3638 }
3639 return port::Status::OK();
3640 }
3641
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3642 bool CudnnSupport::DoBatchNormalizationBackward(
3643 Stream* stream, const DeviceMemory<float>& y_backprop,
3644 const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3645 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3646 const dnn::BatchDescriptor& x_desc,
3647 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3648 DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3649 DeviceMemory<float>* offset_backprop,
3650 DeviceMemory<uint8>* reserve_space_data,
3651 ScratchAllocator* workspace_allocator) {
3652 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3653 stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3654 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3655 epsilon, x_backprop, scale_backprop, offset_backprop,
3656 reserve_space_data, workspace_allocator),
3657 /*report_error=*/true);
3658 }
3659
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3660 bool CudnnSupport::DoBatchNormalizationBackward(
3661 Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3662 const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3663 const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3664 const dnn::BatchDescriptor& x_desc,
3665 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3666 DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3667 DeviceMemory<float>* offset_backprop,
3668 DeviceMemory<uint8>* reserve_space_data,
3669 ScratchAllocator* workspace_allocator) {
3670 return IsStatusOk(DoBatchNormalizationBackwardImpl(
3671 stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3672 x, scale, mean, inv_var, x_desc, scale_offset_desc,
3673 epsilon, x_backprop, scale_backprop, offset_backprop,
3674 reserve_space_data, workspace_allocator),
3675 /*report_error=*/true);
3676 }
3677
3678 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3679 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3680 Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3681 const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3682 const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3683 const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3684 const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3685 DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3686 DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
3687 ScratchAllocator* workspace_allocator) {
3688 CudnnTensorDescriptor x_descriptor(
3689 x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3690 CudnnTensorDescriptor scale_offset_descriptor(
3691 scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3692 cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3693 if (BatchnormSpatialPersistentEnabled()) {
3694 mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3695 }
3696 float one = 1.0;
3697 float zero = 0.0;
3698
3699 auto cudnn = cudnn_->GetHandle(parent_, stream);
3700
3701 bool called = false;
3702 #if CUDNN_VERSION >= 7402
3703 if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
3704 called = true;
3705 const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
3706 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
3707 CreateBatchNormBackwardWorkspace(
3708 stream, cudnn, mode, bn_ops, x_descriptor,
3709 scale_offset_descriptor, workspace_allocator))
3710 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
3711 /*handle=*/cudnn.handle(),
3712 /*mode=*/mode,
3713 /*bnOps=*/bn_ops,
3714 /*alphaDataDiff=*/&one,
3715 /*betaDataDiff=*/&zero,
3716 /*alphaParamDiff=*/&one,
3717 /*betaParamDiff=*/&zero,
3718 /*xDesc=*/x_descriptor.handle(),
3719 /*xData=*/x.opaque(),
3720 /*yDesc=*/nullptr,
3721 /*yData=*/nullptr,
3722 /*dyDesc=*/x_descriptor.handle(),
3723 /*dyData=*/y_backprop.opaque(),
3724 /*dzDesc=*/nullptr,
3725 /*dzData=*/nullptr,
3726 /*dxDesc=*/x_descriptor.handle(),
3727 /*dxData=*/x_backprop->opaque(),
3728 /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
3729 /*bnScaleData=*/scale.opaque(),
3730 /*bnBiasData=*/nullptr,
3731 /*dBnScaleData=*/scale_backprop->opaque(),
3732 /*dBnBiasData=*/offset_backprop->opaque(),
3733 /*epsilon=*/epsilon,
3734 /*savedMean=*/mean.opaque(),
3735 /*savedInvVariance=*/inv_var.opaque(),
3736 /*activationDesc=*/nullptr,
3737 /*workspace=*/workspace.opaque(),
3738 /*workSpaceSizeInBytes=*/workspace.size(),
3739 /*reserveSpace=*/reserve_space_data->opaque(),
3740 /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
3741 }
3742 #endif
3743 if (!called) {
3744 RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3745 cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3746 x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3747 x_descriptor.handle(), x_backprop->opaque(),
3748 scale_offset_descriptor.handle(), scale.opaque(),
3749 scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3750 mean.opaque(), inv_var.opaque()));
3751 }
3752
3753 return port::Status::OK();
3754 }
3755
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3756 port::Status CudnnSupport::DoFusedConvolve(
3757 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3758 const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3759 const dnn::FilterDescriptor& filter_descriptor,
3760 const DeviceMemory<double>& filter_data,
3761 const dnn::ConvolutionDescriptor& convolution_descriptor,
3762 const DeviceMemory<double>& side_input_data, double side_input_scale,
3763 const dnn::BatchDescriptor& bias_descriptor,
3764 const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3765 const dnn::BatchDescriptor& output_descriptor,
3766 DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3767 const dnn::AlgorithmConfig& algorithm_config,
3768 dnn::ProfileResult* output_profile_result) {
3769 return DoFusedConvolveImpl(
3770 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3771 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3772 side_input_scale, bias_descriptor, biases, activation_mode,
3773 output_descriptor, output_data,
3774 GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3775 algorithm_config, output_profile_result);
3776 }
3777
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3778 port::Status CudnnSupport::DoFusedConvolve(
3779 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3780 const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3781 const dnn::FilterDescriptor& filter_descriptor,
3782 const DeviceMemory<float>& filter_data,
3783 const dnn::ConvolutionDescriptor& convolution_descriptor,
3784 const DeviceMemory<float>& side_input_data, float side_input_scale,
3785 const dnn::BatchDescriptor& bias_descriptor,
3786 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3787 const dnn::BatchDescriptor& output_descriptor,
3788 DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3789 const dnn::AlgorithmConfig& algorithm_config,
3790 dnn::ProfileResult* output_profile_result) {
3791 return DoFusedConvolveImpl(
3792 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3793 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3794 side_input_scale, bias_descriptor, biases, activation_mode,
3795 output_descriptor, output_data,
3796 GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3797 algorithm_config, output_profile_result);
3798 }
3799
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3800 port::Status CudnnSupport::DoFusedConvolve(
3801 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3802 const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3803 const dnn::FilterDescriptor& filter_descriptor,
3804 const DeviceMemory<Eigen::half>& filter_data,
3805 const dnn::ConvolutionDescriptor& convolution_descriptor,
3806 const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3807 const dnn::BatchDescriptor& bias_descriptor,
3808 const DeviceMemory<Eigen::half>& biases,
3809 dnn::ActivationMode activation_mode,
3810 const dnn::BatchDescriptor& output_descriptor,
3811 DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3812 const dnn::AlgorithmConfig& algorithm_config,
3813 dnn::ProfileResult* output_profile_result) {
3814 return DoFusedConvolveImpl(
3815 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3816 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3817 side_input_scale, bias_descriptor, biases, activation_mode,
3818 output_descriptor, output_data,
3819 GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3820 algorithm_config, output_profile_result);
3821 }
3822
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3823 port::Status CudnnSupport::DoFusedConvolve(
3824 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3825 const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3826 const dnn::FilterDescriptor& filter_descriptor,
3827 const DeviceMemory<int8>& filter_data,
3828 const dnn::ConvolutionDescriptor& convolution_descriptor,
3829 const DeviceMemory<int8>& side_input_data, float side_input_scale,
3830 const dnn::BatchDescriptor& bias_descriptor,
3831 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3832 const dnn::BatchDescriptor& output_descriptor,
3833 DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3834 const dnn::AlgorithmConfig& algorithm_config,
3835 dnn::ProfileResult* output_profile_result) {
3836 int cc_major, cc_minor;
3837 std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
3838
3839 if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3840 return port::UnimplementedError(
3841 "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
3842 "GPUs with compute capability 6.1 or later.");
3843 }
3844
3845 return DoFusedConvolveImpl(
3846 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3847 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3848 side_input_scale, bias_descriptor, biases, activation_mode,
3849 output_descriptor, output_data,
3850 GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3851 algorithm_config, output_profile_result);
3852 }
3853
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3854 port::Status CudnnSupport::DoFusedConvolve(
3855 Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3856 const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3857 const dnn::FilterDescriptor& filter_descriptor,
3858 const DeviceMemory<int8>& filter_data,
3859 const dnn::ConvolutionDescriptor& convolution_descriptor,
3860 const DeviceMemory<float>& side_input_data, float side_input_scale,
3861 const dnn::BatchDescriptor& bias_descriptor,
3862 const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3863 const dnn::BatchDescriptor& output_descriptor,
3864 DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3865 const dnn::AlgorithmConfig& algorithm_config,
3866 dnn::ProfileResult* output_profile_result) {
3867 int cc_major, cc_minor;
3868 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3869 &cc_minor);
3870 if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3871 return port::UnimplementedError(
3872 "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
3873 "GPUs with compute capability 6.1 or later.");
3874 }
3875
3876 return DoFusedConvolveImpl(
3877 stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3878 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3879 side_input_scale, bias_descriptor, biases, activation_mode,
3880 output_descriptor, output_data,
3881 GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3882 algorithm_config, output_profile_result);
3883 }
3884
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)3885 port::Status CudnnSupport::DoPrepareForCtcLoss(
3886 Stream* stream, dnn::DataType element_type,
3887 const dnn::RnnStateTensorDescriptor& probs_desc,
3888 const dnn::RnnStateTensorDescriptor& grads_desc,
3889 absl::Span<const int> labels_data,
3890 absl::Span<const int> labels_lengths_data,
3891 absl::Span<const int> input_lengths_data,
3892 ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
3893 int* ctc_loss_algo_id) {
3894 auto cudnn = cudnn_->GetHandle(parent_, stream);
3895 // Query the workspace size.
3896 size_t workspace_size_in_bytes = 0;
3897 #if CUDNN_VERSION >= 7603
3898 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3899 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3900 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3901 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3902 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3903
3904 // Try running with `algo`, if successful then pick it. The non-deterministic
3905 // algorithm is first and thus preferentially picked when determinism is not
3906 // required.
3907 auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
3908 : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
3909 cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
3910 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3911 /*gradientsDesc=*/cudnn_grads_desc.handle(),
3912 /*labels=*/labels_data.data(),
3913 /*labelLengths=*/labels_lengths_data.data(),
3914 /*inputLengths=*/input_lengths_data.data(),
3915 /*algo=*/algo,
3916 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3917 /*sizeInBytes=*/&workspace_size_in_bytes);
3918 if (RequireCudnnDeterminism()) {
3919 RETURN_IF_CUDNN_ERROR(status);
3920 }
3921
3922 if (status != CUDNN_STATUS_SUCCESS) {
3923 algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
3924 RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
3925 /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3926 /*gradientsDesc=*/cudnn_grads_desc.handle(),
3927 /*labels=*/labels_data.data(),
3928 /*labelLengths=*/labels_lengths_data.data(),
3929 /*inputLengths=*/input_lengths_data.data(),
3930 /*algo=*/algo,
3931 /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3932 /*sizeInBytes=*/&workspace_size_in_bytes));
3933 }
3934 *ctc_loss_algo_id = algo;
3935 #else
3936 return port::Status(port::error::INVALID_ARGUMENT,
3937 "No supported cudnnGetCTCLossWorkspaceSize when "
3938 "CUDNN_VERSION < 7.6.3");
3939 #endif
3940 // Allocate the workspace.
3941 if (workspace_size_in_bytes == 0) {
3942 *scratch_memory = DeviceMemory<uint8>();
3943 return port::Status::OK();
3944 }
3945 const auto scratch_or =
3946 scratch_allocator->AllocateBytes(workspace_size_in_bytes);
3947 if (scratch_or.ok()) {
3948 *scratch_memory = scratch_or.ValueOrDie();
3949 return port::Status::OK();
3950 }
3951 return port::InternalError(
3952 "Failed to allocate scratch memory for the CuDNN CTC Loss");
3953 }
3954
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)3955 port::Status CudnnSupport::DoCtcLoss(
3956 Stream* stream, dnn::DataType element_type,
3957 const dnn::RnnStateTensorDescriptor& probs_desc,
3958 const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
3959 absl::Span<const int> labels_lengths_data,
3960 absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
3961 const dnn::RnnStateTensorDescriptor& grads_desc,
3962 DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
3963 int ctc_loss_algo_id) {
3964 // Current cuDNN CTC Loss only supports the float datatype
3965 if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
3966 return port::Status(port::error::INVALID_ARGUMENT,
3967 "CudnnCtcLossDescriptor is supported only when the "
3968 "CUDNN_VERSION >= 7.6.3 and DataType is float");
3969 }
3970 CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3971 const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3972 static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3973 const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3974 static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3975 return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
3976 labels_lengths_data, input_lengths_data, costs_data,
3977 cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
3978 scratch_memory, ctc_loss_algo_id);
3979 }
3980
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3981 bool CudnnSupport::DoTransformTensor(Stream* stream,
3982 const dnn::BatchDescriptor& input_desc,
3983 dnn::DataType input_type,
3984 const DeviceMemoryBase& input_data,
3985 const dnn::BatchDescriptor& output_desc,
3986 dnn::DataType output_type, float scale,
3987 DeviceMemoryBase* output_data) {
3988 float beta = 0.0f;
3989 CudnnTensorDescriptor input_tensor_desc(
3990 input_desc, ToCudnnDataType(input_type, input_desc.layout()));
3991 CudnnTensorDescriptor output_tensor_desc(
3992 output_desc, ToCudnnDataType(output_type, output_desc.layout()));
3993 auto cudnn = cudnn_->GetHandle(parent_, stream);
3994 const auto status = [&] {
3995 RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
3996 cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
3997 &beta, output_tensor_desc.handle(), output_data->opaque()));
3998 return port::Status::OK();
3999 }();
4000 return IsStatusOk(status, /*report_error=*/true);
4001 }
4002
4003 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)4004 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
4005 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4006 const DeviceMemory<T>& input_data,
4007 const dnn::BatchDescriptor& bias_descriptor,
4008 DeviceMemory<T>* backward_bias_data) {
4009 cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
4010 CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4011 CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
4012
4013 // Alpha is the scaling factor for input.
4014 float alpha = 1.0;
4015 // Beta is the scaling factor for output.
4016 float beta = 0.0;
4017
4018 auto cudnn = cudnn_->GetHandle(parent_, stream);
4019 RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
4020 cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
4021 bias_nd.handle(), backward_bias_data->opaque()));
4022 return port::Status::OK();
4023 }
4024
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)4025 bool CudnnSupport::DoConvolveBackwardBias(
4026 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4027 const DeviceMemory<double>& input_data,
4028 const dnn::BatchDescriptor& bias_descriptor,
4029 DeviceMemory<double>* backward_bias_data) {
4030 return IsStatusOk(
4031 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4032 bias_descriptor, backward_bias_data),
4033 /*report_error=*/true);
4034 }
4035
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)4036 bool CudnnSupport::DoConvolveBackwardBias(
4037 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4038 const DeviceMemory<float>& input_data,
4039 const dnn::BatchDescriptor& bias_descriptor,
4040 DeviceMemory<float>* backward_bias_data) {
4041 return IsStatusOk(
4042 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4043 bias_descriptor, backward_bias_data),
4044 /*report_error=*/true);
4045 }
4046
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)4047 bool CudnnSupport::DoConvolveBackwardBias(
4048 Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4049 const DeviceMemory<Eigen::half>& input_data,
4050 const dnn::BatchDescriptor& bias_descriptor,
4051 DeviceMemory<Eigen::half>* backward_bias_data) {
4052 return IsStatusOk(
4053 DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4054 bias_descriptor, backward_bias_data),
4055 /*report_error=*/true);
4056 }
4057
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4058 bool CudnnSupport::DoMatMul(Stream* stream,
4059 const DeviceMemory<float>& input_data,
4060 const DeviceMemory<float>& weights,
4061 const dnn::BatchDescriptor& input_dimensions,
4062 const dnn::BatchDescriptor& output_dimensions,
4063 DeviceMemory<float>* output_data) {
4064 if (input_dimensions.count() != output_dimensions.count()) {
4065 LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
4066 return false;
4067 }
4068
4069 // We do not permute the input or output, instead we just
4070 // reinterpret the layout. We are working with row-major matrices
4071 // and the rows of the input and output correspond to batch, so
4072 // batch has to be outermost in both the input and output.
4073 //
4074 // By adding transposes to the BLAS gemm call we could perhaps make
4075 // the kYXDepthBatch layout work as well, but there has been no need
4076 // for that so far.
4077 if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4078 input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4079 LOG(ERROR) << "Unsupported MatMul input layout.";
4080 return false;
4081 }
4082 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4083 output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4084 LOG(ERROR) << "Unsupported MatMul output layout.";
4085 return false;
4086 }
4087
4088 if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
4089 // This is a fast path that also supports the kBatchYXDepth layout.
4090
4091 // The matrices here are in row-major format while BLAS expects
4092 // column-major, i.e. our matrices are transposed as far as BLAS
4093 // is concerned. So we need to compute output^T =
4094 // input^T*weights^T. There is no parameter for transposing the
4095 // output in BLAS gemm, but instead we can transpose both sides of
4096 // the equality to see that this is equivalent to
4097 // output=weights*input. So we only need to swap the order of
4098 // weights and input in the matrix product to correct for the
4099 // row-major versus column-major difference.
4100 const float alpha = 1.0f; // Take the matrix product without scaling it.
4101 const float beta = 0.0f; // Ignore the original values in output_data.
4102 const int64 m = output_dimensions.NodesAcrossFeatureMaps();
4103 const int64 n = input_dimensions.count();
4104 const int64 k = input_dimensions.NodesAcrossFeatureMaps();
4105 stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
4106 blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
4107 m, input_data, k, beta, output_data, m);
4108 } else {
4109 // This is a slower and more complex path that supports output
4110 // width() * height() > 1, though it only supports the
4111 // kBatchYXDepth layout. Does support kBatchDepthYX if output
4112 // feature_map_count() == 1, as then there is no difference
4113 // between the two layouts.
4114 //
4115 // The operation here is the same as above, except that we have to
4116 // do the matrix multiplication for each (y,x) output coordinate
4117 // separately. We then interpret weights as containing K = width()
4118 // * height() different matrices, which we all multiply onto the
4119 // matrix from input_data, yielding K matrix products. We then
4120 // combine these together into one matrix by concatenating all the
4121 // first rows of these matrices, then all the seconds rows and so
4122 // on. We can do this with a batched matrix multiplication, where
4123 // the result is written to a different submatrix of the output
4124 // for each matrix multiplication.
4125 //
4126 // The reason that we only support the kBatchYXDepth output layout
4127 // is that we have to do something in the depth for each (y,x)
4128 // coordinate. The kBatchYXDepth layout has the depth information
4129 // for each point (y,x) in contiguous memory while the
4130 // kBatchDepthYX layout does not.
4131 //
4132 // TODO(broune): Consider a special case for when output depth ==
4133 // 1, as then possibly this could all be done as one matrix
4134 // multiplication instead of a batched one, which should be
4135 // faster. Another possibility would be to add a weights layout
4136 // parameter and then support kBatchDepthYX for a different
4137 // weights layout.
4138 if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4139 !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
4140 output_dimensions.feature_map_count() == 1)) {
4141 LOG(ERROR) << "Unsupported MatMul output layout.";
4142 return false;
4143 }
4144
4145 const float alpha = 1.0f; // Take the matrix product without scaling it.
4146 const float beta = 0.0f; // Ignore the original values in output_data.
4147 const uint64 m = output_dimensions.feature_map_count();
4148 const uint64 n = input_dimensions.count();
4149 const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
4150 const int lda = m;
4151 const int ldb = k;
4152 const int ldc = output_dimensions.NodesAcrossFeatureMaps();
4153 const int batch_count = output_dimensions.NodesPerFeatureMap();
4154
4155 std::vector<DeviceMemory<float>> a(batch_count);
4156 std::vector<DeviceMemory<float>> b(batch_count);
4157 std::vector<DeviceMemory<float>> c(batch_count);
4158 for (int i = 0; i < batch_count; ++i) {
4159 const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
4160 output_dimensions.feature_map_count();
4161 a[i] = DeviceMemory<float>::MakeFromByteSize(
4162 const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
4163 weights_offset,
4164 weights.ElementCount() - weights_offset);
4165
4166 b[i] = input_data;
4167
4168 const int output_offset = i * output_dimensions.feature_map_count();
4169 c[i] = DeviceMemory<float>::MakeFromByteSize(
4170 const_cast<float*>(
4171 reinterpret_cast<const float*>(output_data->opaque())) +
4172 output_offset,
4173 output_data->ElementCount() - output_offset);
4174 }
4175 const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
4176 std::vector<DeviceMemory<float>*> ptrs;
4177 ptrs.reserve(v.size());
4178 for (auto& mem : v) {
4179 ptrs.push_back(&mem);
4180 }
4181 return ptrs;
4182 };
4183
4184 stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
4185 blas::Transpose::kNoTranspose, m, n, k, alpha,
4186 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
4187 ldc, batch_count);
4188 }
4189
4190 return stream->ok();
4191 }
4192
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)4193 bool CudnnSupport::DoBiasAdd(Stream* stream,
4194 const DeviceMemory<float>& input_data,
4195 const DeviceMemory<float>& biases,
4196 const dnn::BatchDescriptor& dimensions,
4197 DeviceMemory<float>* output_data) {
4198 CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
4199
4200 dnn::BatchDescriptor bias_dimensions;
4201 bias_dimensions.set_count(1)
4202 .set_feature_map_count(dimensions.feature_map_count())
4203 .set_height(1)
4204 .set_width(1)
4205 .set_layout(dnn::DataLayout::kBatchYXDepth);
4206 CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
4207
4208 // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
4209 // output_data before doing the addition, unless the input and
4210 // output are at the same address.
4211 if (input_data.opaque() != output_data->opaque()) {
4212 stream->ThenMemcpy(output_data, input_data,
4213 dimensions.ElementCount() * sizeof(float));
4214 if (!stream->ok()) {
4215 LOG(ERROR)
4216 << "stream " << stream
4217 << " could not enqueue a tensor copy as part of bias addition.";
4218 return false;
4219 }
4220 }
4221
4222 const float alpha = 1.0f;
4223 const float beta = 1.0f;
4224
4225 auto cudnn = cudnn_->GetHandle(parent_, stream);
4226
4227 const auto status = [&] {
4228 RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
4229 cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
4230 &beta, input_descriptor.handle(), output_data->opaque()));
4231 return port::Status::OK();
4232 }();
4233 return IsStatusOk(status, /*report_error=*/true);
4234 }
4235
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)4236 bool CudnnSupport::DoActivate(Stream* stream,
4237 dnn::ActivationMode activation_mode,
4238 const dnn::BatchDescriptor& dimensions,
4239 const DeviceMemory<float>& input_data,
4240 DeviceMemory<float>* output_data,
4241 uint64 options) {
4242 CudnnActivationDescriptor activation_desc(
4243 activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
4244
4245 CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
4246 // Alpha is the input scaling factor.
4247 float alpha = 1.0;
4248 // Beta is the output scaling factor.
4249 float beta = 0.0;
4250
4251 auto cudnn = cudnn_->GetHandle(parent_, stream);
4252 const auto status = [&] {
4253 RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
4254 cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
4255 input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
4256 return port::Status::OK();
4257 }();
4258 return IsStatusOk(status, /*report_error=*/true);
4259 }
4260
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)4261 bool CudnnSupport::DoPoolForward(
4262 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4263 const dnn::BatchDescriptor& input_dimensions,
4264 const DeviceMemory<double>& input_data,
4265 const dnn::BatchDescriptor& output_dimensions,
4266 DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
4267 // Alpha is the scaling factor for input.
4268 double alpha = 1.0;
4269 // Beta is the scaling factor for output.
4270 double beta = 0.0;
4271
4272 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4273 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4274 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4275
4276 auto cudnn = cudnn_->GetHandle(parent_, stream);
4277 const auto status = [&] {
4278 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4279 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4280 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4281 return port::Status::OK();
4282 }();
4283 return IsStatusOk(status, /*report_error=*/true);
4284 }
4285
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)4286 bool CudnnSupport::DoPoolForward(
4287 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4288 const dnn::BatchDescriptor& input_dimensions,
4289 const DeviceMemory<float>& input_data,
4290 const dnn::BatchDescriptor& output_dimensions,
4291 DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
4292 // Alpha is the scaling factor for input.
4293 float alpha = 1.0;
4294 // Beta is the scaling factor for output.
4295 float beta = 0.0;
4296
4297 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4298 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4299 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4300
4301 auto cudnn = cudnn_->GetHandle(parent_, stream);
4302 const auto status = [&] {
4303 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4304 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4305 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4306 return port::Status::OK();
4307 }();
4308 return IsStatusOk(status, /*report_error=*/true);
4309 }
4310
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4311 bool CudnnSupport::DoPoolForward(
4312 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4313 const dnn::BatchDescriptor& input_dimensions,
4314 const DeviceMemory<Eigen::half>& input_data,
4315 const dnn::BatchDescriptor& output_dimensions,
4316 DeviceMemory<Eigen::half>* output_data,
4317 ScratchAllocator* workspace_allocator) {
4318 // Alpha is the scaling factor for input.
4319 float alpha = 1.0;
4320 // Beta is the scaling factor for output.
4321 float beta = 0.0;
4322
4323 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4324 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4325 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4326 auto cudnn = cudnn_->GetHandle(parent_, stream);
4327 const auto status = [&] {
4328 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4329 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4330 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4331 return port::Status::OK();
4332 }();
4333 return IsStatusOk(status, /*report_error=*/true);
4334 }
4335
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)4336 bool CudnnSupport::DoPoolForward(
4337 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4338 const dnn::BatchDescriptor& input_dimensions,
4339 const DeviceMemory<int8>& input_data,
4340 const dnn::BatchDescriptor& output_dimensions,
4341 DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
4342 // Alpha is the scaling factor for input.
4343 float alpha = 1.0;
4344 // Beta is the scaling factor for output.
4345 float beta = 0.0;
4346
4347 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
4348 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
4349 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4350
4351 auto cudnn = cudnn_->GetHandle(parent_, stream);
4352 const auto status = [&] {
4353 RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4354 cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4355 input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4356 return port::Status::OK();
4357 }();
4358 return IsStatusOk(status, /*report_error=*/true);
4359 }
4360
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4361 bool CudnnSupport::DoPoolBackward(
4362 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4363 const dnn::BatchDescriptor& input_dimensions,
4364 const DeviceMemory<double>& input_data,
4365 const dnn::BatchDescriptor& output_dimensions,
4366 const DeviceMemory<double>& output_data,
4367 const DeviceMemory<double>& input_diff_data,
4368 DeviceMemory<double>* output_diff_data,
4369 ScratchAllocator* workspace_allocator) {
4370 // Alpha is the scaling factor for input.
4371 double alpha = 1.0;
4372 // Beta is the scaling factor for output.
4373 double beta = 0.0;
4374
4375 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4376 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4377 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4378
4379 auto cudnn = cudnn_->GetHandle(parent_, stream);
4380 const auto status = [&] {
4381 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4382 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4383 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4384 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4385 output_diff_data->opaque()));
4386 return port::Status::OK();
4387 }();
4388 return IsStatusOk(status, /*report_error=*/true);
4389 }
4390
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4391 bool CudnnSupport::DoPoolBackward(
4392 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4393 const dnn::BatchDescriptor& input_dimensions,
4394 const DeviceMemory<float>& input_data,
4395 const dnn::BatchDescriptor& output_dimensions,
4396 const DeviceMemory<float>& output_data,
4397 const DeviceMemory<float>& input_diff_data,
4398 DeviceMemory<float>* output_diff_data,
4399 ScratchAllocator* workspace_allocator) {
4400 // Alpha is the scaling factor for input.
4401 float alpha = 1.0;
4402 // Beta is the scaling factor for output.
4403 float beta = 0.0;
4404
4405 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4406 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4407 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4408
4409 auto cudnn = cudnn_->GetHandle(parent_, stream);
4410 const auto status = [&] {
4411 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4412 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4413 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4414 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4415 output_diff_data->opaque()));
4416 return port::Status::OK();
4417 }();
4418 return IsStatusOk(status, /*report_error=*/true);
4419 }
4420
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4421 bool CudnnSupport::DoPoolBackward(
4422 Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4423 const dnn::BatchDescriptor& input_dimensions,
4424 const DeviceMemory<Eigen::half>& input_data,
4425 const dnn::BatchDescriptor& output_dimensions,
4426 const DeviceMemory<Eigen::half>& output_data,
4427 const DeviceMemory<Eigen::half>& input_diff_data,
4428 DeviceMemory<Eigen::half>* output_diff_data,
4429 ScratchAllocator* workspace_allocator) {
4430 // Alpha is the scaling factor for input.
4431 float alpha = 1.0;
4432 // Beta is the scaling factor for output.
4433 float beta = 0.0;
4434
4435 CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4436 CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4437 CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4438
4439 auto cudnn = cudnn_->GetHandle(parent_, stream);
4440 const auto status = [&] {
4441 RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4442 cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4443 output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4444 src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4445 output_diff_data->opaque()));
4446 return port::Status::OK();
4447 }();
4448 return IsStatusOk(status, /*report_error=*/true);
4449 }
4450
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4451 bool CudnnSupport::DoNormalizeWithDimensions(
4452 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4453 const dnn::BatchDescriptor& dimensions,
4454 const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4455 // Check for unsupported modes.
4456 if (normalize_descriptor.wrap_around()) {
4457 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4458 return false;
4459 }
4460 if (normalize_descriptor.segment_size()) {
4461 LOG(ERROR) << "CUDA LRN does not support segmentation";
4462 return false;
4463 }
4464
4465 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4466 CudnnNormalizeDescriptor normalize(normalize_descriptor);
4467
4468 // Alpha is the scaling factor for input.
4469 float alpha = 1.0f;
4470 // Beta is the scaling factor for output.
4471 float beta = 0.0f;
4472
4473 auto cudnn = cudnn_->GetHandle(parent_, stream);
4474
4475 // Launch the normalization.
4476 const auto status = [&] {
4477 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
4478 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4479 &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
4480 output_data->opaque()));
4481 return port::Status::OK();
4482 }();
4483 return IsStatusOk(status, /*report_error=*/true);
4484 }
4485
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4486 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
4487 Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4488 const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4489 const DeviceMemory<float>& normalized_data,
4490 const DeviceMemory<float>& normalized_variable_gradient,
4491 DeviceMemory<float>* raw_variable_gradient,
4492 ScratchAllocator* workspace_allocator) {
4493 // Check for unsupported modes.
4494 if (normalize_descriptor.wrap_around()) {
4495 LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4496 return false;
4497 }
4498 if (normalize_descriptor.segment_size()) {
4499 LOG(ERROR) << "CUDA LRN does not support segmentation";
4500 return false;
4501 }
4502
4503 CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4504 CudnnNormalizeDescriptor normalize(normalize_descriptor);
4505
4506 float alpha = 1.0f;
4507 float beta = 0.0f;
4508
4509 auto cudnn = cudnn_->GetHandle(parent_, stream);
4510 const auto status = [&] {
4511 RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
4512 cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4513 &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
4514 normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4515 &beta, dims.handle(), raw_variable_gradient->opaque()));
4516 return port::Status::OK();
4517 }();
4518 return IsStatusOk(status, /*report_error=*/true);
4519 }
4520
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4521 bool CudnnSupport::DoDepthConcatenate(
4522 Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4523 port::ArraySlice<const DeviceMemory<float>*> input_data,
4524 DeviceMemory<float>* output_data) {
4525 CHECK_EQ(input_dimensions.size(), input_data.size());
4526
4527 for (const auto& dimensions : input_dimensions) {
4528 if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4529 LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
4530 "supports the kBatchDepthYX layout.";
4531 return false;
4532 }
4533 }
4534
4535 if (input_dimensions.empty()) {
4536 return true; // Nothing to do.
4537 }
4538
4539 dnn::BatchDescriptor output_dimensions =
4540 dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4541
4542 const int64 area = output_dimensions.width() * output_dimensions.height();
4543 const auto index = [area](int64 batch, int64 depth, int64 yx,
4544 int64 max_depth) {
4545 return (batch * max_depth + depth) * area + yx;
4546 };
4547
4548 std::vector<float> output_host(output_dimensions.ElementCount());
4549 std::vector<float> tmp;
4550 int64 depth_sum = 0;
4551 for (size_t i = 0; i < input_data.size(); ++i) {
4552 const auto& dimensions = input_dimensions[i];
4553 tmp.resize(dimensions.ElementCount());
4554 stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4555 port::Status block_status = stream->BlockHostUntilDone();
4556 if (!block_status.ok()) {
4557 LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4558 return false;
4559 }
4560
4561 for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4562 for (int64 yx = 0; yx < area; ++yx) {
4563 for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4564 LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4565 << yx << ' ' << depth;
4566 output_host[index(batch, depth + depth_sum, yx,
4567 output_dimensions.feature_map_count())] =
4568 tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4569 }
4570 }
4571 }
4572 depth_sum += dimensions.feature_map_count();
4573 }
4574 stream->ThenMemcpyH2D<float>(output_host, output_data);
4575 return true;
4576 }
4577
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4578 bool CudnnSupport::DoElementwiseOperate(
4579 Stream* stream, dnn::ElementwiseOperation operation,
4580 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4581 port::ArraySlice<const DeviceMemory<float>*> input_data,
4582 const dnn::BatchDescriptor& output_dimensions,
4583 DeviceMemory<float>* output_data) {
4584 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4585 return false;
4586 }
4587
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4588 bool CudnnSupport::DoXYPad(Stream* stream,
4589 const dnn::BatchDescriptor& dimensions,
4590 const DeviceMemory<float>& input_data,
4591 int64 left_pad, int64 right_pad, int64 top_pad,
4592 int64 bottom_pad, DeviceMemory<float>* output_data) {
4593 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4594 return false;
4595 }
4596
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4597 bool CudnnSupport::DoXYSlice(Stream* stream,
4598 const dnn::BatchDescriptor& dimensions,
4599 const DeviceMemory<float>& input_data,
4600 int64 left_trim, int64 right_trim, int64 top_trim,
4601 int64 bottom_trim,
4602 DeviceMemory<float>* output_data) {
4603 LOG(FATAL) << "not yet implemented"; // TODO(leary)
4604 return false;
4605 }
4606
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4607 bool CudnnSupport::DoMemcpyD2HQuantized(
4608 Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4609 dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4610 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4611 return false;
4612 }
4613
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4614 bool CudnnSupport::DoMemcpyH2DQuantized(
4615 Stream* stream, const void* host_src, int64 size,
4616 dnn::QuantizedActivationMode mode,
4617 DeviceMemory<float>* gpu_unquantized_dst) {
4618 LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4619 return false;
4620 }
4621
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4622 bool CudnnSupport::DeriveOutputBatchDescriptor(
4623 const dnn::BatchDescriptor& batch_descriptor,
4624 const dnn::FilterDescriptor& filter_descriptor,
4625 const dnn::ConvolutionDescriptor& convolution_descriptor,
4626 dnn::BatchDescriptor* output_batch_descriptor) {
4627 CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4628 CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4629 CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4630
4631 int dn = batch_descriptor.ndims() + 2;
4632 std::vector<int> dims(dn); // in BDYX
4633 const auto status = [&] {
4634 RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4635 conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4636 output_batch_descriptor->set_count(dims[0])
4637 .set_feature_map_count(dims[1])
4638 .set_layout(batch_descriptor.layout());
4639
4640 for (int i = 0; i < batch_descriptor.ndims(); i++) {
4641 output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4642 dims.rbegin()[i]);
4643 }
4644 return port::Status::OK();
4645 }();
4646 return IsStatusOk(status, /*report_error=*/true);
4647 }
4648
4649 } // namespace gpu
4650
initialize_cudnn()4651 void initialize_cudnn() {
4652 port::Status status =
4653 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4654 cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4655 [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4656 gpu::GpuExecutor* cuda_executor =
4657 dynamic_cast<gpu::GpuExecutor*>(parent);
4658 if (cuda_executor == nullptr) {
4659 LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4660 << "support library with a non-CUDA StreamExecutor";
4661 return nullptr;
4662 }
4663
4664 gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4665 if (!dnn->Init().ok()) {
4666 // Note: Init() will log a more specific error.
4667 delete dnn;
4668 return nullptr;
4669 }
4670 return dnn;
4671 });
4672
4673 if (!status.ok()) {
4674 LOG(ERROR) << "Unable to register cuDNN factory: "
4675 << status.error_message();
4676 }
4677
4678 PluginRegistry::Instance()->SetDefaultFactory(
4679 cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4680 }
4681
4682 } // namespace stream_executor
4683
4684 #pragma clang diagnostic pop
4685
4686 REGISTER_MODULE_INITIALIZER(register_cudnn,
4687 { stream_executor::initialize_cudnn(); });
4688