1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/cuda/cuda_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "third_party/eigen3/Eigen/Core"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/platform/tensor_float_32_utils.h"
27 #include "tensorflow/core/util/env_var.h"
28 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
29 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
30 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
31 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
32 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
33 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
34 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
35 #include "tensorflow/stream_executor/cuda/cudnn_version.h"
36 #include "tensorflow/stream_executor/dnn.h"
37 #include "tensorflow/stream_executor/lib/env.h"
38 #include "tensorflow/stream_executor/lib/error.h"
39 #include "tensorflow/stream_executor/lib/initialize.h"
40 #include "tensorflow/stream_executor/lib/mathutil.h"
41 #include "tensorflow/stream_executor/lib/threadpool.h"
42 #include "tensorflow/stream_executor/platform/logging.h"
43 #include "tensorflow/stream_executor/plugin_registry.h"
44 #include "tensorflow/stream_executor/scratch_allocator.h"
45 #include "tensorflow/stream_executor/stream.h"
46 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
47 // clang-format off
48 #include "third_party/gpus/cudnn/cudnn.h"
49 #include "absl/strings/string_view.h"
50 // clang-format on
51 
52 #pragma clang diagnostic push
53 
54 // Make sure that Eigen::half forward declaration in dnn.h matches the
55 // declaration in Eigen.
56 #pragma clang diagnostic warning "-Wmismatched-tags"
57 
58 namespace stream_executor {
59 namespace gpu {
60 
61 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
62 
63 namespace {
64 
65 static_assert(CUDNN_VERSION >= 7300, "cuDNN needs to be version 7.3 or higher");
66 
67 // Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
68 #define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
69 
70 // If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
71 // function with a non-successful port::Status.
72 #define RETURN_IF_CUDNN_ERROR(expr)                                      \
73   do {                                                                   \
74     cudnnStatus_t _status = expr;                                        \
75     if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) {             \
76       std::ostringstream oss;                                            \
77       oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
78           << "): '" << #expr << "'";                                     \
79       return port::Status(port::error::UNKNOWN, oss.str().c_str());      \
80     }                                                                    \
81   } while (false)
82 
83 // Converts (via narrowing) a type T value to a type U, and checks that the
84 // value has no value change due to the conversion.
85 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)86 NarrowT CheckedNarrowing(const WideT& wide) {
87   NarrowT narrow = wide;
88   CHECK_EQ(narrow, wide)
89       << "checked narrowing failed; values not equal post-conversion";
90   return narrow;
91 }
92 
ToString(cudnnStatus_t status)93 std::string ToString(cudnnStatus_t status) {
94   switch (status) {
95     case CUDNN_STATUS_SUCCESS:
96       return "CUDNN_STATUS_SUCCESS";
97     case CUDNN_STATUS_NOT_INITIALIZED:
98       return "CUDNN_STATUS_NOT_INITIALIZED";
99     case CUDNN_STATUS_ALLOC_FAILED:
100       return "CUDNN_STATUS_ALLOC_FAILED";
101     case CUDNN_STATUS_BAD_PARAM:
102       return "CUDNN_STATUS_BAD_PARAM";
103     case CUDNN_STATUS_INTERNAL_ERROR:
104       return "CUDNN_STATUS_INTERNAL_ERROR";
105     case CUDNN_STATUS_INVALID_VALUE:
106       return "CUDNN_STATUS_INVALID_VALUE";
107     case CUDNN_STATUS_ARCH_MISMATCH:
108       return "CUDNN_STATUS_ARCH_MISMATCH";
109     case CUDNN_STATUS_MAPPING_ERROR:
110       return "CUDNN_STATUS_MAPPING_ERROR";
111     case CUDNN_STATUS_EXECUTION_FAILED:
112       return "CUDNN_STATUS_EXECUTION_FAILED";
113     case CUDNN_STATUS_NOT_SUPPORTED:
114       return "CUDNN_STATUS_NOT_SUPPORTED";
115     case CUDNN_STATUS_LICENSE_ERROR:
116       return "CUDNN_STATUS_LICENSE_ERROR";
117     case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
118       return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
119     case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
120       return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
121     case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
122       return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
123     default:
124       return absl::StrCat("<unknown cudnn status: ", static_cast<int>(status),
125                           ">");
126   }
127 }
128 
129 // RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
130 //
131 // See CudnnAccess::GetHandle() for details.
132 class CudnnHandle {
133  public:
134   // Takes ownership of the executor context and the lock to access cuDNN
135   // using handle.
CudnnHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,cudnnHandle_t handle)136   CudnnHandle(gpu::ScopedActivateExecutorContext context,
137               std::unique_ptr<absl::MutexLock> lock, cudnnHandle_t handle)
138       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
139 
140   // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep
141   // a copy.
handle() const142   cudnnHandle_t handle() const { return handle_; }
143 
144  private:
145   gpu::ScopedActivateExecutorContext context_;
146   std::unique_ptr<absl::MutexLock> lock_;
147   cudnnHandle_t handle_;  // Not owned.
148 };
149 
150 }  // namespace
151 
152 // Wraps a cuDNN handle and provides access to it through CudnnHandle
153 // instances, which also locks a mutex, acquires the CUDA context, and sets
154 // the stream that cuDNN should use to enqueue any work.
155 //
156 // Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
157 class CudnnAccess {
158  public:
159   // Takes ownership of the handle.
CudnnAccess(cudnnHandle_t handle)160   explicit CudnnAccess(cudnnHandle_t handle) : handle_(handle) {}
161 
~CudnnAccess()162   ~CudnnAccess() {
163     absl::MutexLock lock(&mutex_);
164     cudnnDestroy(handle_);
165   }
166 
167   // Creates a CudnnHandle instance for stream.
168   //
169   // cuDNN API calls using the same handle instance need to be serialized
170   // across threads. This is guaranteed by CudnnHandle instances locking the
171   // mutex owned by this class.
172   //
173   // Most cuDNN APIs taking a handle perform work on a CUDA stream. The
174   // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
175   // to use the provided stream.
176   //
177   // The stream argument may be null, which translates to the legacy default
178   // stream. See
179   // https://docs.nvidia.com/cuda/cuda-driver-api/stream-sync-behavior.html.
180   // The legacy default stream synchronizes with all other streams and it is
181   // therefore a bad idea (performance wise) to call any cuDNN APIs that
182   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)183   CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) {
184     auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
185     mutex_.AssertHeld();
186     gpu::ScopedActivateExecutorContext context(executor);
187     CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy;
188     const auto status = cudnnSetStream(handle_, cu_stream);
189     CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
190     return CudnnHandle(std::move(context), std::move(lock), handle_);
191   }
192 
193  private:
194   // Guards the enqueueing of cuDNN operations via the handle_ below.
195   absl::Mutex mutex_;
196 
197   // cuDNN library handle.
198   cudnnHandle_t handle_ TF_GUARDED_BY(mutex_);  // Owned.
199 };
200 
201 namespace {
202 
203 // A helper function to return the internal compute type for
204 // RNNs in cudnn.
205 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
206 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)207 cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
208   cudnnConvolutionFwdAlgo_t algo =
209       cudnnConvolutionFwdAlgo_t(algorithm.algo_id());
210   switch (algo) {
211     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
212     case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
213     case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
214     case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
215     case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
216     case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
217     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
218     case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
219       return algo;
220     default:
221       LOG(FATAL) << "Unsupported Cudnn convolution forward algorithm: "
222                  << algorithm.algo_id();
223   }
224 }
225 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)226 cudnnConvolutionBwdDataAlgo_t ToConvBackwardDataAlgo(
227     dnn::AlgorithmDesc algorithm) {
228   cudnnConvolutionBwdDataAlgo_t algo =
229       cudnnConvolutionBwdDataAlgo_t(algorithm.algo_id());
230   switch (algo) {
231     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
232     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
233     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
234     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
235     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
236     case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
237       return algo;
238     default:
239       LOG(FATAL)
240           << "Unsupported Cudnn convolution backward algorithm for data: "
241           << algorithm.algo_id();
242   }
243 }
244 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)245 cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
246     dnn::AlgorithmDesc algorithm) {
247   cudnnConvolutionBwdFilterAlgo_t algo =
248       cudnnConvolutionBwdFilterAlgo_t(algorithm.algo_id());
249   switch (algo) {
250     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
251     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
252     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
253     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
254     // Based on cudnn.h, the following is not implemented.
255     // case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD:
256     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
257     case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
258       return algo;
259     default:
260       LOG(FATAL)
261           << "Unsupported Cudnn convolution backward algorithm for filter: "
262           << algorithm.algo_id();
263   }
264 }
265 
GetCudnnProperty(libraryPropertyType type)266 port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
267   int value;
268   RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
269   return value;
270 }
271 
ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm)272 cudnnRNNAlgo_t ToCudnnRNNAlgo(absl::optional<dnn::AlgorithmDesc> algorithm) {
273   if (!algorithm.has_value()) {
274     return CUDNN_RNN_ALGO_STANDARD;
275   }
276   cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm->algo_id());
277   switch (algo) {
278     case CUDNN_RNN_ALGO_STANDARD:
279     case CUDNN_RNN_ALGO_PERSIST_STATIC:
280     case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
281       return algo;
282     default:
283       LOG(FATAL) << "Unsupported Cudnn RNN algorithm: " << algorithm->algo_id();
284   }
285 }
286 
GetLoadedCudnnVersion(CudnnVersion * version)287 port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
288   SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
289   SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
290   SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
291   return port::Status::OK();
292 }
293 
294 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
PreloadCudnnLibrary(cudnnStatus_t (* version_check_fn)(),absl::string_view sub_library)295 void PreloadCudnnLibrary(cudnnStatus_t (*version_check_fn)(),
296                          absl::string_view sub_library) {
297   cudnnStatus_t status = version_check_fn();
298   if (status != CUDNN_STATUS_SUCCESS) {
299     VLOG(1) << "Could not pre-initialize cuDNN sub-library " << sub_library
300             << ".  Error: " << cudnnGetErrorString(status) << ".";
301   }
302 }
303 #endif
304 
305 }  // namespace
306 
CudnnSupport(GpuExecutor * parent)307 CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {}
308 
Init()309 port::Status CudnnSupport::Init() {
310   ScopedActivateExecutorContext context(parent_);
311   cudnnHandle_t cudnn_handle = nullptr;
312   const auto status = cudnnCreate(&cudnn_handle);
313   if (status == CUDNN_STATUS_SUCCESS) {
314     CudnnVersion source_version(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
315 
316     CudnnVersion loaded_version;
317     TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&loaded_version));
318     if (!IsSourceCompatibleWithCudnnLibrary(source_version, loaded_version)) {
319       const std::string error = absl::StrCat(
320           "Loaded runtime CuDNN library: ", loaded_version.ToString(),
321           " but source was compiled with: ", source_version.ToString(),
322           ".  CuDNN library needs to have matching major version and equal or "
323           "higher minor version. If using a binary install, upgrade your CuDNN "
324           "library.  If building from sources, make sure the library loaded at "
325           "runtime is compatible with the version specified during compile "
326           "configuration.");
327       LOG(ERROR) << error;
328       cudnnDestroy(cudnn_handle);
329       return port::Status(port::error::INTERNAL, error);
330     }
331 
332     cudnn_.reset(new CudnnAccess(cudnn_handle));
333 
334     LOG(INFO) << "Loaded cuDNN version " << cudnnGetVersion();
335     return port::Status::OK();
336   }
337 
338   CHECK_EQ(cudnn_handle, nullptr);
339   LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
340   if (status == CUDNN_STATUS_NOT_INITIALIZED) {
341     auto result = gpu::Diagnostician::FindKernelDriverVersion();
342     if (!result.ok()) {
343       LOG(ERROR) << "Error retrieving driver version: "
344                  << cuda::DriverVersionStatusToString(result);
345     } else {
346       const auto& version = result.ValueOrDie();
347       LOG(ERROR) << "Possibly insufficient driver version: "
348                  << cuda::DriverVersionToString(version);
349     }
350   }
351 
352   return port::Status(port::error::INTERNAL,
353                       absl::StrCat("cudnn library could not create a handle: ",
354                                    ToString(status)));
355 }
356 
357 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()358 CudnnSupport::GetVersion() {
359   CudnnVersion version;
360   TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
361   return perftools::gputools::dnn::VersionInfo(
362       version.major_version, version.minor_version, version.patch_level);
363 }
364 
365 namespace {
366 
367 // Deleter functors for cuDNN types that need to be deleted.
368 struct TensorDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::TensorDescriptorDeleter369   void operator()(cudnnTensorDescriptor_t descriptor) const {
370     CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
371   }
372 };
373 struct RNNDataDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::RNNDataDescriptorDeleter374   void operator()(cudnnRNNDataDescriptor_t descriptor) const {
375     CHECK_CUDNN_OK(cudnnDestroyRNNDataDescriptor(descriptor));
376   }
377 };
378 struct FilterDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::FilterDescriptorDeleter379   void operator()(cudnnFilterDescriptor_t descriptor) const {
380     CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
381   }
382 };
383 struct ConvolutionDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::ConvolutionDescriptorDeleter384   void operator()(cudnnConvolutionDescriptor_t descriptor) const {
385     CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
386   }
387 };
388 struct PoolingDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::PoolingDescriptorDeleter389   void operator()(cudnnPoolingDescriptor_t descriptor) const {
390     CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
391   }
392 };
393 struct LrnDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::LrnDescriptorDeleter394   void operator()(cudnnLRNDescriptor_t descriptor) const {
395     CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
396   }
397 };
398 
399 struct ActivationDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::ActivationDescriptorDeleter400   void operator()(cudnnActivationDescriptor_t descriptor) const {
401     CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
402   }
403 };
404 struct DropoutDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::DropoutDescriptorDeleter405   void operator()(cudnnDropoutDescriptor_t descriptor) const {
406     CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
407   }
408 };
409 struct RnnDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::RnnDescriptorDeleter410   void operator()(cudnnRNNDescriptor_t descriptor) const {
411     CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
412   }
413 };
414 struct PersistentRnnPlanDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::PersistentRnnPlanDeleter415   void operator()(cudnnPersistentRNNPlan_t plan) const {
416     CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
417   }
418 };
419 #if CUDNN_VERSION >= 7603
420 struct CtcLossDescriptorDeleter {
operator ()stream_executor::gpu::__anon040dea3a0311::CtcLossDescriptorDeleter421   void operator()(cudnnCTCLossDescriptor_t descriptor) const {
422     CHECK_CUDNN_OK(cudnnDestroyCTCLossDescriptor(descriptor));
423   }
424 };
425 #endif
426 
427 // RAII wrappers for cuDNN types.
428 using TensorDescriptor =
429     std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
430 using RNNDataDescriptor =
431     std::unique_ptr<cudnnRNNDataStruct, RNNDataDescriptorDeleter>;
432 using FilterDescriptor =
433     std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
434 using ConvolutionDescriptor =
435     std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
436 using PoolingDescriptor =
437     std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
438 using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
439 using ActivationDescriptor =
440     std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
441 using DropoutDescriptor =
442     std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
443 using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
444 using PersistentRnnPlan =
445     std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
446 #if CUDNN_VERSION >= 7603
447 using CtcLossDescriptor =
448     std::unique_ptr<cudnnCTCLossStruct, CtcLossDescriptorDeleter>;
449 #endif
450 
451 // Factory methods for cuDNN types.
CreateTensorDescriptor()452 TensorDescriptor CreateTensorDescriptor() {
453   cudnnTensorDescriptor_t result;
454   CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
455   return TensorDescriptor(result);
456 }
CreateRNNDataDescriptor()457 RNNDataDescriptor CreateRNNDataDescriptor() {
458   cudnnRNNDataDescriptor_t result;
459   CHECK_CUDNN_OK(cudnnCreateRNNDataDescriptor(&result));
460   return RNNDataDescriptor(result);
461 }
CreateFilterDescriptor()462 FilterDescriptor CreateFilterDescriptor() {
463   cudnnFilterDescriptor_t result;
464   CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
465   return FilterDescriptor(result);
466 }
CreateConvolutionDescriptor()467 ConvolutionDescriptor CreateConvolutionDescriptor() {
468   cudnnConvolutionDescriptor_t result;
469   CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
470   return ConvolutionDescriptor(result);
471 }
CreatePoolingDescriptor()472 PoolingDescriptor CreatePoolingDescriptor() {
473   cudnnPoolingDescriptor_t result;
474   CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
475   return PoolingDescriptor(result);
476 }
CreateLrnDescriptor()477 LrnDescriptor CreateLrnDescriptor() {
478   cudnnLRNDescriptor_t result;
479   CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
480   return LrnDescriptor(result);
481 }
CreateActivationDescriptor()482 ActivationDescriptor CreateActivationDescriptor() {
483   cudnnActivationDescriptor_t result;
484   CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
485   return ActivationDescriptor(result);
486 }
CreateDropoutDescriptor()487 DropoutDescriptor CreateDropoutDescriptor() {
488   cudnnDropoutDescriptor_t result;
489   CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
490   return DropoutDescriptor(result);
491 }
CreateRnnDescriptor()492 RnnDescriptor CreateRnnDescriptor() {
493   cudnnRNNDescriptor_t result;
494   CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
495   return RnnDescriptor(result);
496 }
497 #if CUDNN_VERSION >= 7603
CreateCtcLossDescriptor()498 CtcLossDescriptor CreateCtcLossDescriptor() {
499   cudnnCTCLossDescriptor_t result;
500   CHECK_CUDNN_OK(cudnnCreateCTCLossDescriptor(&result));
501   return CtcLossDescriptor(result);
502 }
503 #endif
504 
CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,int batch_size,cudnnDataType_t data_type)505 port::StatusOr<PersistentRnnPlan> CreatePersistentRnnPlan(
506     cudnnRNNDescriptor_t rnn_desc, int batch_size, cudnnDataType_t data_type) {
507   cudnnPersistentRNNPlan_t result;
508   RETURN_IF_CUDNN_ERROR(
509       cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
510   return port::StatusOr<PersistentRnnPlan>(PersistentRnnPlan(result));
511 }
512 
513 // Turns a BatchDescriptor structure into a cudnn tensor handle within a
514 // scope.
515 class CudnnTensorDescriptor {
516  public:
CudnnTensorDescriptor(const dnn::BatchDescriptor & batch_descriptor,cudnnDataType_t elem_type)517   CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
518                         cudnnDataType_t elem_type)
519       : handle_(CreateTensorDescriptor()) {
520     switch (batch_descriptor.layout()) {
521       case dnn::DataLayout::kBatchYXDepth:
522       case dnn::DataLayout::kBatchDepthYX: {
523         const int nd = batch_descriptor.ndims() + 2;
524         // cuDNN requires the strides and dims to be ordered as BDYX.
525         std::vector<int64> strides64 =
526             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
527         std::vector<int64> dims64 =
528             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
529 
530         // cuDNN requires arrays of ints.
531         std::vector<int> strides(nd);
532         std::vector<int> dims(nd);
533         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
534                        &CheckedNarrowing<int64, int>);
535         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
536                        &CheckedNarrowing<int64, int>);
537         CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
538                                                   dims.data(), strides.data()))
539             << "batch_descriptor: " << batch_descriptor.ToString();
540       } break;
541       case dnn::DataLayout::kBatchDepthYX4: {
542         CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
543             handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
544             batch_descriptor.count(), batch_descriptor.feature_map_count(),
545             batch_descriptor.height(), batch_descriptor.width()))
546             << "batch_descriptor: " << batch_descriptor.ToString();
547       } break;
548       default:
549         LOG(FATAL) << "Unsupported tensor format "
550                    << DataLayoutString(batch_descriptor.layout());
551         break;
552     }
553   }
554 
handle() const555   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
556 
557  private:
558   TensorDescriptor handle_;
559 
560   SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
561 };
562 
563 // Turns a FilterDescriptor structure into a cudnn filter handle within a
564 // scope.
565 class CudnnFilterDescriptor {
566  public:
CudnnFilterDescriptor(const dnn::FilterDescriptor & filter_descriptor,cudnnDataType_t elem_type)567   CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
568                         cudnnDataType_t elem_type)
569       : handle_(CreateFilterDescriptor()) {
570     // TODO(b/23032134): Even if the filter layout is not supported,
571     // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
572     // it does not take layout as an input. Maybe force cuDNN by giving wrong
573     // inputs intentionally?
574     cudnnTensorFormat_t format;
575     switch (filter_descriptor.layout()) {
576       case dnn::FilterLayout::kOutputInputYX:
577         format = CUDNN_TENSOR_NCHW;
578         break;
579       case dnn::FilterLayout::kOutputYXInput:
580         format = CUDNN_TENSOR_NHWC;
581         break;
582       case dnn::FilterLayout::kOutputInputYX4:
583         format = CUDNN_TENSOR_NCHW_VECT_C;
584         break;
585       default:
586         LOG(FATAL) << "Unsupported filter format "
587                    << FilterLayoutString(filter_descriptor.layout());
588         break;
589     }
590 
591     std::vector<int> dims(2 + filter_descriptor.ndims());
592     dims[0] = filter_descriptor.output_feature_map_count();
593     dims[1] = filter_descriptor.input_feature_map_count();
594     absl::Span<const int64> spatial_dims =
595         filter_descriptor.input_filter_dims();
596     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
597 
598     CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
599                                               dims.size(), dims.data()));
600   }
601 
handle() const602   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
603 
604  private:
605   FilterDescriptor handle_;  // Owned.
606 
607   SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
608 };
609 
610 // A helper function to decide whether to use
611 // CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
612 // some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
613 // and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
614 // reason we set it to false by default is that this mode may use scaled
615 // atomic integer reduction that may cause a numerical overflow for certain
616 // input data range.
617 // TODO(yangzihao): Use autotune to choose between this mode and
618 // CUDNN_BATCHNORM_SPATIAL mode.
BatchnormSpatialPersistentEnabled()619 bool BatchnormSpatialPersistentEnabled() {
620   static bool is_enabled = [] {
621     bool is_enabled = false;
622     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
623         "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
624         /*default_val=*/false, &is_enabled));
625     return is_enabled;
626   }();
627   return is_enabled;
628 }
629 
630 // The following function allows deterministic ops to be implemented relatively
631 // quickly using environment variables. It is intended to be temporary. The
632 // longer-term intention is to enable deterministic ops via tf.config and
633 // appropriate plumbing. See the discussion on PR 34951 for more information:
634 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
635 // This function and associated comment are replicated in the following three
636 // places:
637 //   1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
638 //   2. tensorflow/core/kernels/gpu_utils.cc
639 //   3. tensorflow/stream_executor/cuda/cuda_dnn.cc
640 // When implementing the plumbing, you should also search for the use of
641 // TF_DETERMINISTIC_OPS on its own.
642 // TODO(duncanriach): move to an API that uses tf.config and implement the first
643 //                    phase of plumbing.
RequireCudnnDeterminism()644 bool RequireCudnnDeterminism() {
645   static bool require_cudnn_determinism = [] {
646     bool deterministic_ops = false;
647     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
648                                                /*default_val=*/false,
649                                                &deterministic_ops));
650     bool cudnn_deterministic = false;
651     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
652                                                /*default_val=*/false,
653                                                &cudnn_deterministic));
654     return deterministic_ops || cudnn_deterministic;
655   }();
656   return require_cudnn_determinism;
657 }
658 
659 // A helper function to decide whether to force the default conv algorithm.
ConvUseDefaultAlgorithm()660 bool ConvUseDefaultAlgorithm() {
661   static bool use_default = [] {
662     bool use_default = false;
663     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_USE_DEFAULT_CONV_ALGO",
664                                                /*default_val=*/false,
665                                                &use_default));
666     return use_default;
667   }();
668   return use_default;
669 }
670 
GetCcMajorMinor(Stream * stream)671 std::tuple<int, int> GetCcMajorMinor(Stream* stream) {
672   int cc_major, cc_minor;
673   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
674                                                                    &cc_minor);
675   return std::make_tuple(cc_major, cc_minor);
676 }
677 
678 // Turns a ConvolutionDescriptor structure into a cudnn convolution handle
679 // within a scope.
680 class CudnnConvolutionDescriptor {
681  public:
CudnnConvolutionDescriptor(const dnn::ConvolutionDescriptor & convolution_descriptor,cudnnDataType_t data_type)682   CudnnConvolutionDescriptor(
683       const dnn::ConvolutionDescriptor& convolution_descriptor,
684       cudnnDataType_t data_type)
685       : handle_(CreateConvolutionDescriptor()) {
686     absl::Span<const int64> strides64 = convolution_descriptor.strides();
687     absl::Span<const int64> padding64 = convolution_descriptor.padding();
688     absl::Span<const int64> dilations64 = convolution_descriptor.dilations();
689     CHECK_NE(convolution_descriptor.pad_alignment(),
690              dnn::PadAlignment::kTensorFlowPadding)
691         << "TensorFlow padding alignment is not supported.";
692 
693     // cuDNN requires arrays of ints.
694     std::vector<int> strides(convolution_descriptor.ndims());
695     std::vector<int> padding(convolution_descriptor.ndims());
696     std::vector<int> dilations(convolution_descriptor.ndims());
697     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
698                    &CheckedNarrowing<int64, int>);
699     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
700                    &CheckedNarrowing<int64, int>);
701     // TODO(yangzihao): Test with negative dilation to make sure that cudnn
702     // doesn't crash.
703     std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
704                    &CheckedNarrowing<int64, int>);
705 
706     CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
707         handle_.get(), convolution_descriptor.ndims(), padding.data(),
708         strides.data(), dilations.data(),
709         convolution_descriptor.convolution_not_crosscorr()
710             ? CUDNN_CONVOLUTION
711             : CUDNN_CROSS_CORRELATION,
712         data_type));
713 
714 #if CUDNN_MAJOR >= 7
715     VLOG(2) << "Requesting grouped convolution: "
716             << convolution_descriptor.group_count();
717     CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
718         handle_.get(), convolution_descriptor.group_count()));
719 #else
720     CHECK_EQ(convolution_descriptor.group_count(), 1)
721         << "Requested grouped convolution for cuDNN version < 7";
722 #endif
723   }
724 
set_use_tensor_op_math(bool use_tensor_op_math)725   void set_use_tensor_op_math(bool use_tensor_op_math) {
726     cudnnMathType_t math_type =
727 #if CUDNN_VERSION >= 8000
728         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH);
729 #else
730         (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
731 #endif
732     CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
733   }
734 
handle() const735   cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
736 
737  private:
738   ConvolutionDescriptor handle_;  // Owned.
739 
740   SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
741 };
742 
743 // A helper function to query if a CudnnConvolutionDescriptor has tensor_op_math
744 // set
IsTensorMathOpSet(const CudnnConvolutionDescriptor & conv)745 static bool IsTensorMathOpSet(const CudnnConvolutionDescriptor& conv) {
746   cudnnMathType_t math_type;
747   CHECK_CUDNN_OK(cudnnGetConvolutionMathType(conv.handle(), &math_type));
748 #if CUDNN_VERSION >= 8000
749   return math_type != CUDNN_FMA_MATH;
750 #else
751   return math_type == CUDNN_TENSOR_OP_MATH;
752 #endif
753 }
754 
TensorOpMathAvailable(int cc_major)755 static bool TensorOpMathAvailable(int cc_major) { return cc_major >= 7; }
756 
IsTensorMathEnabled(Stream * stream,dnn::DataType input_type)757 static bool IsTensorMathEnabled(Stream* stream, dnn::DataType input_type) {
758   int cc_major, cc_minor;
759   std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
760   if (!TensorOpMathAvailable(cc_major)) {
761     return false;
762   }
763   if (input_type == dnn::DataType::kFloat) {
764 #if CUDNN_VERSION < 8000
765     return false;
766 #else
767     if (!tensorflow::tensor_float_32_execution_enabled()) {
768       return false;
769     }
770 #endif
771   }
772   return true;
773 }
774 
775 // Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
776 // within a scope.
777 class CudnnPoolingDescriptor {
778  public:
CudnnPoolingDescriptor(const dnn::PoolingDescriptor & pooling_descriptor)779   explicit CudnnPoolingDescriptor(
780       const dnn::PoolingDescriptor& pooling_descriptor)
781       : handle_(CreatePoolingDescriptor()) {
782     absl::Span<const int64> strides64 = pooling_descriptor.strides();
783     absl::Span<const int64> padding64 = pooling_descriptor.padding();
784     absl::Span<const int64> shape64 = pooling_descriptor.window();
785 
786     const int nd = pooling_descriptor.ndims();
787     std::vector<int> shape(nd);
788     std::vector<int> padding(nd);
789     std::vector<int> strides(nd);
790     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
791                    &CheckedNarrowing<int64, int>);
792     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
793                    &CheckedNarrowing<int64, int>);
794     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
795                    &CheckedNarrowing<int64, int>);
796     bool propagate_nans = pooling_descriptor.propagate_nans();
797     const auto cudnn_max_pooling_mode = RequireCudnnDeterminism()
798                                             ? CUDNN_POOLING_MAX_DETERMINISTIC
799                                             : CUDNN_POOLING_MAX;
800     CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
801         handle_.get(),
802         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
803              ? cudnn_max_pooling_mode
804              : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
805         propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
806         shape.data(), padding.data(), strides.data()));
807   }
808 
handle() const809   cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
810 
811  private:
812   PoolingDescriptor handle_;  // Owned.
813 
814   SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
815 };
816 
817 // Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
818 class CudnnNormalizeDescriptor {
819  public:
CudnnNormalizeDescriptor(const dnn::NormalizeDescriptor & normalize_descriptor)820   explicit CudnnNormalizeDescriptor(
821       const dnn::NormalizeDescriptor& normalize_descriptor)
822       : handle_(CreateLrnDescriptor()) {
823     // The range specifies that the indices in the closed range
824     // [i - range, i + range] should be included in the normalization for index
825     // i. The lrnN value is the total number of elements in the range, so
826     // lrnN = 2*range + 1.
827     unsigned lrnN = 2 * normalize_descriptor.range() + 1;
828 
829     // Note that SE defines the normalization operation as
830     //
831     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
832     //
833     // but cuDNN defines it as
834     //
835     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
836     //
837     // i.e. there is a factor of n difference between the meaning of the alphas
838     // in the two contexts. The cuDNN alpha is n times the SE alpha.
839     double lrnAlpha = lrnN * normalize_descriptor.alpha();
840 
841     double lrnBeta = normalize_descriptor.beta();
842     double lrnK = normalize_descriptor.bias();
843     CHECK_CUDNN_OK(
844         cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
845   }
846 
handle() const847   cudnnLRNDescriptor_t handle() const { return handle_.get(); }
848 
849  private:
850   LrnDescriptor handle_;  // Owned.
851 
852   SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
853 };
854 
855 // Turns a ActivationDescriptor structure into a cudnn activation
856 // descriptor handle within a scope.
857 class CudnnActivationDescriptor {
858  public:
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,cudnnNanPropagation_t nan_propagation,double value_max)859   CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
860                             cudnnNanPropagation_t nan_propagation,
861                             double value_max)
862       : handle_(CreateActivationDescriptor()) {
863     double relu_ceiling = 0.0;
864     cudnnActivationMode_t mode;
865     switch (activation_mode) {
866       case dnn::ActivationMode::kNone:
867         mode = CUDNN_ACTIVATION_IDENTITY;
868         break;
869       case dnn::ActivationMode::kRelu6:
870         relu_ceiling = 6.0;
871         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
872         break;
873       case dnn::ActivationMode::kReluX:
874         relu_ceiling = value_max;
875         mode = CUDNN_ACTIVATION_CLIPPED_RELU;
876         break;
877       case dnn::ActivationMode::kRelu:
878         mode = CUDNN_ACTIVATION_RELU;
879         break;
880       case dnn::ActivationMode::kSigmoid:
881         mode = CUDNN_ACTIVATION_SIGMOID;
882         break;
883       case dnn::ActivationMode::kTanh:
884         mode = CUDNN_ACTIVATION_TANH;
885         break;
886       default:
887         LOG(FATAL) << "unrecognized activation mode: "
888                    << static_cast<int>(activation_mode);
889     }
890 
891     CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
892                                                 nan_propagation, relu_ceiling));
893   }
894 
handle() const895   cudnnActivationDescriptor_t handle() const { return handle_.get(); }
896 
897  private:
898   ActivationDescriptor handle_;  // Owned.
899 
900   SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
901 };
902 
ToCudnnDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)903 cudnnDataType_t ToCudnnDataType(
904     dnn::DataType data_type,
905     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
906   switch (data_type) {
907     case dnn::DataType::kFloat:
908       return CUDNN_DATA_FLOAT;
909     case dnn::DataType::kDouble:
910       return CUDNN_DATA_DOUBLE;
911     case dnn::DataType::kHalf:
912       return CUDNN_DATA_HALF;
913     case dnn::DataType::kInt8:
914       return data_layout == dnn::DataLayout::kBatchDepthYX4 ? CUDNN_DATA_INT8x4
915                                                             : CUDNN_DATA_INT8;
916     case dnn::DataType::kInt32:
917       return CUDNN_DATA_INT32;
918     default:
919       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
920   }
921 }
922 
ToCudnnDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)923 cudnnDataType_t ToCudnnDataType(dnn::DataType data_type,
924                                 dnn::FilterLayout filter_layout) {
925   if (data_type == dnn::DataType::kInt8 &&
926       filter_layout == dnn::FilterLayout::kOutputInputYX4) {
927     return CUDNN_DATA_INT8x4;
928   }
929   return ToCudnnDataType(data_type);
930 }
931 
932 template <typename T>
GetCudnnDataType(dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)933 cudnnDataType_t GetCudnnDataType(
934     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
935   return ToCudnnDataType(dnn::ToDataType<T>::value, data_layout);
936 }
937 
ToCudnnRnnInputMode(dnn::RnnInputMode input_mode)938 cudnnRNNInputMode_t ToCudnnRnnInputMode(dnn::RnnInputMode input_mode) {
939   switch (input_mode) {
940     case dnn::RnnInputMode::kRnnLinearSkip:
941     case dnn::RnnInputMode::kRnnSkipInput:
942       return static_cast<cudnnRNNInputMode_t>(input_mode);
943     default:
944       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
945   }
946 }
947 
ToCudnnRnnDirectionMode(dnn::RnnDirectionMode direction_mode)948 cudnnDirectionMode_t ToCudnnRnnDirectionMode(
949     dnn::RnnDirectionMode direction_mode) {
950   switch (direction_mode) {
951     case dnn::RnnDirectionMode::kRnnUnidirectional:
952     case dnn::RnnDirectionMode::kRnnBidirectional:
953       return static_cast<cudnnDirectionMode_t>(direction_mode);
954     default:
955       LOG(FATAL) << "Invalid RNN direction mode: "
956                  << static_cast<int>(direction_mode);
957   }
958 }
959 
ToCudnnRnnMode(dnn::RnnMode rnn_mode)960 cudnnRNNMode_t ToCudnnRnnMode(dnn::RnnMode rnn_mode) {
961   switch (rnn_mode) {
962     case dnn::RnnMode::kRnnRelu:
963     case dnn::RnnMode::kRnnTanh:
964     case dnn::RnnMode::kRnnLstm:
965     case dnn::RnnMode::kRnnGru:
966       return static_cast<cudnnRNNMode_t>(rnn_mode);
967     default:
968       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
969   }
970 }
971 
CudnnDataTypeToByteSize(cudnnDataType_t data_type)972 int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
973   switch (data_type) {
974     case CUDNN_DATA_FLOAT:
975       return sizeof(float);
976     case CUDNN_DATA_DOUBLE:
977       return sizeof(double);
978     case CUDNN_DATA_HALF:
979       return sizeof(Eigen::half);
980     default:
981       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
982   }
983 }
984 
985 class CudnnDropoutDescriptor {
CudnnDropoutDescriptor(DropoutDescriptor handle)986   explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
987       : handle_(std::move(handle)) {}
988 
989  public:
990   CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
991 
Create(const CudnnHandle & cudnn,float dropout,uint64 seed,ScratchAllocator * state_allocator)992   static port::StatusOr<CudnnDropoutDescriptor> Create(
993       const CudnnHandle& cudnn, float dropout, uint64 seed,
994       ScratchAllocator* state_allocator) {
995     DropoutDescriptor handle = CreateDropoutDescriptor();
996 
997     if (dropout == 0.0f) {
998       // Return 'empty' dropout descriptor.
999       return CudnnDropoutDescriptor(std::move(handle));
1000     }
1001 
1002     DeviceMemory<uint8> state_memory;
1003     if (state_allocator) {
1004       size_t state_sizes_in_bytes = 0;
1005       RETURN_IF_CUDNN_ERROR(
1006           cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
1007       SE_ASSIGN_OR_RETURN(state_memory,
1008                           state_allocator->AllocateBytes(state_sizes_in_bytes));
1009     }
1010     RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
1011         handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
1012         state_memory.size(), seed));
1013 
1014     return CudnnDropoutDescriptor(std::move(handle));
1015   }
1016 
handle() const1017   cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
1018 
1019  private:
1020   DropoutDescriptor handle_;  // Owned.
1021   SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
1022 };
1023 
1024 class CudnnRnnParamsDescriptor {
1025   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1026 
CudnnRnnParamsDescriptor(FilterDescriptor handle,int64 params_size_in_bytes,ParamsRegions weights,ParamsRegions biases)1027   CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
1028                            ParamsRegions weights, ParamsRegions biases)
1029       : handle_(std::move(handle)),
1030         params_size_in_bytes_(params_size_in_bytes),
1031         weights_(std::move(weights)),
1032         biases_(std::move(biases)) {}
1033 
1034  public:
1035   CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
1036 
1037   static port::StatusOr<CudnnRnnParamsDescriptor> Create(
1038       const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1039       cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1040       cudnnDirectionMode_t direction_mode, int num_layers);
1041 
handle() const1042   cudnnFilterDescriptor_t handle() const { return handle_.get(); }
params_size_in_bytes() const1043   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1044   ParamsRegions params_weights() const { return weights_; }
params_biases() const1045   ParamsRegions params_biases() const { return biases_; }
1046 
1047  private:
1048   FilterDescriptor handle_;
1049   int64 params_size_in_bytes_;
1050   ParamsRegions weights_;
1051   ParamsRegions biases_;
1052   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnParamsDescriptor);
1053 };
1054 
1055 }  // namespace
1056 
1057 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
CudnnRnnDescriptor(const CudnnHandle & cudnn,gpu::RnnDescriptor rnn_desc,PersistentRnnPlan rnn_plan,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,CudnnDropoutDescriptor dropout_desc,CudnnRnnParamsDescriptor params_desc)1058   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
1059                      PersistentRnnPlan rnn_plan, int num_layers,
1060                      int hidden_size, int input_size, int cell_size,
1061                      int batch_size, cudnnRNNInputMode_t input_mode,
1062                      cudnnDirectionMode_t direction_mode,
1063                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
1064                      cudnnDataType_t compute_type,
1065                      const dnn::AlgorithmConfig& algorithm_config,
1066                      CudnnDropoutDescriptor dropout_desc,
1067                      CudnnRnnParamsDescriptor params_desc)
1068       : rnn_desc_(std::move(rnn_desc)),
1069         rnn_plan_(std::move(rnn_plan)),
1070         num_layers_(num_layers),
1071         hidden_size_(hidden_size),
1072         input_size_(input_size),
1073         cell_size_(cell_size),
1074         batch_size_(batch_size),
1075         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
1076         input_mode_(input_mode),
1077         direction_mode_(direction_mode),
1078         rnn_mode_(rnn_mode),
1079         data_type_(data_type),
1080         compute_type_(compute_type),
1081         algorithm_config_(algorithm_config),
1082         dropout_desc_(std::move(dropout_desc)),
1083         params_desc_(std::move(params_desc)) {}
1084 
1085  public:
1086   CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
1087 
Create(const CudnnHandle & cudnn,int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,cudnnRNNInputMode_t input_mode,cudnnDirectionMode_t direction_mode,cudnnRNNMode_t rnn_mode,cudnnDataType_t data_type,cudnnDataType_t compute_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)1088   static port::StatusOr<CudnnRnnDescriptor> Create(
1089       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
1090       int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
1091       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
1092       cudnnDataType_t data_type, cudnnDataType_t compute_type,
1093       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
1094       ScratchAllocator* state_allocator, bool use_padded_io) {
1095     SE_ASSIGN_OR_RETURN(
1096         CudnnDropoutDescriptor dropout_desc,
1097         CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
1098 
1099     gpu::RnnDescriptor rnn_desc = CreateRnnDescriptor();
1100     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
1101 
1102     // TODO: allow the user to choose an algorithm.
1103     auto proj_size = hidden_size;
1104     hidden_size = std::max(hidden_size, cell_size);
1105 
1106     // Require explicit algorithm config to enable tensor cores. Some configs
1107     // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
1108     // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
1109     // We can only reasonably expect the user to handle the subsequent failure
1110     // in profile mode, which is run with algorithms returned from
1111     // GetRnnAlgorithms() (which are non-default and explicitly set whether to
1112     // use tensor ops). CuDNN 7.2.1 fixed this issue.
1113     // TODO(csigg): Minimal support cuDNN version is 7.3, clean up.
1114     bool allow_tensor_ops = data_type == CUDNN_DATA_HALF;
1115     if (data_type == CUDNN_DATA_FLOAT)
1116       allow_tensor_ops = tensorflow::tensor_float_32_execution_enabled();
1117     bool use_tensor_ops =
1118         algorithm_config.algorithm().has_value()
1119             ? algorithm_config.algorithm()->tensor_ops_enabled()
1120             : allow_tensor_ops;
1121     if (use_tensor_ops && !allow_tensor_ops) {
1122       return port::Status(port::error::INVALID_ARGUMENT,
1123                           "Algo requests disallowed tensor op evaluation.");
1124     }
1125 
1126 #if CUDNN_VERSION >= 8000
1127     cudnnMathType_t math_type =
1128         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH;
1129 #else
1130     cudnnMathType_t math_type =
1131         use_tensor_ops ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
1132 #endif
1133 
1134 #if CUDNN_VERSION >= 8000
1135     cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
1136     uint32_t aux_flags = 0;
1137     if (use_padded_io) aux_flags |= CUDNN_RNN_PADDED_IO_ENABLED;
1138     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v8(
1139         /*rnnDesc=*/rnn_desc.get(), /*algo=*/rnn_algo, /*cellMode=*/rnn_mode,
1140         /*biasMode=*/bias_mode, /*dirMode=*/direction_mode,
1141         /*inputMode=*/input_mode,
1142         /*dataType=*/data_type, /*mathPrec=*/compute_type,
1143         /*mathType=*/math_type,
1144         /*inputSize=*/input_size,
1145         /*hiddenSize=*/hidden_size, /*projSize=*/proj_size,
1146         /*numLayers=*/num_layers,
1147         /*dropoutDesc=*/dropout_desc.handle(),
1148         /*auxFlags=*/aux_flags));
1149 #else
1150     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
1151         cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1152         /*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
1153         /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
1154         /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
1155         /*dataType=*/compute_type));
1156     CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
1157 
1158     if (proj_size < hidden_size) {
1159       RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
1160           cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
1161           /*recProjSize=*/proj_size, /*outProjSize=*/0));
1162     }
1163 
1164     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
1165     // But in the future if these APIs are used to process full length arrays,
1166     // we need to distinguish when to set it.
1167     if (use_padded_io) {
1168       RETURN_IF_CUDNN_ERROR(
1169           cudnnSetRNNPaddingMode(rnn_desc.get(), CUDNN_RNN_PADDED_IO_ENABLED));
1170     }
1171 #endif
1172 
1173     port::StatusOr<PersistentRnnPlan> rnn_plan_wrapper;
1174     PersistentRnnPlan rnn_plan;
1175     if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
1176       CHECK_GE(batch_size, 0);
1177       rnn_plan_wrapper =
1178           CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
1179       if (!rnn_plan_wrapper.ok()) {
1180         return port::StatusOr<CudnnRnnDescriptor>(rnn_plan_wrapper.status());
1181       } else {
1182         rnn_plan = rnn_plan_wrapper.ConsumeValueOrDie();
1183         RETURN_IF_CUDNN_ERROR(
1184             cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
1185       }
1186     }
1187 
1188     // Create the params handle.
1189     SE_ASSIGN_OR_RETURN(auto params_desc,
1190                         CudnnRnnParamsDescriptor::Create(
1191                             cudnn, input_size, data_type, rnn_desc.get(),
1192                             rnn_mode, direction_mode, num_layers));
1193 
1194     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
1195                               num_layers, hidden_size, input_size, cell_size,
1196                               batch_size, input_mode, direction_mode, rnn_mode,
1197                               data_type, compute_type, algorithm_config,
1198                               std::move(dropout_desc), std::move(params_desc));
1199   }
1200 
handle() const1201   cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
num_layers() const1202   int num_layers() const { return num_layers_; }
hidden_size() const1203   int hidden_size() const { return hidden_size_; }
input_size() const1204   int input_size() const { return input_size_; }
cell_size() const1205   int cell_size() const { return cell_size_; }
batch_size() const1206   int batch_size() const { return batch_size_; }
input_mode() const1207   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1208   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1209   cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1210   cudnnDataType_t data_type() const { return data_type_; }
compute_type() const1211   cudnnDataType_t compute_type() const { return compute_type_; }
algorithm_config() const1212   const dnn::AlgorithmConfig& algorithm_config() const {
1213     return algorithm_config_;
1214   }
ParamsSizeInBytes() const1215   int64 ParamsSizeInBytes() const override {
1216     return params_desc_.params_size_in_bytes();
1217   }
params_handle() const1218   cudnnFilterDescriptor_t params_handle() const {
1219     return params_desc_.handle();
1220   }
ParamsWeightRegions() const1221   ParamsRegions ParamsWeightRegions() const override {
1222     return params_desc_.params_weights();
1223   }
ParamsBiasRegions() const1224   ParamsRegions ParamsBiasRegions() const override {
1225     return params_desc_.params_biases();
1226   }
1227 
1228  private:
1229   gpu::RnnDescriptor rnn_desc_;
1230   PersistentRnnPlan rnn_plan_;
1231   int num_layers_;
1232   int hidden_size_;
1233   int input_size_;
1234   // cell_size_ is the size of cell state, which will be different from
1235   // hidden_size_ if the projection is used.
1236   int cell_size_;
1237   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
1238   // algorithm.
1239   int batch_size_;
1240   cudnnRNNAlgo_t rnn_algo_;
1241   cudnnRNNInputMode_t input_mode_;
1242   cudnnDirectionMode_t direction_mode_;
1243   cudnnRNNMode_t rnn_mode_;
1244   cudnnDataType_t data_type_;
1245   cudnnDataType_t compute_type_;
1246   dnn::AlgorithmConfig algorithm_config_;
1247   CudnnDropoutDescriptor dropout_desc_;
1248   CudnnRnnParamsDescriptor params_desc_;
1249   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
1250 };
1251 
1252 #if CUDNN_VERSION >= 7603
1253 class CudnnCtcLossDescriptor {
1254  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1255   explicit CudnnCtcLossDescriptor(cudnnDataType_t data_type)
1256       : handle_(CreateCtcLossDescriptor()) {
1257     CHECK_CUDNN_OK(cudnnSetCTCLossDescriptorEx(
1258         /*ctcLossDesc=*/handle_.get(),
1259         /*compType=*/data_type,
1260         /*normMode=*/CUDNN_LOSS_NORMALIZATION_SOFTMAX,
1261         /*gradMode=*/CUDNN_NOT_PROPAGATE_NAN));
1262   }
1263 
handle() const1264   cudnnCTCLossDescriptor_t handle() const { return handle_.get(); }
1265 
1266  private:
1267   CtcLossDescriptor handle_;  // Owned
1268 
1269   SE_DISALLOW_COPY_AND_ASSIGN(CudnnCtcLossDescriptor);
1270 };
1271 #else
1272 // dummy class
1273 class CudnnCtcLossDescriptor {
1274  public:
CudnnCtcLossDescriptor(cudnnDataType_t data_type)1275   CudnnCtcLossDescriptor(cudnnDataType_t data_type) {}
1276 };
1277 #endif
1278 
1279 namespace {
1280 
1281 // Check if the LSTM projection is used. If yes, an additional weight matrix
1282 // (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
1283 // be done.
CheckAndFetchProjectionWeights(const CudnnHandle & cudnn,cudnnRNNDescriptor_t rnn_desc,const int layer,const TensorDescriptor & input_desc,const FilterDescriptor & filter_desc,const FilterDescriptor & region_desc_handle,dnn::RnnDescriptor::ParamsRegions * weights)1284 port::Status CheckAndFetchProjectionWeights(
1285     const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
1286     const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
1287     const FilterDescriptor& region_desc_handle,
1288     dnn::RnnDescriptor::ParamsRegions* weights) {
1289   int hidden_size_v;
1290   int num_layers_v;
1291   cudnnDropoutDescriptor_t dropout_desc;
1292   cudnnRNNInputMode_t input_mode;
1293   cudnnDirectionMode_t direction;
1294   cudnnRNNMode_t mode;
1295   cudnnRNNAlgo_t algo;
1296   cudnnDataType_t data_type;
1297 #if CUDNN_VERSION >= 8000
1298   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor_v6(
1299       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1300       /*hiddenSize=*/&hidden_size_v,
1301       /*numLayers=*/&num_layers_v,
1302       /*dropoutDesc=*/&dropout_desc,
1303       /*inputMode=*/&input_mode,
1304       /*direction=*/&direction,
1305       /*mode=*/&mode,
1306       /*algo=*/&algo,
1307       /*mathPrec=*/&data_type));
1308 #else
1309   RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
1310       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1311       /*hiddenSize=*/&hidden_size_v,
1312       /*numLayers=*/&num_layers_v,
1313       /*dropoutDesc=*/&dropout_desc,
1314       /*inputMode=*/&input_mode,
1315       /*direction=*/&direction,
1316       /*mode=*/&mode,
1317       /*algo=*/&algo,
1318       /*mathPrec=*/&data_type));
1319 #endif
1320   int rec_proj_size_v;
1321   int out_proj_size_v;
1322   RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
1323       /*handle=*/cudnn.handle(),
1324       /*rnnDesc=*/rnn_desc,
1325       /*recProjSize*/ &rec_proj_size_v,
1326       /*outProjSize*/ &out_proj_size_v));
1327   if (rec_proj_size_v != hidden_size_v) {
1328     void* offset = nullptr;
1329     int region_id = 8;
1330     RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
1331         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1332         /*layer=*/layer, /*xDesc=*/input_desc.get(),
1333         /*wDesc=*/filter_desc.get(),
1334         /*w=*/nullptr, /*linLayerID=*/region_id,
1335         /*linLayerMatDesc=*/region_desc_handle.get(),
1336         /*linLayerMat or linLayerBias=*/&offset));
1337     int dims[] = {1, 1, 1};
1338     cudnnDataType_t data_type;
1339     cudnnTensorFormat_t tensor_format;
1340     int n_dims;
1341     RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1342         /*filterDesc=*/region_desc_handle.get(),
1343         /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1344         /*dataType=*/&data_type, /*format=*/&tensor_format,
1345         /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1346     int64 size =
1347         dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1348     dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
1349                                                size};
1350     weights->push_back(region);
1351   }
1352   return port::Status::OK();
1353 }
1354 
Create(const CudnnHandle & cudnn,int input_size,cudnnDataType_t data_type,cudnnRNNDescriptor_t rnn_desc,cudnnRNNMode_t rnn_mode,cudnnDirectionMode_t direction_mode,int num_layers)1355 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
1356     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
1357     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
1358     cudnnDirectionMode_t direction_mode, int num_layers) {
1359   // Query the params size.
1360   TensorDescriptor input_desc = CreateTensorDescriptor();
1361   int tensor_dims[] = {1, input_size, 1};
1362   int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
1363   RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1364       /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
1365       /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
1366       /*dimA=*/tensor_dims,
1367       /*strideA=*/strides));
1368 
1369   size_t params_size = 0;
1370   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1371       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1372       /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
1373       /*dataType=*/data_type));
1374   int64 params_size_in_bytes = static_cast<int64>(params_size);
1375 
1376   FilterDescriptor filter_desc = CreateFilterDescriptor();
1377   int64 filter_dim0 = params_size_in_bytes / CudnnDataTypeToByteSize(data_type);
1378   int filter_dims[] = {static_cast<int>(filter_dim0), 1, 1};
1379   RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
1380       /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
1381       /*format=*/CUDNN_TENSOR_NCHW,
1382       /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
1383       /*filterDimA=*/filter_dims));
1384 
1385   // Create the weights and biases into the params buffer
1386   int region_count_per_layer = [&] {
1387     switch (rnn_mode) {
1388       case CUDNN_RNN_RELU:
1389       case CUDNN_RNN_TANH:
1390         return 2;
1391       case CUDNN_LSTM:
1392         return 8;
1393       case CUDNN_GRU:
1394         return 6;
1395       default:
1396         LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1397         return 0;
1398     }
1399   }();
1400 
1401   FilterDescriptor region_desc_handle = CreateFilterDescriptor();
1402   const int layer_count =
1403       direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
1404 
1405   ParamsRegions weights;
1406   ParamsRegions biases;
1407 
1408   for (int layer = 0; layer < layer_count; layer++) {
1409     for (int region = 0; region < region_count_per_layer; region++) {
1410       for (int type = 0; type < 2; type++) {
1411         void* offset = nullptr;
1412         RETURN_IF_CUDNN_ERROR(
1413             type == 0 ? cudnnGetRNNLinLayerMatrixParams(
1414                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1415                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1416                             /*wDesc=*/filter_desc.get(),
1417                             /*w=*/nullptr, /*linLayerID=*/region,
1418                             /*linLayerMatDesc=*/region_desc_handle.get(),
1419                             /*linLayerMat or linLayerBias=*/&offset)
1420                       : cudnnGetRNNLinLayerBiasParams(
1421                             /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
1422                             /*layer=*/layer, /*xDesc=*/input_desc.get(),
1423                             /*wDesc=*/filter_desc.get(),
1424                             /*w=*/nullptr, /*linLayerID=*/region,
1425                             /*linLayerMatDesc=*/region_desc_handle.get(),
1426                             /*linLayerMat or linLayerBias=*/&offset));
1427         int dims[] = {1, 1, 1};
1428         cudnnDataType_t data_type;
1429         cudnnTensorFormat_t tensor_format;
1430         int n_dims;
1431         RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
1432             /*filterDesc=*/region_desc_handle.get(),
1433             /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
1434             /*dataType=*/&data_type, /*format=*/&tensor_format,
1435             /*nbDims=*/&n_dims, /*filterDimA=*/dims));
1436         int64 size =
1437             dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
1438         dnn::RnnDescriptor::ParamsRegion region = {
1439             reinterpret_cast<int64>(offset), size};
1440         (type == 0 ? weights : biases).push_back(region);
1441       }
1442     }
1443     TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
1444         cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
1445         &weights));
1446   }
1447 
1448   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
1449                                   weights, biases);
1450 }
1451 
1452 }  // namespace
1453 
1454 class CudnnRnnSequenceTensorDescriptor
1455     : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type,RNNDataDescriptor data_handle,TensorDescriptor handle)1456   CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length,
1457                                    int batch_size, int data_size,
1458                                    cudnnDataType_t data_type,
1459                                    RNNDataDescriptor data_handle,
1460                                    TensorDescriptor handle)
1461       : max_seq_length_(max_seq_length),
1462         batch_size_(batch_size),
1463         data_size_(data_size),
1464         data_type_(data_type),
1465         handle_(std::move(handle)),
1466         rnn_data_handle_(std::move(data_handle)),
1467         handles_(max_seq_length, handle_.get()) {
1468   }
1469 
1470  public:
1471   CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
1472       default;
1473 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,cudnnDataType_t data_type)1474   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1475       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1476       cudnnDataType_t data_type) {
1477     if (max_seq_length <= 0) {
1478       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1479     }
1480     int dims[] = {batch_size, data_size, 1};
1481     int strides[] = {dims[1] * dims[2], dims[2], 1};
1482     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1483     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1484         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1485         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1486         /*strideA=*/strides));
1487     return CudnnRnnSequenceTensorDescriptor(parent, max_seq_length, batch_size,
1488                                             data_size, data_type,
1489                                             nullptr,
1490                                             std::move(tensor_desc));
1491   }
1492 
Create(GpuExecutor * parent,int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,cudnnDataType_t data_type)1493   static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
1494       GpuExecutor* parent, int max_seq_length, int batch_size, int data_size,
1495       const absl::Span<const int>& seq_lengths, bool time_major,
1496       cudnnDataType_t data_type) {
1497     if (max_seq_length <= 0) {
1498       return port::Status(port::error::INVALID_ARGUMENT, "max_seq_length <= 0");
1499     }
1500     int dims[] = {batch_size, data_size, 1};
1501     int strides[] = {dims[1] * dims[2], dims[2], 1};
1502     TensorDescriptor tensor_desc = CreateTensorDescriptor();
1503     RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
1504         /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
1505         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1506         /*strideA=*/strides));
1507     const int* seq_lengths_array = seq_lengths.data();
1508     RNNDataDescriptor data_desc = CreateRNNDataDescriptor();
1509     float padding_fill = 0.0f;
1510     cudnnRNNDataLayout_t layout;
1511     if (time_major) {
1512       layout = CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED;
1513     } else {
1514       layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
1515     }
1516     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDataDescriptor(
1517         /*RNNDataDesc=*/data_desc.get(), /*dataType*/ data_type,
1518         /*layout=*/layout,
1519         /*maxSeqLength=*/max_seq_length,
1520         /*batchSize=*/batch_size, /*vectorSize=*/data_size,
1521         /*seqLengthArray=*/seq_lengths_array,
1522         /*paddingFill*/ (void*)&padding_fill));
1523     return CudnnRnnSequenceTensorDescriptor(
1524         parent, max_seq_length, batch_size, data_size, data_type,
1525         std::move(data_desc), std::move(tensor_desc));
1526   }
1527 
handles() const1528   const cudnnTensorDescriptor_t* handles() const { return handles_.data(); }
data_handle() const1529   const cudnnRNNDataDescriptor_t data_handle() const {
1530     return rnn_data_handle_.get();
1531   }
1532 
max_seq_length() const1533   int max_seq_length() const { return max_seq_length_; }
batch_size() const1534   int batch_size() const { return batch_size_; }
data_size() const1535   int data_size() const { return data_size_; }
is_var_seq_lengths() const1536   bool is_var_seq_lengths() const {
1537     return rnn_data_handle_ != nullptr;
1538   }
1539 
1540  private:
1541   int max_seq_length_;
1542   int batch_size_;
1543   int data_size_;
1544   cudnnDataType_t data_type_;
1545   TensorDescriptor handle_;
1546   RNNDataDescriptor rnn_data_handle_;
1547   std::vector<cudnnTensorDescriptor_t> handles_;  // Copies of handle_.
1548   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
1549 };
1550 
1551 class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
1552  public:
CudnnRnnStateTensorDescriptor(GpuExecutor * parent,int num_layers,int batch_size,int data_size,cudnnDataType_t data_type)1553   CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers,
1554                                 int batch_size, int data_size,
1555                                 cudnnDataType_t data_type)
1556       : handle_(CreateTensorDescriptor()),
1557         num_layers_(num_layers),
1558         batch_size_(batch_size),
1559         data_size_(data_size),
1560         data_type_(data_type) {
1561     int dims[] = {num_layers, batch_size, data_size};
1562     int strides[] = {dims[1] * dims[2], dims[2], 1};
1563     CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
1564         /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
1565         /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
1566         /*strideA=*/strides));
1567   }
1568 
handle() const1569   cudnnTensorDescriptor_t handle() const { return handle_.get(); }
1570 
num_layers() const1571   int num_layers() const { return num_layers_; }
batch_size() const1572   int batch_size() const { return batch_size_; }
data_size() const1573   int data_size() const { return data_size_; }
1574 
1575  private:
1576   TensorDescriptor handle_;
1577   int num_layers_;
1578   int batch_size_;
1579   int data_size_;
1580   cudnnDataType_t data_type_;
1581   SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnStateTensorDescriptor);
1582 };
1583 
1584 namespace {
1585 
1586 struct RnnModelDims {
1587   int num_layers = 0;
1588   int batch_size = 0;
1589   int max_seq_length = 0;
1590   int hidden_size = 0;
1591   int input_size = 0;
1592   int cell_size = 0;
1593   int dir_count = 0;
1594 };
1595 
1596 template <class T>
ExtractAndCheckRnnForward(const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data)1597 port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
1598     const CudnnRnnDescriptor& rnn_desc,
1599     const CudnnRnnSequenceTensorDescriptor& input_desc,
1600     const DeviceMemory<T>& input_data,
1601     const CudnnRnnStateTensorDescriptor& input_h_desc,
1602     const DeviceMemory<T>& input_h_data,
1603     const CudnnRnnStateTensorDescriptor& input_c_desc,
1604     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1605     const CudnnRnnSequenceTensorDescriptor& output_desc,
1606     const DeviceMemory<T>& output_data,
1607     const CudnnRnnStateTensorDescriptor& output_h_desc,
1608     const DeviceMemory<T>& output_h_data,
1609     const CudnnRnnStateTensorDescriptor& output_c_desc,
1610     const DeviceMemory<T>& output_c_data) {
1611   // extract model parameters
1612   RnnModelDims model_dims;
1613   model_dims.num_layers = rnn_desc.num_layers();
1614   model_dims.batch_size = input_desc.batch_size();
1615   model_dims.max_seq_length = input_desc.max_seq_length();
1616   model_dims.hidden_size = rnn_desc.hidden_size();
1617   model_dims.input_size = input_desc.data_size();
1618   model_dims.cell_size = rnn_desc.cell_size();
1619   model_dims.dir_count =
1620       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
1621 
1622   // check parameters
1623   if (!(input_h_desc.num_layers() ==
1624             model_dims.num_layers * model_dims.dir_count &&
1625         input_h_desc.batch_size() == model_dims.batch_size &&
1626         input_h_desc.data_size() == model_dims.hidden_size)) {
1627     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
1628   }
1629   // The LSTM projection will be used if input_h_desc.data_size() <
1630   // input_c_desc.data_size()
1631   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1632         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1633         input_h_desc.data_size() <= input_c_desc.data_size())) {
1634     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
1635   }
1636   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
1637         output_desc.batch_size() == model_dims.batch_size &&
1638         output_desc.data_size() ==
1639             model_dims.hidden_size * model_dims.dir_count)) {
1640     return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
1641   }
1642   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1643         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1644         input_h_desc.data_size() == output_h_desc.data_size())) {
1645     return port::Status(port::error::INVALID_ARGUMENT,
1646                         "Invalid output_h shape");
1647   }
1648   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1649         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1650         input_h_desc.data_size() <= output_c_desc.data_size())) {
1651     return port::Status(port::error::INVALID_ARGUMENT,
1652                         "Invalid output_c shape");
1653   }
1654 
1655   return model_dims;
1656 }
1657 
CheckRNNParameterSize(const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc)1658 port::Status CheckRNNParameterSize(
1659     const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
1660     const CudnnRnnSequenceTensorDescriptor& input_desc) {
1661   size_t params_size_in_bytes = 0;
1662   RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
1663       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1664       /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
1665       /*dataType=*/rnn_desc.data_type()));
1666   if (static_cast<int64>(params_size_in_bytes) !=
1667       rnn_desc.ParamsSizeInBytes()) {
1668     return port::Status(port::error::INVALID_ARGUMENT,
1669                         "Mismatching RNN parameter size");
1670   }
1671   return port::Status::OK();
1672 }
1673 
CreateRnnWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator)1674 port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
1675     Stream* stream, const CudnnHandle& cudnn,
1676     const CudnnRnnDescriptor& rnn_desc,
1677     const CudnnRnnSequenceTensorDescriptor& input_desc,
1678     ScratchAllocator* workspace_allocator) {
1679   // Query the workspace size.
1680   size_t workspace_size_in_bytes = 0;
1681   RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
1682       /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1683       /*seqLength=*/input_desc.max_seq_length(), /*xDesc=*/input_desc.handles(),
1684       /*sizeInBytes=*/&workspace_size_in_bytes));
1685   // Allocate the workspace.
1686   if (workspace_size_in_bytes == 0) {
1687     return DeviceMemory<uint8>();
1688   }
1689   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1690 }
1691 
1692 #if CUDNN_VERSION >= 7402
CreateBatchNormForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const cudnnActivationDescriptor_t & activation_desc,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1693 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormForwardWorkspace(
1694     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1695     const cudnnBatchNormOps_t& bn_ops,
1696     const cudnnActivationDescriptor_t& activation_desc,
1697     const CudnnTensorDescriptor& x_descriptor,
1698     const CudnnTensorDescriptor& scale_offset_descriptor,
1699     ScratchAllocator* workspace_allocator) {
1700   // Query the workspace size.
1701   size_t workspace_size_in_bytes = 0;
1702   RETURN_IF_CUDNN_ERROR(
1703       cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
1704           /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1705           /*xDesc=*/x_descriptor.handle(), /*zDesc=*/x_descriptor.handle(),
1706           /*yDesc=*/x_descriptor.handle(),
1707           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
1708           /*activationDesc=*/activation_desc,
1709           /*sizeInBytes=*/&workspace_size_in_bytes));
1710   // Allocate the workspace.
1711   if (workspace_size_in_bytes == 0) {
1712     return DeviceMemory<uint8>();
1713   }
1714   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1715 }
1716 
CreateBatchNormBackwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const cudnnBatchNormMode_t & mode,const cudnnBatchNormOps_t & bn_ops,const CudnnTensorDescriptor & x_descriptor,const CudnnTensorDescriptor & scale_offset_descriptor,ScratchAllocator * workspace_allocator)1717 port::StatusOr<DeviceMemory<uint8>> CreateBatchNormBackwardWorkspace(
1718     Stream* stream, const CudnnHandle& cudnn, const cudnnBatchNormMode_t& mode,
1719     const cudnnBatchNormOps_t& bn_ops,
1720     const CudnnTensorDescriptor& x_descriptor,
1721     const CudnnTensorDescriptor& scale_offset_descriptor,
1722     ScratchAllocator* workspace_allocator) {
1723   // Query the workspace size.
1724   size_t workspace_size_in_bytes = 0;
1725   RETURN_IF_CUDNN_ERROR(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
1726       /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
1727       /*xDesc=*/x_descriptor.handle(),
1728       /*yDesc=*/x_descriptor.handle(),
1729       /*dyDesc=*/x_descriptor.handle(),
1730       /*dzDesc=*/nullptr,
1731       /*dxDesc=*/x_descriptor.handle(),
1732       /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
1733       /*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size_in_bytes));
1734   // Allocate the workspace.
1735   if (workspace_size_in_bytes == 0) {
1736     return DeviceMemory<uint8>();
1737   }
1738   return workspace_allocator->AllocateBytes(workspace_size_in_bytes);
1739 }
1740 
1741 #endif
1742 
1743 }  // namespace
1744 
1745 template <class T>
DoRnnForwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1746 port::Status CudnnSupport::DoRnnForwardImpl(
1747     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1748     const CudnnRnnSequenceTensorDescriptor& input_desc,
1749     const DeviceMemory<T>& input_data,
1750     const CudnnRnnStateTensorDescriptor& input_h_desc,
1751     const DeviceMemory<T>& input_h_data,
1752     const CudnnRnnStateTensorDescriptor& input_c_desc,
1753     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1754     const CudnnRnnSequenceTensorDescriptor& output_desc,
1755     DeviceMemory<T>* output_data,
1756     const CudnnRnnStateTensorDescriptor& output_h_desc,
1757     DeviceMemory<T>* output_h_data,
1758     const CudnnRnnStateTensorDescriptor& output_c_desc,
1759     DeviceMemory<T>* output_c_data, bool is_training,
1760     ScratchAllocator* reserve_space_allocator,
1761     ScratchAllocator* workspace_allocator,
1762     dnn::ProfileResult* output_profile_result) {
1763   SE_ASSIGN_OR_RETURN(
1764       RnnModelDims model_dims,
1765       ExtractAndCheckRnnForward(
1766           rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
1767           input_c_desc, input_c_data, params, output_desc, *output_data,
1768           output_h_desc, *output_h_data, output_c_desc, *output_c_data));
1769 
1770   auto cudnn = cudnn_->GetHandle(parent_, stream);
1771 
1772   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1773   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1774                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1775                                          workspace_allocator))
1776 
1777   // query the reserve space size
1778   // allocate the reserve space
1779   DeviceMemory<uint8> reserve_space;
1780   if (is_training) {
1781     size_t reserve_space_size_in_bytes = 0;
1782     RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
1783         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1784         /*seqLength=*/model_dims.max_seq_length, /*xDesc=*/input_desc.handles(),
1785         /*sizeInBytes=*/&reserve_space_size_in_bytes));
1786 
1787     if (reserve_space_size_in_bytes > 0) {
1788       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
1789                                              reserve_space_size_in_bytes));
1790     }
1791   }
1792 
1793   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1794   const bool is_profiling = output_profile_result != nullptr;
1795   if (is_profiling) {
1796     timer.reset(new GpuTimer(parent_));
1797     // The start and stop of the timer should be as close to the Cudnn call as
1798     // possible. It is still possible for other threads to issue workload on
1799     // to this stream. So it could take multiple profiling measurements.
1800     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1801       return port::Status(port::error::INTERNAL, "Failed to start timer");
1802     }
1803   }
1804 
1805   if (!is_training) {
1806     if (input_desc.is_var_seq_lengths()) {
1807       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInferenceEx(
1808           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1809           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1810           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1811           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1812           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1813           /*yDesc=*/output_desc.data_handle(),
1814           /*y=*/output_data->opaque(),
1815           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1816           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1817           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1818           nullptr,
1819           /*workspace=*/workspace.opaque(),
1820           /*workSpaceSizeInBytes=*/workspace.size()));
1821     } else {
1822       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
1823           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1824           /*seqLength=*/model_dims.max_seq_length,
1825           /*xDesc=*/input_desc.handles(),
1826           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1827           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1828           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1829           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1830           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1831           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1832           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1833           /*workSpaceSizeInBytes=*/workspace.size()));
1834     }
1835   } else {
1836     if (input_desc.is_var_seq_lengths()) {
1837       // cudnnSetRNNPaddingMode(rnn_desc.handle(), CUDNN_RNN_PADDED_IO_ENABLED);
1838       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTrainingEx(
1839           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1840           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1841           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1842           /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1843           /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1844           /*yDesc=*/output_desc.data_handle(),
1845           /*y=*/output_data->opaque(),
1846           /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(),
1847           /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(),
1848           nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
1849           nullptr,
1850           /*workspace=*/workspace.opaque(),
1851           /*workSpaceSizeInBytes=*/workspace.size(),
1852           /*reserveSpace=*/reserve_space.opaque(),
1853           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1854     } else {
1855       RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
1856           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1857           /*seqLength=*/model_dims.max_seq_length,
1858           /*xDesc=*/input_desc.handles(),
1859           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
1860           /*hx=*/input_h_data.opaque(), /*cxDesc=*/input_c_desc.handle(),
1861           /*cx=*/input_c_data.opaque(), /*wDesc=*/rnn_desc.params_handle(),
1862           /*w=*/params.opaque(), /*yDesc=*/output_desc.handles(),
1863           /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
1864           /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
1865           /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
1866           /*workSpaceSizeInBytes=*/workspace.size(),
1867           /*reserveSpace=*/reserve_space.opaque(),
1868           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
1869     }
1870   }
1871 
1872   if (is_profiling) {
1873     if (!timer->Stop(AsGpuStream(stream))) {
1874       return port::Status(port::error::INTERNAL, "Failed to stop timer");
1875     }
1876     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
1877     output_profile_result->set_algorithm(algo_desc);
1878     output_profile_result->set_elapsed_time_in_ms(
1879         timer->GetElapsedMilliseconds());
1880   }
1881 
1882   return port::Status::OK();
1883 }
1884 
1885 template <class T>
DoRnnBackwardImpl(Stream * stream,const CudnnRnnDescriptor & rnn_desc,const CudnnRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const CudnnRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const CudnnRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const CudnnRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const CudnnRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const CudnnRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)1886 port::Status CudnnSupport::DoRnnBackwardImpl(
1887     Stream* stream, const CudnnRnnDescriptor& rnn_desc,
1888     const CudnnRnnSequenceTensorDescriptor& input_desc,
1889     const DeviceMemory<T>& input_data,
1890     const CudnnRnnStateTensorDescriptor& input_h_desc,
1891     const DeviceMemory<T>& input_h_data,
1892     const CudnnRnnStateTensorDescriptor& input_c_desc,
1893     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1894     const CudnnRnnSequenceTensorDescriptor& output_desc,
1895     const DeviceMemory<T>& output_data,
1896     const CudnnRnnStateTensorDescriptor& output_h_desc,
1897     const DeviceMemory<T>& output_h_data,
1898     const CudnnRnnStateTensorDescriptor& output_c_desc,
1899     const DeviceMemory<T>& output_c_data,
1900     const DeviceMemory<T>& output_backprop_data,
1901     const DeviceMemory<T>& output_h_backprop_data,
1902     const DeviceMemory<T>& output_c_backprop_data,
1903     DeviceMemory<T>* input_backprop_data,
1904     DeviceMemory<T>* input_h_backprop_data,
1905     DeviceMemory<T>* input_c_backprop_data,
1906     DeviceMemory<T>* params_backprop_data,
1907     DeviceMemory<uint8>* reserve_space_data,
1908     ScratchAllocator* workspace_allocator,
1909     dnn::ProfileResult* output_profile_result) {
1910   SE_ASSIGN_OR_RETURN(
1911       RnnModelDims model_dims,
1912       ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
1913                                 input_h_data, input_c_desc, input_c_data,
1914                                 params, output_desc, output_data, output_h_desc,
1915                                 output_h_data, output_c_desc, output_c_data));
1916 
1917   auto cudnn = cudnn_->GetHandle(parent_, stream);
1918 
1919   SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
1920   SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
1921                       CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
1922                                          workspace_allocator));
1923 
1924   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
1925   const bool is_profiling = output_profile_result != nullptr;
1926   if (is_profiling) {
1927     timer.reset(new GpuTimer(parent_));
1928     // The start and stop of the timer should be as close to the Cudnn call as
1929     // possible. It is still possible for other threads to issue workload on
1930     // to this stream. So it could take multiple profiling measurements.
1931     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
1932       return port::Status(port::error::INTERNAL, "Failed to start timer");
1933     }
1934   }
1935 
1936   if (input_desc.is_var_seq_lengths()) {
1937     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardDataEx(
1938         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1939         /*yDesc=*/output_desc.data_handle(), /*y=*/output_data.opaque(),
1940         /*dyDesc=*/output_desc.data_handle(),
1941         /*dy=*/output_backprop_data.opaque(), nullptr, nullptr,
1942         /*dhyDesc=*/output_h_desc.handle(),
1943         /*dhy=*/output_h_backprop_data.opaque(),
1944         /*dcyDesc=*/output_c_desc.handle(),
1945         /*dcy=*/output_c_backprop_data.opaque(),
1946         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1947         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1948         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1949         /*dxDesc=*/input_desc.data_handle(),
1950         /*dx=*/input_backprop_data->opaque(),
1951         /*dhxDesc=*/input_h_desc.handle(),
1952         /*dhx=*/input_h_backprop_data->opaque(),
1953         /*dcxDesc=*/input_c_desc.handle(),
1954         /*dcx=*/input_c_backprop_data->opaque(), nullptr, nullptr,
1955         /*workspace=*/workspace.opaque(),
1956         /*workSpaceSizeInBytes=*/workspace.size(),
1957         /*reserveSpace=*/reserve_space_data->opaque(),
1958         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1959   } else {
1960     RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
1961         /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1962         /*seqLength=*/model_dims.max_seq_length,
1963         /*yDesc=*/output_desc.handles(),
1964         /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
1965         /*dy=*/output_backprop_data.opaque(),
1966         /*dhyDesc=*/output_h_desc.handle(),
1967         /*dhy=*/output_h_backprop_data.opaque(),
1968         /*dcyDesc=*/output_c_desc.handle(),
1969         /*dcy=*/output_c_backprop_data.opaque(),
1970         /*wDesc=*/rnn_desc.params_handle(), /*w=*/params.opaque(),
1971         /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1972         /*cxDesc=*/input_c_desc.handle(), /*cx=*/input_c_data.opaque(),
1973         /*dxDesc=*/input_desc.handles(), /*dx=*/input_backprop_data->opaque(),
1974         /*dhxDesc=*/input_h_desc.handle(),
1975         /*dhx=*/input_h_backprop_data->opaque(),
1976         /*dcxDesc=*/input_c_desc.handle(),
1977         /*dcx=*/input_c_backprop_data->opaque(),
1978         /*workspace=*/workspace.opaque(),
1979         /*workSpaceSizeInBytes=*/workspace.size(),
1980         /*reserveSpace=*/reserve_space_data->opaque(),
1981         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
1982   }
1983 
1984   if (params_backprop_data != nullptr) {
1985     // Clear the dw to zeros.
1986     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
1987     if (input_desc.is_var_seq_lengths()) {
1988       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeightsEx(
1989           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
1990           /*xDesc=*/input_desc.data_handle(), /*x=*/input_data.opaque(),
1991           /*hxDesc=*/input_h_desc.handle(), /*hx=*/input_h_data.opaque(),
1992           /*yDesc=*/output_desc.data_handle(),
1993           /*y=*/output_data.opaque(),
1994           /*workspace=*/workspace.opaque(),
1995           /*workSpaceSizeInBytes=*/workspace.size(),
1996           /*dwDesc=*/rnn_desc.params_handle(),
1997           /*dw=*/params_backprop_data->opaque(),
1998           /*reserveSpace=*/reserve_space_data->opaque(),
1999           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2000     } else {
2001       // make the backward weight call
2002       RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
2003           /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
2004           /*seqLength=*/model_dims.max_seq_length,
2005           /*xDesc=*/input_desc.handles(),
2006           /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
2007           /*hx=*/input_h_data.opaque(), /*yDesc=*/output_desc.handles(),
2008           /*y=*/output_data.opaque(), /*workspace=*/workspace.opaque(),
2009           /*workSpaceSizeInBytes=*/workspace.size(),
2010           /*dwDesc=*/rnn_desc.params_handle(),
2011           /*dw=*/params_backprop_data->opaque(),
2012           /*reserveSpace=*/reserve_space_data->opaque(),
2013           /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
2014     }
2015   }
2016 
2017   if (is_profiling) {
2018     if (!timer->Stop(AsGpuStream(stream))) {
2019       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2020     }
2021     auto algo_desc = *rnn_desc.algorithm_config().algorithm();
2022     output_profile_result->set_algorithm(algo_desc);
2023     output_profile_result->set_elapsed_time_in_ms(
2024         timer->GetElapsedMilliseconds());
2025   }
2026 
2027   return port::Status::OK();
2028 }
2029 
DoCtcLossImpl(Stream * stream,const CudnnRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const CudnnRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const CudnnCtcLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2030 port::Status CudnnSupport::DoCtcLossImpl(
2031     Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
2032     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2033     absl::Span<const int> labels_lengths_data,
2034     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2035     const CudnnRnnStateTensorDescriptor& grads_desc,
2036     DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
2037     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2038   auto cudnn = cudnn_->GetHandle(parent_, stream);
2039 
2040   int kNumTimestamps = probs_desc.num_layers();
2041   int kBatchSize = probs_desc.batch_size();
2042   int kNumLabels = probs_desc.data_size();
2043   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2044   (void)total_size;
2045 
2046 #if CUDNN_VERSION >= 7603
2047   cudnnCTCLossAlgo_t ctc_loss_algo =
2048       static_cast<cudnnCTCLossAlgo_t>(ctc_loss_algo_id);
2049   RETURN_IF_CUDNN_ERROR(cudnnCTCLoss(
2050       /*handle=*/cudnn.handle(), /*probsDesc=*/probs_desc.handle(),
2051       /*probs=*/probs_data.opaque(), /*labels=*/labels_data.data(),
2052       /*labelLengths=*/labels_lengths_data.data(),
2053       /*inputLengths=*/input_lengths_data.data(),
2054       /*costs=*/costs_data.opaque(), /*gradientsDesc=*/grads_desc.handle(),
2055       /*gradients=*/grads_data.opaque(),
2056       /*algo=*/ctc_loss_algo,
2057       /*ctcLossDesc=*/ctc_loss_desc.handle(),
2058       /*workspace=*/scratch_memory.opaque(),
2059       /*workSpaceSizeInBytes=*/scratch_memory.size()));
2060 #else
2061   return port::Status(port::error::INVALID_ARGUMENT,
2062                       "No supported cudnnCTCLoss when "
2063                       "CUDNN_VERSION < 7.6.3");
2064 #endif
2065 
2066   return port::Status::OK();
2067 }
2068 
2069 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2070 CudnnSupport::createRnnDescriptor(
2071     int num_layers, int hidden_size, int input_size, int cell_size,
2072     int batch_size, dnn::RnnInputMode input_mode,
2073     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2074     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2075     float dropout, uint64 seed, ScratchAllocator* state_allocator,
2076     bool use_padded_io) {
2077   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
2078   // not enqueueing anything into a stream, we pass in the null stream.
2079   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
2080   SE_ASSIGN_OR_RETURN(
2081       CudnnRnnDescriptor rnn_desc,
2082       CudnnRnnDescriptor::Create(
2083           cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
2084           ToCudnnRnnInputMode(input_mode),
2085           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
2086           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
2087           algorithm_config, dropout, seed, state_allocator, use_padded_io));
2088   return std::unique_ptr<dnn::RnnDescriptor>(
2089       new CudnnRnnDescriptor(std::move(rnn_desc)));
2090 }
2091 
2092 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)2093 CudnnSupport::createRnnSequenceTensorDescriptor(int max_seq_length,
2094                                                 int batch_size, int data_size,
2095                                                 dnn::DataType data_type) {
2096   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2097                       CudnnRnnSequenceTensorDescriptor::Create(
2098                           parent_, max_seq_length, batch_size, data_size,
2099                           ToCudnnDataType(data_type)));
2100   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2101       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2102 }
2103 
2104 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)2105 CudnnSupport::createRnnSequenceTensorDescriptor(
2106     int max_seq_length, int batch_size, int data_size,
2107     const absl::Span<const int>& seq_lengths, bool time_major,
2108     dnn::DataType data_type) {
2109   SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
2110                       CudnnRnnSequenceTensorDescriptor::Create(
2111                           parent_, max_seq_length, batch_size, data_size,
2112                           seq_lengths, time_major, ToCudnnDataType(data_type)));
2113   return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
2114       new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
2115 }
2116 
2117 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2118 CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2119                                              int data_size,
2120                                              dnn::DataType data_type) {
2121   return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
2122       new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
2123                                         data_size, ToCudnnDataType(data_type)));
2124 }
2125 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2126 bool CudnnSupport::DoRnnForward(
2127     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2128     const dnn::RnnSequenceTensorDescriptor& input_desc,
2129     const DeviceMemory<Eigen::half>& input_data,
2130     const dnn::RnnStateTensorDescriptor& input_h_desc,
2131     const DeviceMemory<Eigen::half>& input_h_data,
2132     const dnn::RnnStateTensorDescriptor& input_c_desc,
2133     const DeviceMemory<Eigen::half>& input_c_data,
2134     const DeviceMemory<Eigen::half>& params,
2135     const dnn::RnnSequenceTensorDescriptor& output_desc,
2136     DeviceMemory<Eigen::half>* output_data,
2137     const dnn::RnnStateTensorDescriptor& output_h_desc,
2138     DeviceMemory<Eigen::half>* output_h_data,
2139     const dnn::RnnStateTensorDescriptor& output_c_desc,
2140     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2141     ScratchAllocator* reserve_space_allocator,
2142     ScratchAllocator* workspace_allocator,
2143     dnn::ProfileResult* output_profile_result) {
2144   const CudnnRnnDescriptor& cudnn_rnn_desc =
2145       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2146   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2147       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2148   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2149       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2150   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2151       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2152   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2153       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2154   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2155       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2156   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2157       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2158   return IsStatusOk(
2159       DoRnnForwardImpl<Eigen::half>(
2160           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2161           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2162           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2163           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2164           reserve_space_allocator, workspace_allocator, output_profile_result),
2165       /*report_error=*/!output_profile_result);
2166 }
2167 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2168 bool CudnnSupport::DoRnnForward(
2169     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2170     const dnn::RnnSequenceTensorDescriptor& input_desc,
2171     const DeviceMemory<float>& input_data,
2172     const dnn::RnnStateTensorDescriptor& input_h_desc,
2173     const DeviceMemory<float>& input_h_data,
2174     const dnn::RnnStateTensorDescriptor& input_c_desc,
2175     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2176     const dnn::RnnSequenceTensorDescriptor& output_desc,
2177     DeviceMemory<float>* output_data,
2178     const dnn::RnnStateTensorDescriptor& output_h_desc,
2179     DeviceMemory<float>* output_h_data,
2180     const dnn::RnnStateTensorDescriptor& output_c_desc,
2181     DeviceMemory<float>* output_c_data, bool is_training,
2182     ScratchAllocator* reserve_space_allocator,
2183     ScratchAllocator* workspace_allocator,
2184     dnn::ProfileResult* output_profile_result) {
2185   const CudnnRnnDescriptor& cudnn_rnn_desc =
2186       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2187   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2188       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2189   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2190       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2191   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2192       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2193   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2194       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2195   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2196       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2197   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2198       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2199   return IsStatusOk(
2200       DoRnnForwardImpl<float>(
2201           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2202           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2203           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2204           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2205           reserve_space_allocator, workspace_allocator, output_profile_result),
2206       /*report_error=*/!output_profile_result);
2207 }
2208 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2209 bool CudnnSupport::DoRnnForward(
2210     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2211     const dnn::RnnSequenceTensorDescriptor& input_desc,
2212     const DeviceMemory<double>& input_data,
2213     const dnn::RnnStateTensorDescriptor& input_h_desc,
2214     const DeviceMemory<double>& input_h_data,
2215     const dnn::RnnStateTensorDescriptor& input_c_desc,
2216     const DeviceMemory<double>& input_c_data,
2217     const DeviceMemory<double>& params,
2218     const dnn::RnnSequenceTensorDescriptor& output_desc,
2219     DeviceMemory<double>* output_data,
2220     const dnn::RnnStateTensorDescriptor& output_h_desc,
2221     DeviceMemory<double>* output_h_data,
2222     const dnn::RnnStateTensorDescriptor& output_c_desc,
2223     DeviceMemory<double>* output_c_data, bool is_training,
2224     ScratchAllocator* reserve_space_allocator,
2225     ScratchAllocator* workspace_allocator,
2226     dnn::ProfileResult* output_profile_result) {
2227   const CudnnRnnDescriptor& cudnn_rnn_desc =
2228       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2229   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2230       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2231   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2232       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2233   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2234       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2235   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2236       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2237   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2238       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2239   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2240       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2241   return IsStatusOk(
2242       DoRnnForwardImpl<double>(
2243           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2244           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2245           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2246           output_h_data, cudnn_output_c_desc, output_c_data, is_training,
2247           reserve_space_allocator, workspace_allocator, output_profile_result),
2248       /*report_error=*/!output_profile_result);
2249 }
2250 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2251 bool CudnnSupport::DoRnnBackward(
2252     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2253     const dnn::RnnSequenceTensorDescriptor& input_desc,
2254     const DeviceMemory<Eigen::half>& input_data,
2255     const dnn::RnnStateTensorDescriptor& input_h_desc,
2256     const DeviceMemory<Eigen::half>& input_h_data,
2257     const dnn::RnnStateTensorDescriptor& input_c_desc,
2258     const DeviceMemory<Eigen::half>& input_c_data,
2259     const DeviceMemory<Eigen::half>& params,
2260     const dnn::RnnSequenceTensorDescriptor& output_desc,
2261     const DeviceMemory<Eigen::half>& output_data,
2262     const dnn::RnnStateTensorDescriptor& output_h_desc,
2263     const DeviceMemory<Eigen::half>& output_h_data,
2264     const dnn::RnnStateTensorDescriptor& output_c_desc,
2265     const DeviceMemory<Eigen::half>& output_c_data,
2266     const DeviceMemory<Eigen::half>& output_backprop_data,
2267     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2268     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2269     DeviceMemory<Eigen::half>* input_backprop_data,
2270     DeviceMemory<Eigen::half>* input_h_backprop_data,
2271     DeviceMemory<Eigen::half>* input_c_backprop_data,
2272     DeviceMemory<Eigen::half>* params_backprop_data,
2273     DeviceMemory<uint8>* reserve_space_data,
2274     ScratchAllocator* workspace_allocator,
2275     dnn::ProfileResult* output_profile_result) {
2276   const CudnnRnnDescriptor& cudnn_rnn_desc =
2277       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2278   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2279       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2280   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2281       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2282   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2283       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2284   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2285       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2286   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2287       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2288   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2289       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2290   return IsStatusOk(
2291       DoRnnBackwardImpl<Eigen::half>(
2292           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2293           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2294           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2295           output_h_data, cudnn_output_c_desc, output_c_data,
2296           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2297           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2298           params_backprop_data, reserve_space_data, workspace_allocator,
2299           output_profile_result),
2300       /*report_error=*/!output_profile_result);
2301 }
2302 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2303 bool CudnnSupport::DoRnnBackward(
2304     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2305     const dnn::RnnSequenceTensorDescriptor& input_desc,
2306     const DeviceMemory<float>& input_data,
2307     const dnn::RnnStateTensorDescriptor& input_h_desc,
2308     const DeviceMemory<float>& input_h_data,
2309     const dnn::RnnStateTensorDescriptor& input_c_desc,
2310     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2311     const dnn::RnnSequenceTensorDescriptor& output_desc,
2312     const DeviceMemory<float>& output_data,
2313     const dnn::RnnStateTensorDescriptor& output_h_desc,
2314     const DeviceMemory<float>& output_h_data,
2315     const dnn::RnnStateTensorDescriptor& output_c_desc,
2316     const DeviceMemory<float>& output_c_data,
2317     const DeviceMemory<float>& output_backprop_data,
2318     const DeviceMemory<float>& output_h_backprop_data,
2319     const DeviceMemory<float>& output_c_backprop_data,
2320     DeviceMemory<float>* input_backprop_data,
2321     DeviceMemory<float>* input_h_backprop_data,
2322     DeviceMemory<float>* input_c_backprop_data,
2323     DeviceMemory<float>* params_backprop_data,
2324     DeviceMemory<uint8>* reserve_space_data,
2325     ScratchAllocator* workspace_allocator,
2326     dnn::ProfileResult* output_profile_result) {
2327   const CudnnRnnDescriptor& cudnn_rnn_desc =
2328       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2329   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2330       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2331   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2332       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2333   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2334       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2335   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2336       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2337   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2338       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2339   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2340       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2341   return IsStatusOk(
2342       DoRnnBackwardImpl<float>(
2343           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2344           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2345           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2346           output_h_data, cudnn_output_c_desc, output_c_data,
2347           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2348           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2349           params_backprop_data, reserve_space_data, workspace_allocator,
2350           output_profile_result),
2351       /*report_error=*/!output_profile_result);
2352 }
2353 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2354 bool CudnnSupport::DoRnnBackward(
2355     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2356     const dnn::RnnSequenceTensorDescriptor& input_desc,
2357     const DeviceMemory<double>& input_data,
2358     const dnn::RnnStateTensorDescriptor& input_h_desc,
2359     const DeviceMemory<double>& input_h_data,
2360     const dnn::RnnStateTensorDescriptor& input_c_desc,
2361     const DeviceMemory<double>& input_c_data,
2362     const DeviceMemory<double>& params,
2363     const dnn::RnnSequenceTensorDescriptor& output_desc,
2364     const DeviceMemory<double>& output_data,
2365     const dnn::RnnStateTensorDescriptor& output_h_desc,
2366     const DeviceMemory<double>& output_h_data,
2367     const dnn::RnnStateTensorDescriptor& output_c_desc,
2368     const DeviceMemory<double>& output_c_data,
2369     const DeviceMemory<double>& output_backprop_data,
2370     const DeviceMemory<double>& output_h_backprop_data,
2371     const DeviceMemory<double>& output_c_backprop_data,
2372     DeviceMemory<double>* input_backprop_data,
2373     DeviceMemory<double>* input_h_backprop_data,
2374     DeviceMemory<double>* input_c_backprop_data,
2375     DeviceMemory<double>* params_backprop_data,
2376     DeviceMemory<uint8>* reserve_space_data,
2377     ScratchAllocator* workspace_allocator,
2378     dnn::ProfileResult* output_profile_result) {
2379   const CudnnRnnDescriptor& cudnn_rnn_desc =
2380       static_cast<const CudnnRnnDescriptor&>(rnn_desc);
2381   const CudnnRnnSequenceTensorDescriptor& cudnn_input_desc =
2382       static_cast<const CudnnRnnSequenceTensorDescriptor&>(input_desc);
2383   const CudnnRnnStateTensorDescriptor& cudnn_input_h_desc =
2384       static_cast<const CudnnRnnStateTensorDescriptor&>(input_h_desc);
2385   const CudnnRnnStateTensorDescriptor& cudnn_input_c_desc =
2386       static_cast<const CudnnRnnStateTensorDescriptor&>(input_c_desc);
2387   const CudnnRnnSequenceTensorDescriptor& cudnn_output_desc =
2388       static_cast<const CudnnRnnSequenceTensorDescriptor&>(output_desc);
2389   const CudnnRnnStateTensorDescriptor& cudnn_output_h_desc =
2390       static_cast<const CudnnRnnStateTensorDescriptor&>(output_h_desc);
2391   const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
2392       static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
2393   return IsStatusOk(
2394       DoRnnBackwardImpl<double>(
2395           stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
2396           cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
2397           params, cudnn_output_desc, output_data, cudnn_output_h_desc,
2398           output_h_data, cudnn_output_c_desc, output_c_data,
2399           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2400           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2401           params_backprop_data, reserve_space_data, workspace_allocator,
2402           output_profile_result),
2403       /*report_error=*/!output_profile_result);
2404 }
2405 
2406 namespace {
2407 
2408 // TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
2409 // and backward filter.
2410 
GetCudnnConvolutionForwardAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2411 port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
2412     const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
2413     const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
2414     const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
2415     size_t memory_limit_bytes) {
2416 #if CUDNN_VERSION >= 8000
2417   const int num_requested_algos = 5;
2418   int num_returned_algos = 0;
2419   cudnnConvolutionFwdAlgoPerf_t perf_results[num_requested_algos];
2420 
2421   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(
2422       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2423       output_nd.handle(), num_requested_algos, &num_returned_algos,
2424       perf_results));
2425 
2426   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2427   for (int r = 0; r < num_returned_algos; r++) {
2428     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2429         perf_results[r].algo != CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
2430         perf_results[r].memory <= mem_limit) {
2431       return perf_results[r].algo;
2432     }
2433   }
2434   return port::Status(port::error::INTERNAL,
2435                       "cudnnGetConvolutionForwardAlgorithm_v7 returned "
2436                       "no suitable algorithms. This could be a cudnn bug.");
2437 #else
2438   cudnnConvolutionFwdPreference_t preference =
2439       specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
2440                               : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
2441   cudnnConvolutionFwdAlgo_t algo_to_use;
2442   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
2443       cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
2444       output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2445   return algo_to_use;
2446 #endif
2447 }
2448 
2449 port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2450 GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
2451                                     const CudnnTensorDescriptor& input_nd,
2452                                     const CudnnFilterDescriptor& filter,
2453                                     const CudnnConvolutionDescriptor& conv,
2454                                     const CudnnTensorDescriptor& output_nd,
2455                                     bool specify_workspace_limit,
2456                                     size_t memory_limit_bytes) {
2457 #if CUDNN_VERSION >= 8000
2458   const int num_requested_algos = 5;
2459   int num_returned_algos = 0;
2460   cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_requested_algos];
2461 
2462   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(
2463       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2464       input_nd.handle(), num_requested_algos, &num_returned_algos,
2465       perf_results));
2466 
2467   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2468   for (int r = 0; r < num_returned_algos; r++) {
2469     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2470         perf_results[r].algo !=
2471             CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED &&
2472         perf_results[r].memory <= mem_limit) {
2473       return perf_results[r].algo;
2474     }
2475   }
2476   return port::Status(port::error::INTERNAL,
2477                       "cudnnGetConvolutionBackwardDataAlgorithm_v7 returned "
2478                       "no suitable algorithms. This could be a cudnn bug.");
2479 #else
2480   cudnnConvolutionBwdDataPreference_t preference =
2481       specify_workspace_limit
2482           ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
2483           : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
2484   cudnnConvolutionBwdDataAlgo_t algo_to_use;
2485   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
2486       cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
2487       input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
2488   return algo_to_use;
2489 #endif
2490 }
2491 
2492 port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,bool specify_workspace_limit,size_t memory_limit_bytes)2493 GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
2494                                       const CudnnTensorDescriptor& input_nd,
2495                                       const CudnnFilterDescriptor& filter,
2496                                       const CudnnConvolutionDescriptor& conv,
2497                                       const CudnnTensorDescriptor& output_nd,
2498                                       bool specify_workspace_limit,
2499                                       size_t memory_limit_bytes) {
2500 #if CUDNN_VERSION >= 8000
2501   const int num_requested_algos = 5;
2502   int num_returned_algos = 0;
2503   cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_requested_algos];
2504 
2505   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
2506       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2507       filter.handle(), num_requested_algos, &num_returned_algos, perf_results));
2508 
2509   size_t mem_limit = specify_workspace_limit ? memory_limit_bytes : 0ULL;
2510   for (int r = 0; r < num_returned_algos; r++) {
2511     if (perf_results[r].status == CUDNN_STATUS_SUCCESS &&
2512         perf_results[r].algo !=
2513             CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED &&
2514         perf_results[r].memory <= mem_limit) {
2515       return perf_results[r].algo;
2516     }
2517   }
2518   return port::Status(port::error::INTERNAL,
2519                       "cudnnGetConvolutionBackwardFilterAlgorithm_v7 returned "
2520                       "no suitable algorithms. This could be a cudnn bug.");
2521 #else
2522   cudnnConvolutionBwdFilterPreference_t preference =
2523       specify_workspace_limit
2524           ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
2525           : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
2526   cudnnConvolutionBwdFilterAlgo_t algo_to_use;
2527   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
2528       cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
2529       filter.handle(), preference, memory_limit_bytes, &algo_to_use));
2530   return algo_to_use;
2531 #endif
2532 }
2533 
AllocateCudnnConvolutionForwardWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2534 port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
2535     Stream* stream, const CudnnHandle& cudnn,
2536     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2537     const CudnnConvolutionDescriptor& conv,
2538     const CudnnTensorDescriptor& output_nd,
2539     const dnn::AlgorithmDesc& algorithm_desc,
2540     ScratchAllocator* scratch_allocator) {
2541   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2542     return port::Status(
2543         port::error::INTERNAL,
2544         "Mismatch between cudnn conv and algorithm descriptors.");
2545   }
2546 
2547   // Query the size of the workspace and allocate it.
2548   size_t size_in_bytes;
2549   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
2550       cudnn.handle(),
2551       /*xDesc=*/input_nd.handle(),
2552       /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2553       /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
2554       /*sizeInBytes=*/&size_in_bytes));
2555 
2556   int64 size_in_bytes_int64 = size_in_bytes;
2557 
2558   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2559     return port::Status(
2560         port::error::INTERNAL,
2561         "cudnnGetConvolutionForwardWorkspaceSize() returned "
2562         "negative sizeInBytes value. This could be a cudnn bug.");
2563   }
2564 
2565   if (size_in_bytes_int64 == 0) {
2566     return DeviceMemory<uint8>();
2567   }
2568 
2569   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2570     return port::Status(port::error::INVALID_ARGUMENT,
2571                         "No scratch allocator provided");
2572   }
2573 
2574   return scratch_allocator->AllocateBytes(size_in_bytes);
2575 }
2576 
2577 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardDataWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2578 AllocateCudnnConvolutionBackwardDataWorkspace(
2579     Stream* stream, const CudnnHandle& cudnn,
2580     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2581     const CudnnConvolutionDescriptor& conv,
2582     const CudnnTensorDescriptor& output_nd,
2583     const dnn::AlgorithmDesc& algorithm_desc,
2584     ScratchAllocator* scratch_allocator) {
2585   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2586     return port::Status(
2587         port::error::INTERNAL,
2588         "Mismatch between cudnn conv and algorithm descriptors.");
2589   }
2590 
2591   // Query the size of the workspace and allocate it.
2592   size_t size_in_bytes;
2593   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
2594       cudnn.handle(),
2595       /*wDesc=*/filter.handle(),
2596       /*dyDesc=*/output_nd.handle(),
2597       /*convDesc=*/conv.handle(),
2598       /*dxDesc=*/input_nd.handle(),
2599       /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
2600       /*sizeInBytes=*/&size_in_bytes));
2601 
2602   int64 size_in_bytes_int64 = size_in_bytes;
2603 
2604   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2605     return port::Status(
2606         port::error::INTERNAL,
2607         "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
2608         "negative sizeInBytes value. This could be a cudnn bug.");
2609   }
2610 
2611   if (size_in_bytes_int64 == 0) {
2612     return DeviceMemory<uint8>();
2613   }
2614 
2615   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2616     return port::Status(port::error::INVALID_ARGUMENT,
2617                         "No scratch allocator provided");
2618   }
2619 
2620   return scratch_allocator->AllocateBytes(size_in_bytes);
2621 }
2622 
2623 port::StatusOr<DeviceMemory<uint8>>
AllocateCudnnConvolutionBackwardFilterWorkspace(Stream * stream,const CudnnHandle & cudnn,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,const CudnnConvolutionDescriptor & conv,const CudnnTensorDescriptor & output_nd,const dnn::AlgorithmDesc & algorithm_desc,ScratchAllocator * scratch_allocator)2624 AllocateCudnnConvolutionBackwardFilterWorkspace(
2625     Stream* stream, const CudnnHandle& cudnn,
2626     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2627     const CudnnConvolutionDescriptor& conv,
2628     const CudnnTensorDescriptor& output_nd,
2629     const dnn::AlgorithmDesc& algorithm_desc,
2630     ScratchAllocator* scratch_allocator) {
2631   if (IsTensorMathOpSet(conv) != algorithm_desc.tensor_ops_enabled()) {
2632     return port::Status(
2633         port::error::INTERNAL,
2634         "Mismatch between cudnn conv and algorithm descriptors.");
2635   }
2636 
2637   // Query the size of the workspace and allocate it.
2638   size_t size_in_bytes;
2639   RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
2640       cudnn.handle(),
2641       /*xDesc=*/input_nd.handle(),
2642       /*dyDesc=*/output_nd.handle(),
2643       /*convDesc=*/conv.handle(),
2644       /*gradDesc=*/filter.handle(),
2645       /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
2646       /*sizeInBytes=*/&size_in_bytes));
2647 
2648   int64 size_in_bytes_int64 = size_in_bytes;
2649 
2650   if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
2651     return port::Status(
2652         port::error::INTERNAL,
2653         "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
2654         "negative sizeInBytes value. This could be a cudnn bug.");
2655   }
2656 
2657   if (size_in_bytes_int64 == 0) {
2658     return DeviceMemory<uint8>();
2659   }
2660 
2661   if (TF_PREDICT_FALSE(!scratch_allocator)) {
2662     return port::Status(port::error::INVALID_ARGUMENT,
2663                         "No scratch allocator provided");
2664   }
2665 
2666   return scratch_allocator->AllocateBytes(size_in_bytes);
2667 }
2668 
UseTensorOps(Stream * stream,dnn::DataType type,absl::optional<dnn::AlgorithmDesc> desc)2669 port::StatusOr<bool> UseTensorOps(Stream* stream, dnn::DataType type,
2670                                   absl::optional<dnn::AlgorithmDesc> desc) {
2671   bool use_tensor_ops;
2672   if (desc.has_value()) {
2673     use_tensor_ops = desc->tensor_ops_enabled();
2674     if (use_tensor_ops && !IsTensorMathEnabled(stream, type)) {
2675       return port::Status(port::error::INVALID_ARGUMENT,
2676                           "Algo requests disabled tensor op evaluation.");
2677     }
2678   } else {
2679     use_tensor_ops = IsTensorMathEnabled(stream, type);
2680   }
2681   return use_tensor_ops;
2682 }
2683 
2684 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
2685 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type);
2686 
GetCudnnConvolutionForwardAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2687 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
2688     Stream* stream, const CudnnHandle& cudnn,
2689     const dnn::AlgorithmConfig& algorithm_config,
2690     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2691     dnn::DataType element_type,
2692     const dnn::ConvolutionDescriptor& convolution_descriptor,
2693     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2694     DeviceMemory<uint8>* scratch) {
2695   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2696 
2697   CudnnConvolutionDescriptor conv(
2698       convolution_descriptor,
2699       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2700   bool use_tensor_ops;
2701   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2702                       UseTensorOps(stream, element_type, algo_desc));
2703   conv.set_use_tensor_op_math(use_tensor_ops);
2704 
2705   if (!algo_desc.has_value()) {
2706     // Pick fastest algorithm within memory limit according to cuDNN's
2707     // heuristics.
2708     bool specify_workspace_limit = scratch_allocator != nullptr;
2709     auto memory_limit_bytes =
2710         specify_workspace_limit
2711             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2712             : int64{0};
2713     SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
2714                         GetCudnnConvolutionForwardAlgo(
2715                             cudnn, input_nd, filter, conv, output_nd,
2716                             specify_workspace_limit, memory_limit_bytes));
2717     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2718   }
2719 
2720   const auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
2721       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2722       scratch_allocator);
2723 
2724   if (scratch_or.ok()) {
2725     *scratch = scratch_or.ValueOrDie();
2726     return *algo_desc;
2727   }
2728 
2729   algo_desc = algorithm_config.algorithm_no_scratch();
2730 
2731   // Failed to allocate workspace for the first algorithm, fall back to the
2732   // no_scratch algorithm.
2733   if (!algo_desc.has_value()) {
2734     return port::Status(
2735         scratch_or.status().code(),
2736         absl::StrCat("The primary convolution algorithm failed, ",
2737                      "while a secondary algorithm is not provided. ",
2738                      "Returned status: ", scratch_or.status().ToString()));
2739   }
2740 
2741   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2742                       UseTensorOps(stream, element_type, algo_desc));
2743   conv.set_use_tensor_op_math(use_tensor_ops);
2744   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionForwardWorkspace(
2745                                     stream, cudnn, input_nd, filter, conv,
2746                                     output_nd, *algo_desc, scratch_allocator));
2747   return *algo_desc;
2748 }
2749 
GetCudnnConvolutionBackwardDataAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2750 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
2751     Stream* stream, const CudnnHandle& cudnn,
2752     const dnn::AlgorithmConfig& algorithm_config,
2753     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2754     dnn::DataType element_type,
2755     const dnn::ConvolutionDescriptor& convolution_descriptor,
2756     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2757     DeviceMemory<uint8>* scratch) {
2758   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2759   CudnnConvolutionDescriptor conv(
2760       convolution_descriptor,
2761       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2762   bool use_tensor_ops;
2763   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2764                       UseTensorOps(stream, element_type, algo_desc));
2765   conv.set_use_tensor_op_math(use_tensor_ops);
2766 
2767   if (!algo_desc.has_value()) {
2768     // Pick fastest algorithm within memory limit according to cuDNN's
2769     // heuristics.
2770     bool specify_workspace_limit = scratch_allocator != nullptr;
2771     auto memory_limit_bytes =
2772         specify_workspace_limit
2773             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2774             : int64{0};
2775     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
2776                         GetCudnnConvolutionBackwardDataAlgo(
2777                             cudnn, input_nd, filter, conv, output_nd,
2778                             specify_workspace_limit, memory_limit_bytes));
2779     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2780   }
2781 
2782   const auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
2783       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2784       scratch_allocator);
2785 
2786   if (scratch_or.ok()) {
2787     *scratch = scratch_or.ValueOrDie();
2788     return *algo_desc;
2789   }
2790 
2791   algo_desc = algorithm_config.algorithm_no_scratch();
2792 
2793   // Failed to allocate workspace for the first algorithm, fall back to the
2794   // no_scratch algorithm.
2795   if (!algo_desc.has_value()) {
2796     return port::Status(
2797         port::error::INVALID_ARGUMENT,
2798         "The primary convolution algorithm failed memory allocation, "
2799         "while a secondary algorithm is not provided.");
2800   }
2801 
2802   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2803                       UseTensorOps(stream, element_type, algo_desc));
2804   conv.set_use_tensor_op_math(use_tensor_ops);
2805   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
2806                                     stream, cudnn, input_nd, filter, conv,
2807                                     output_nd, *algo_desc, scratch_allocator));
2808   return *algo_desc;
2809 }
2810 
GetCudnnConvolutionBackwardFilterAlgorithm(Stream * stream,const CudnnHandle & cudnn,const dnn::AlgorithmConfig & algorithm_config,const CudnnTensorDescriptor & input_nd,const CudnnFilterDescriptor & filter,dnn::DataType element_type,const dnn::ConvolutionDescriptor & convolution_descriptor,const CudnnTensorDescriptor & output_nd,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch)2811 port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
2812     Stream* stream, const CudnnHandle& cudnn,
2813     const dnn::AlgorithmConfig& algorithm_config,
2814     const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
2815     dnn::DataType element_type,
2816     const dnn::ConvolutionDescriptor& convolution_descriptor,
2817     const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
2818     DeviceMemory<uint8>* scratch) {
2819   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2820   CudnnConvolutionDescriptor conv(
2821       convolution_descriptor,
2822       ToCudnnDataType(GetConvAccumulatorType(element_type)));
2823   bool use_tensor_ops;
2824   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2825                       UseTensorOps(stream, element_type, algo_desc));
2826   conv.set_use_tensor_op_math(use_tensor_ops);
2827 
2828   if (!algo_desc.has_value()) {
2829     // Pick fastest algorithm within memory limit according to cuDNN's
2830     // heuristics.
2831     bool specify_workspace_limit = scratch_allocator != nullptr;
2832     auto memory_limit_bytes =
2833         specify_workspace_limit
2834             ? std::max(scratch_allocator->GetMemoryLimitInBytes(), int64{0})
2835             : int64{0};
2836     SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
2837                         GetCudnnConvolutionBackwardFilterAlgo(
2838                             cudnn, input_nd, filter, conv, output_nd,
2839                             specify_workspace_limit, memory_limit_bytes));
2840     algo_desc = dnn::AlgorithmDesc(algo, use_tensor_ops);
2841   }
2842 
2843   auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
2844       stream, cudnn, input_nd, filter, conv, output_nd, *algo_desc,
2845       scratch_allocator);
2846 
2847   if (scratch_or.ok()) {
2848     *scratch = scratch_or.ValueOrDie();
2849     return *algo_desc;
2850   }
2851 
2852   algo_desc = algorithm_config.algorithm_no_scratch();
2853 
2854   // Failed to allocate workspace for the first algorithm, fall back to the
2855   // no_scratch algorithm.
2856   if (!algo_desc.has_value()) {
2857     return port::Status(
2858         port::error::INVALID_ARGUMENT,
2859         "The primary convolution algorithm failed memory allocation, "
2860         "while a secondary algorithm is not provided.");
2861   }
2862 
2863   SE_ASSIGN_OR_RETURN(use_tensor_ops,
2864                       UseTensorOps(stream, element_type, algo_desc));
2865   conv.set_use_tensor_op_math(use_tensor_ops);
2866   SE_ASSIGN_OR_RETURN(*scratch, AllocateCudnnConvolutionBackwardFilterWorkspace(
2867                                     stream, cudnn, input_nd, filter, conv,
2868                                     output_nd, *algo_desc, scratch_allocator));
2869   return *algo_desc;
2870 }
2871 
2872 // A helper class to set env-vars and choose options for cudnn-related
2873 // algorithms.
2874 template <typename EnvVar>
2875 class CudnnEnvVar {
2876  public:
IsEnabled()2877   static bool IsEnabled() {
2878     static bool is_enabled = IsEnabledImpl();
2879     return is_enabled;
2880   }
2881 
2882  private:
IsEnabledImpl()2883   static bool IsEnabledImpl() {
2884     const char* tf_env_var_val = getenv(EnvVar::kName);
2885     if (tf_env_var_val != nullptr) {
2886       absl::string_view tf_env_var_val_str(tf_env_var_val);
2887       if (tf_env_var_val_str == "0") {
2888         return false;
2889       }
2890       return true;
2891     }
2892     return EnvVar::kDefaultFlag;
2893   }
2894 };
2895 
2896 // A helper struct to decide whether to enable the FFT_TILING algorithms for
2897 // forward convolution. It is disabled for cuDNN < 7 due to memory corruption
2898 // caused by some shapes with this algorithm. Users can explicitly enable the
2899 // algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
2900 struct FftTilingForward {
2901   static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
2902   static constexpr bool kDefaultFlag = true;
2903 };
2904 
2905 // A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
2906 // By default it is turned on, users can explicitly disable them through an
2907 // env-var "TF_ENABLE_WINOGRAD_NONFUSED=0".
2908 // https://github.com/tensorflow/tensorflow/pull/4901
2909 struct WinogradNonfused {
2910   static constexpr const char* kName = "TF_ENABLE_WINOGRAD_NONFUSED";
2911   // NVIDIA has fixed winograd nonfused bug for cudnn v>=7. For older versions,
2912   // we have a workaround.
2913   static constexpr bool kDefaultFlag = true;
2914 };
2915 
2916 // A helper struct to decide whether to use FP32 as the internal compute type
2917 // for convolution when the input data type is FP16. By default it is turned on,
2918 // users can explicitly disable them (choose to use FP16 as the internal compute
2919 // type) through an env-var "TF_FP16_CONV_USE_FP32_COMPUTE=0".
2920 struct ConvDoFP32ComputationFP16Input {
2921   static constexpr const char* kName = "TF_FP16_CONV_USE_FP32_COMPUTE";
2922   // Using FP16 as the internal compute type for convolution when the input data
2923   // type is FP16 is only supported on architectures with true fp16 support
2924   // (compute capability 5.3 and 6.0). Setting this to false in an unsupported
2925   // architecture will cause internal errors.
2926   static constexpr bool kDefaultFlag = true;
2927 };
2928 
2929 // A helper struct to decide whether to use FP32 as the internal compute type
2930 // for rnn when the input data type is FP16. At present it is turned off,
2931 // users can explicitly control them through an env-var
2932 // TF_FP16_RNN_USE_FP32_COMPUTE.
2933 // After the TODO below is fixed, users should almost always use fp32 compute
2934 // type for training. Using fp16 might suffer suboptimal accuracy due to loss
2935 // in precision.
2936 struct RnnDoFP32ComputationFP16Input {
2937   static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
2938   // TODO(jamesqin): b/78182362 flip to true when cudnn 7.1.4 fixes the bug.
2939   // Before cudnn 7.1.4 RNN are always done in fp32, no matter what math
2940   // precision is set.
2941   // Set it temporary to false s.t. no error is raised when using fp16 inputs,
2942   // fp32 math precision.
2943   //
2944   // cuDNN == 7.5.0 is verified to have this fixed.
2945   static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7500;
2946 };
2947 
GetRnnComputeType(dnn::DataType data_type)2948 cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
2949   switch (data_type) {
2950     case dnn::DataType::kFloat:
2951       return CUDNN_DATA_FLOAT;
2952     case dnn::DataType::kDouble:
2953       return CUDNN_DATA_DOUBLE;
2954     case dnn::DataType::kHalf:
2955       if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
2956         return CUDNN_DATA_FLOAT;
2957       } else {
2958         return CUDNN_DATA_HALF;
2959       }
2960     default:
2961       LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
2962   }
2963 }
2964 
GetConvAccumulatorType(dnn::DataType data_type)2965 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
2966   switch (data_type) {
2967     case dnn::DataType::kFloat:
2968     case dnn::DataType::kDouble:
2969       return data_type;
2970     case dnn::DataType::kHalf:
2971       return CudnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled()
2972                  ? dnn::DataType::kFloat
2973                  : dnn::DataType::kHalf;
2974     case dnn::DataType::kInt8:
2975     case dnn::DataType::kInt32:
2976       return dnn::DataType::kInt32;
2977     default:
2978       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
2979   }
2980 }
2981 }  // namespace
2982 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2983 port::Status CudnnSupport::DoPrepareForConvolution(
2984     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2985     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2986     const dnn::FilterDescriptor& filter_descriptor,
2987     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2988     DeviceMemoryBase output_data,
2989     const dnn::ConvolutionDescriptor& convolution_descriptor,
2990     const dnn::AlgorithmConfig& algorithm_config,
2991     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2992     DeviceMemory<uint8>* scratch_memory) {
2993   CudnnTensorDescriptor input_nd(
2994       input_descriptor,
2995       ToCudnnDataType(element_type, input_descriptor.layout()));
2996   CudnnFilterDescriptor filter_nd(
2997       filter_descriptor,
2998       ToCudnnDataType(element_type, filter_descriptor.layout()));
2999   CudnnTensorDescriptor output_nd(
3000       output_descriptor,
3001       ToCudnnDataType(element_type, output_descriptor.layout()));
3002 
3003   auto cudnn = cudnn_->GetHandle(parent_, stream);
3004 
3005   switch (kind) {
3006     case dnn::ConvolutionKind::FORWARD: {
3007       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3008                           GetCudnnConvolutionForwardAlgorithm(
3009                               stream, cudnn, algorithm_config, input_nd,
3010                               filter_nd, element_type, convolution_descriptor,
3011                               output_nd, scratch_allocator, scratch_memory));
3012       break;
3013     }
3014     case dnn::ConvolutionKind::BACKWARD_DATA: {
3015       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3016                           GetCudnnConvolutionBackwardDataAlgorithm(
3017                               stream, cudnn, algorithm_config, input_nd,
3018                               filter_nd, element_type, convolution_descriptor,
3019                               output_nd, scratch_allocator, scratch_memory));
3020       break;
3021     }
3022     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3023       SE_ASSIGN_OR_RETURN(*algorithm_desc,
3024                           GetCudnnConvolutionBackwardFilterAlgorithm(
3025                               stream, cudnn, algorithm_config, input_nd,
3026                               filter_nd, element_type, convolution_descriptor,
3027                               output_nd, scratch_allocator, scratch_memory));
3028       break;
3029     }
3030     default:
3031       return port::InternalError(
3032           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3033   }
3034 
3035   return port::Status::OK();
3036 }
3037 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)3038 port::Status CudnnSupport::DoConvolve(
3039     dnn::ConvolutionKind kind, dnn::DataType element_type,
3040     dnn::DataType output_type, Stream* stream,
3041     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3042     const dnn::FilterDescriptor& filter_descriptor,
3043     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3044     DeviceMemoryBase output_data,
3045     const dnn::ConvolutionDescriptor& convolution_descriptor,
3046     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
3047     dnn::ProfileResult* output_profile_result) {
3048   cudnnDataType_t cudnn_type = ToCudnnDataType(element_type);
3049   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
3050   CudnnTensorDescriptor output_nd(output_descriptor,
3051                                   ToCudnnDataType(output_type));
3052   CudnnFilterDescriptor filter_nd(filter_descriptor, cudnn_type);
3053   auto accumulator_type = GetConvAccumulatorType(element_type);
3054   CudnnConvolutionDescriptor conv(convolution_descriptor,
3055                                   ToCudnnDataType(accumulator_type));
3056   SE_ASSIGN_OR_RETURN(bool use_tensor_ops,
3057                       UseTensorOps(stream, element_type, algorithm_desc));
3058   conv.set_use_tensor_op_math(use_tensor_ops);
3059 
3060   auto cudnn = cudnn_->GetHandle(parent_, stream);
3061   // Alpha is the scaling factor for input.
3062   float falpha = 1.0;
3063   double dalpha = 1.0;
3064   void* alpha = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dalpha)
3065                                                 : static_cast<void*>(&falpha);
3066   // Beta is the scaling factor for output.
3067   float fbeta = 0.0;
3068   double dbeta = 0.0;
3069   void* beta = cudnn_type == CUDNN_DATA_DOUBLE ? static_cast<void*>(&dbeta)
3070                                                : static_cast<void*>(&fbeta);
3071 
3072   const bool is_profiling = output_profile_result != nullptr;
3073 
3074   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3075   if (is_profiling) {
3076     timer.reset(new GpuTimer(parent_));  // NOLINT
3077     // The start and stop of the timer should be as close to the Cudnn call as
3078     // possible. It is still possible for other threads to issue workload on
3079     // to this stream. So it could take multiple profiling measurements.
3080     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3081       return port::Status(port::error::INTERNAL, "Failed to start timer");
3082     }
3083   }
3084 
3085   const auto get_fwd_bugs = [&]() -> port::Status {
3086     if (CUDNN_VERSION < 8000) {
3087       if (algorithm_desc.algo_id() ==
3088               CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM &&
3089           ToCudnnDataType(element_type) == CUDNN_DATA_INT8 &&
3090           ToCudnnDataType(output_type) == CUDNN_DATA_FLOAT) {
3091         return port::Status(
3092             port::error::FAILED_PRECONDITION,
3093             "This configuration potentially produces incorrect results.");
3094       }
3095     }
3096     return port::Status::OK();
3097   };
3098 
3099   auto get_bwd_data_bugs = [&]() -> port::Status {
3100     return port::Status::OK();
3101   };
3102 
3103   const auto get_bwd_filter_bugs = [&]() -> port::Status {
3104     return port::Status::OK();
3105   };
3106 
3107   switch (kind) {
3108     case dnn::ConvolutionKind::FORWARD: {
3109       SE_RETURN_IF_ERROR(get_fwd_bugs());
3110       RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
3111           cudnn.handle(),
3112           /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
3113           /*srcData=*/input_data.opaque(), /*filterDesc=*/filter_nd.handle(),
3114           /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
3115           /*algo=*/ToConvForwardAlgo(algorithm_desc),
3116           /*workSpace=*/scratch_memory.opaque(),
3117           /*workSpaceSizeInBytes=*/scratch_memory.size(), /*beta=*/beta,
3118           /*yDesc=*/output_nd.handle(), /*y=*/output_data.opaque()));
3119       break;
3120     }
3121     case dnn::ConvolutionKind::BACKWARD_DATA: {
3122       SE_RETURN_IF_ERROR(get_bwd_data_bugs());
3123       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardData(
3124           cudnn.handle(),
3125           /*alpha=*/alpha,
3126           /*wDesc=*/filter_nd.handle(),
3127           /*w=*/filter_data.opaque(),
3128           /*dyDesc=*/output_nd.handle(),
3129           /*dy=*/output_data.opaque(),
3130           /*convDesc=*/conv.handle(),
3131           /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
3132           /*workSpace=*/scratch_memory.opaque(),
3133           /*workSpaceSizeInBytes=*/scratch_memory.size(),
3134           /*beta=*/beta,
3135           /*dxDesc=*/input_nd.handle(),
3136           /*dx=*/input_data.opaque()));
3137       break;
3138     }
3139     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3140       SE_RETURN_IF_ERROR(get_bwd_filter_bugs());
3141       RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
3142           cudnn.handle(),
3143           /*alpha=*/alpha,
3144           /*srcDesc=*/input_nd.handle(),
3145           /*srcData=*/input_data.opaque(),
3146           /*diffDesc=*/output_nd.handle(),
3147           /*diffData=*/output_data.opaque(),
3148           /*convDesc=*/conv.handle(),
3149           /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
3150           /*workSpace=*/scratch_memory.opaque(),
3151           /*workSpaceSizeInBytes=*/scratch_memory.size(),
3152           /*beta=*/beta,
3153           /*gradDesc=*/filter_nd.handle(),
3154           /*dw=*/filter_data.opaque()));
3155       break;
3156     }
3157     default:
3158       return port::InternalError(
3159           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3160   }
3161 
3162   if (is_profiling) {
3163     if (!timer->Stop(AsGpuStream(stream))) {
3164       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3165     }
3166     output_profile_result->set_algorithm(algorithm_desc);
3167     output_profile_result->set_elapsed_time_in_ms(
3168         timer->GetElapsedMilliseconds());
3169     output_profile_result->set_scratch_size(scratch_memory.size());
3170   }
3171 
3172   return port::Status::OK();
3173 }
3174 
3175 template <typename ElementType, typename BiasType, typename ScaleType,
3176           typename OutputType>
DoFusedConvolveImpl(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<ElementType> & conv_input_data,ScaleType conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<ElementType> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<OutputType> & side_input_data,ScaleType side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<BiasType> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<OutputType> * output_data,dnn::DataType accumulator_type,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3177 port::Status CudnnSupport::DoFusedConvolveImpl(
3178     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3179     const DeviceMemory<ElementType>& conv_input_data,
3180     ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
3181     const DeviceMemory<ElementType>& filter_data,
3182     const dnn::ConvolutionDescriptor& convolution_descriptor,
3183     const DeviceMemory<OutputType>& side_input_data, ScaleType side_input_scale,
3184     const dnn::BatchDescriptor& bias_descriptor,
3185     const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
3186     const dnn::BatchDescriptor& output_descriptor,
3187     DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
3188     ScratchAllocator* scratch_allocator,
3189     const dnn::AlgorithmConfig& algorithm_config,
3190     dnn::ProfileResult* output_profile_result) {
3191   if (activation_mode != dnn::ActivationMode::kRelu &&
3192       activation_mode != dnn::ActivationMode::kNone) {
3193     return port::Status(port::error::INVALID_ARGUMENT,
3194                         "cudnnConvolutionBiasActivationForward() only supports "
3195                         "Relu or None activation.");
3196   }
3197 
3198   CudnnTensorDescriptor conv_input_nd(
3199       conv_input_descriptor,
3200       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3201   CudnnTensorDescriptor output_nd(
3202       output_descriptor,
3203       GetCudnnDataType<OutputType>(conv_input_descriptor.layout()));
3204   CudnnFilterDescriptor filter(
3205       filter_descriptor,
3206       GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
3207   CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
3208 
3209   auto cudnn = cudnn_->GetHandle(parent_, stream);
3210 
3211   const bool is_profiling = output_profile_result != nullptr;
3212 
3213   DeviceMemory<uint8> scratch;
3214   SE_ASSIGN_OR_RETURN(
3215       dnn::AlgorithmDesc algo_desc,
3216       GetCudnnConvolutionForwardAlgorithm(
3217           stream, cudnn, algorithm_config, conv_input_nd, filter,
3218           dnn::ToDataType<ElementType>::value, convolution_descriptor,
3219           output_nd, scratch_allocator, &scratch));
3220 
3221   CudnnConvolutionDescriptor conv(convolution_descriptor,
3222                                   ToCudnnDataType(accumulator_type));
3223   conv.set_use_tensor_op_math(algo_desc.tensor_ops_enabled());
3224 
3225   std::unique_ptr<GpuTimer, GpuTimerDeleter> timer;
3226   if (is_profiling) {
3227     timer.reset(new GpuTimer(parent_));  // NOLINT
3228     // The start and stop of the timer should be as close to the Cudnn call as
3229     // possible. It is still possible for other threads to issue workload on
3230     // to this stream. So it could take multiple profiling measurements.
3231     if (!timer->Init() || !timer->Start(AsGpuStream(stream))) {
3232       return port::Status(port::error::INTERNAL, "Failed to start timer");
3233     }
3234   }
3235   // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
3236   // activation descriptor. Note that this will change the nan propagation
3237   // behavior from separate conv, bias, and relu (which by default is
3238   // CUDNN_PROPAGATE_NAN.
3239   CudnnActivationDescriptor activation_desc(
3240       activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
3241   auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
3242                                                      : side_input_data.opaque();
3243 
3244   VLOG(2) << "\nconv_input_scale = " << conv_input_scale
3245           << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
3246           << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
3247           << "\nfilter.handle() = " << filter.handle()
3248           << "\nfilter_data.opaque() = " << filter_data.opaque()
3249           << "\nconv.handle() = " << conv.handle()
3250           << "\nalgo = " << algo_desc.algo_id()
3251           << "\nscratch.opaque() = " << scratch.opaque()
3252           << "\nscratch.size() = " << scratch.size()
3253           << "\nside_input_scale = " << side_input_scale
3254           << "\noutput_nd.handle() = " << output_nd.handle()
3255           << "\nside_input_data_ptr = " << side_input_data_ptr
3256           << "\nbias_nd.handle() = " << bias_nd.handle()
3257           << "\nbiases.opaque() = " << biases.opaque()
3258           << "\nactivation_desc.handle() = " << activation_desc.handle()
3259           << "\noutput_nd.handle() = " << output_nd.handle()
3260           << "\noutput_data->opaque() = " << output_data->opaque();
3261 
3262   if (IsTensorMathOpSet(conv) != algo_desc.tensor_ops_enabled()) {
3263     return port::Status(port::error::FAILED_PRECONDITION,
3264                         "Tensor op math type in dnn::AlgorithmDesc does not "
3265                         "match that of the CudnnConvolutionDescriptor");
3266   }
3267 
3268   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
3269       cudnn.handle(),
3270       /*alpha1=*/&conv_input_scale,
3271       /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
3272       /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
3273       /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
3274       /*workSpace=*/scratch.opaque(),
3275       /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
3276       /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
3277       /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
3278       /*activationDesc=*/activation_desc.handle(),
3279       /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
3280 
3281   if (is_profiling) {
3282     if (!timer->Stop(AsGpuStream(stream))) {
3283       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3284     }
3285     output_profile_result->set_algorithm(algo_desc);
3286     output_profile_result->set_elapsed_time_in_ms(
3287         timer->GetElapsedMilliseconds());
3288     output_profile_result->set_scratch_size(scratch.size());
3289   }
3290 
3291   return port::Status::OK();
3292 }
3293 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3294 bool CudnnSupport::GetConvolveAlgorithms(
3295     bool with_winograd_nonfused, int cc_major, int cc_minor,
3296     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3297   // Preload sub libs for cudnn 8.0.4+
3298 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3299   cudnnOpsInferVersionCheck();
3300   cudnnCnnInferVersionCheck();
3301 #endif
3302   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3303   out_algorithms->clear();
3304 
3305   std::vector<dnn::AlgorithmDesc::Index> algo_types;
3306   if (ConvUseDefaultAlgorithm()) {
3307     // Force a fallback algorithm.
3308     algo_types = {CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM};
3309   } else {
3310     algo_types = {
3311         // clang-format off
3312     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
3313     CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
3314     CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
3315     CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
3316     CUDNN_CONVOLUTION_FWD_ALGO_FFT,
3317     CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
3318         // clang-format on
3319     };
3320     if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
3321       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
3322     }
3323     if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3324       algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
3325     }
3326   }
3327 
3328   // The algorithms are intentionally ordered for deterministic operation
3329   for (auto i : algo_types) {
3330     if (tensor_op_math_available) {
3331       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3332     }
3333     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3334   }
3335 
3336   return true;
3337 }
3338 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3339 bool CudnnSupport::GetRnnAlgorithms(
3340     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3341   // Preload sub libs for cudnn 8.0.4+
3342 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3343   cudnnOpsInferVersionCheck();
3344   cudnnOpsTrainVersionCheck();
3345   cudnnAdvInferVersionCheck();
3346   cudnnAdvTrainVersionCheck();
3347 #endif
3348   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3349       // clang-format off
3350     CUDNN_RNN_ALGO_STANDARD,
3351     CUDNN_RNN_ALGO_PERSIST_STATIC,
3352     CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
3353       // clang-format on
3354   };
3355 
3356   out_algorithms->clear();
3357   for (auto i : algo_types) {
3358     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3359     out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3360   }
3361   return true;
3362 }
3363 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3364 bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
3365     bool with_winograd_nonfused, int cc_major, int cc_minor,
3366     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3367   // Preload sub libs for cudnn 8.0.4+
3368 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3369   cudnnOpsInferVersionCheck();
3370   cudnnOpsTrainVersionCheck();
3371   cudnnCnnInferVersionCheck();
3372   cudnnCnnTrainVersionCheck();
3373 #endif
3374   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3375   out_algorithms->clear();
3376 
3377   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3378       // clang-format off
3379     CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
3380     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
3381     CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
3382     CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
3383       // clang-format on
3384   };
3385   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3386     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
3387   }
3388   if (!RequireCudnnDeterminism()) {
3389     algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0);
3390   }
3391 
3392   // The algorithms are intentionally ordered for deterministic operation
3393   for (auto i : algo_types) {
3394     if (tensor_op_math_available) {
3395       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3396     }
3397     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3398   }
3399 
3400   return true;
3401 }
3402 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3403 bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
3404     bool with_winograd_nonfused, int cc_major, int cc_minor,
3405     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3406   // Preload sub libs for cudnn 8.0.4+
3407 #if CUDNN_MAJOR >= 8 && (CUDNN_MINOR > 0 || CUDNN_PATCHLEVEL >= 4)
3408   cudnnOpsInferVersionCheck();
3409   cudnnOpsTrainVersionCheck();
3410   cudnnCnnInferVersionCheck();
3411   cudnnCnnTrainVersionCheck();
3412 #endif
3413   bool tensor_op_math_available = TensorOpMathAvailable(cc_major);
3414   out_algorithms->clear();
3415 
3416   std::vector<dnn::AlgorithmDesc::Index> algo_types = {
3417       // clang-format off
3418       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
3419       CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
3420       // Based on cudnn.h, the following is not implemented.
3421       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
3422 
3423       // Produces incorrect results for some shapes. Disabled for now, see
3424       // NVIDIA bug 2072856. TODO(csigg): Only disable for subset of shapes.
3425       // CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
3426       // clang-format on
3427   };
3428   if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
3429     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
3430   }
3431   if (!RequireCudnnDeterminism()) {
3432     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0);
3433     algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3);
3434   }
3435 
3436   // The algorithms are intentionally ordered for deterministic operation
3437   for (auto i : algo_types) {
3438     if (tensor_op_math_available) {
3439       out_algorithms->push_back({i, /*use_tensor_ops=*/true});
3440     }
3441     out_algorithms->push_back({i, /*use_tensor_ops=*/false});
3442   }
3443 
3444   return true;
3445 }
3446 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3447 bool CudnnSupport::DoBatchNormalizationForward(
3448     Stream* stream, const DeviceMemory<float>& x,
3449     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3450     const DeviceMemory<float>& estimated_mean,
3451     const DeviceMemory<float>& estimated_variance,
3452     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3453     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3454     const double exponential_average_factor,
3455     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3456     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3457     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3458     bool is_training, ScratchAllocator* reserve_space_allocator,
3459     ScratchAllocator* workspace_allocator) {
3460   return IsStatusOk(
3461       DoBatchNormalizationForwardImpl<float, float>(
3462           stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
3463           offset, estimated_mean, estimated_variance, side_input, x_desc,
3464           scale_offset_desc, epsilon, exponential_average_factor,
3465           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3466           is_training, reserve_space_allocator, workspace_allocator),
3467       /*report_error=*/true);
3468 }
3469 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3470 bool CudnnSupport::DoBatchNormalizationForward(
3471     Stream* stream, const DeviceMemory<Eigen::half>& x,
3472     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3473     const DeviceMemory<float>& estimated_mean,
3474     const DeviceMemory<float>& estimated_variance,
3475     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3476     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3477     const double exponential_average_factor,
3478     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3479     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3480     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3481     bool is_training, ScratchAllocator* reserve_space_allocator,
3482     ScratchAllocator* workspace_allocator) {
3483   return IsStatusOk(
3484       DoBatchNormalizationForwardImpl<Eigen::half, float>(
3485           stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3486           estimated_mean, estimated_variance, side_input, x_desc,
3487           scale_offset_desc, epsilon, exponential_average_factor,
3488           activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
3489           is_training, reserve_space_allocator, workspace_allocator),
3490       /*report_error=*/true);
3491 }
3492 
3493 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<U> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3494 port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
3495     Stream* stream, dnn::DataType input_data_type,
3496     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3497     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3498     const DeviceMemory<U>& estimated_mean,
3499     const DeviceMemory<U>& estimated_variance,
3500     const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
3501     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3502     const double exponential_average_factor,
3503     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3504     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3505     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3506     bool is_training, ScratchAllocator* reserve_space_allocator,
3507     ScratchAllocator* workspace_allocator) {
3508   CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
3509   CudnnTensorDescriptor scale_offset_descriptor(
3510       scale_offset_desc, ToCudnnDataType(scale_data_type));
3511   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3512   if (BatchnormSpatialPersistentEnabled() && is_training) {
3513     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3514   }
3515   float one = 1.0;
3516   float zero = 0.0;
3517   auto cudnn = cudnn_->GetHandle(parent_, stream);
3518 
3519   DeviceMemory<uint8> workspace;
3520   DeviceMemory<uint8> reserve_space;
3521 
3522 #if CUDNN_VERSION >= 7402
3523   const auto get_bn_ops = [&]() -> cudnnBatchNormOps_t {
3524     if (side_input.is_null()) {
3525       return activation_mode == dnn::ActivationMode::kNone
3526                  ? CUDNN_BATCHNORM_OPS_BN
3527                  : CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
3528     } else {
3529       return CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
3530     }
3531   };
3532   const cudnnBatchNormOps_t bn_ops = get_bn_ops();
3533 
3534   // We use Nan propagation to be consistent with CudnnSupport::DoActivate(...).
3535   CudnnActivationDescriptor activation_desc(
3536       activation_mode, CUDNN_PROPAGATE_NAN, x_desc.value_max());
3537 
3538   if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3539     SE_ASSIGN_OR_RETURN(
3540         workspace,
3541         CreateBatchNormForwardWorkspace(
3542             stream, cudnn, mode, bn_ops, activation_desc.handle(), x_descriptor,
3543             scale_offset_descriptor, workspace_allocator))
3544     if (is_training) {
3545       size_t reserve_space_size_in_bytes = 0;
3546       RETURN_IF_CUDNN_ERROR(
3547           cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
3548               /*handle=*/cudnn.handle(), /*mode=*/mode, /*bnOps=*/bn_ops,
3549               /*activationDesc=*/activation_desc.handle(),
3550               /*xDesc=*/x_descriptor.handle(),
3551               /*sizeInBytes=*/&reserve_space_size_in_bytes));
3552       SE_ASSIGN_OR_RETURN(reserve_space, reserve_space_allocator->AllocateBytes(
3553                                              reserve_space_size_in_bytes));
3554     }
3555   }
3556 #endif
3557 
3558   auto check_no_side_input_or_activation = [&]() -> port::Status {
3559     if (activation_mode != dnn::ActivationMode::kNone ||
3560         !side_input.is_null()) {
3561       return port::Status(
3562           port::error::INTERNAL,
3563           absl::StrCat(
3564               "Side input and activation are not supported by cuDNN version: ",
3565               CUDNN_VERSION));
3566     } else {
3567       return port::Status::OK();
3568     }
3569   };
3570 
3571   if (is_training) {
3572     CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
3573         << "batch_mean and batch_var must both be null or both be non-null";
3574 
3575     void* batch_mean_opaque;
3576     void* batch_var_opaque;
3577     if (!batch_mean->is_null() && !batch_var->is_null()) {
3578       if (exponential_average_factor == 1.0) {
3579         stream->ThenMemZero(batch_mean, batch_mean->size());
3580         stream->ThenMemZero(batch_var, batch_var->size());
3581       }
3582       batch_mean_opaque = batch_mean->opaque();
3583       batch_var_opaque = batch_var->opaque();
3584     } else {
3585       batch_mean_opaque = nullptr;
3586       batch_var_opaque = nullptr;
3587     }
3588 
3589     bool called = false;
3590 #if CUDNN_VERSION >= 7402
3591     if (reserve_space_allocator != nullptr && workspace_allocator != nullptr) {
3592       called = true;
3593       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTrainingEx(
3594           /*handle=*/cudnn.handle(),
3595           /*mode=*/mode,
3596           /*bnOps=*/bn_ops,
3597           /*alpha=*/&one,
3598           /*beta=*/&zero,
3599           /*xDesc=*/x_descriptor.handle(),
3600           /*xData=*/x.opaque(),
3601           /*zDesc=*/x_descriptor.handle(),
3602           /*zData=*/side_input.opaque(),
3603           /*yDesc=*/x_descriptor.handle(),
3604           /*yData=*/y->opaque(),
3605           /*bnScaleBiasMeanVarDesc=*/scale_offset_descriptor.handle(),
3606           /*bnScale=*/scale.opaque(),
3607           /*bnBias=*/offset.opaque(),
3608           /*exponentialAverageFactor=*/exponential_average_factor,
3609           /*resultRunningMean=*/batch_mean_opaque,
3610           /*resultRunningVariance=*/batch_var_opaque,
3611           /*epsilon=*/epsilon,
3612           /*resultSaveMean=*/saved_mean->opaque(),
3613           /*resultSaveInvVariance=*/saved_inv_var->opaque(),
3614           /*activationDesc=*/activation_desc.handle(),
3615           /*workspace=*/workspace.opaque(),
3616           /*workSpaceSizeInBytes=*/workspace.size(),
3617           /*reserveSpace=*/reserve_space.opaque(),
3618           /*reserveSpaceSizeInBytes=*/reserve_space.size()));
3619     }
3620 #endif
3621     if (!called) {
3622       SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3623       RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
3624           cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3625           x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3626           scale.opaque(), offset.opaque(), exponential_average_factor,
3627           batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(),
3628           saved_inv_var->opaque()));
3629     }
3630   } else {
3631     const void* maybe_inv_var = estimated_variance.opaque();
3632     SE_RETURN_IF_ERROR(check_no_side_input_or_activation());
3633     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
3634         cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3635         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3636         scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
3637         epsilon));
3638   }
3639   return port::Status::OK();
3640 }
3641 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3642 bool CudnnSupport::DoBatchNormalizationBackward(
3643     Stream* stream, const DeviceMemory<float>& y_backprop,
3644     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3645     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3646     const dnn::BatchDescriptor& x_desc,
3647     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3648     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3649     DeviceMemory<float>* offset_backprop,
3650     DeviceMemory<uint8>* reserve_space_data,
3651     ScratchAllocator* workspace_allocator) {
3652   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3653                         stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
3654                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3655                         epsilon, x_backprop, scale_backprop, offset_backprop,
3656                         reserve_space_data, workspace_allocator),
3657                     /*report_error=*/true);
3658 }
3659 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3660 bool CudnnSupport::DoBatchNormalizationBackward(
3661     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3662     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3663     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3664     const dnn::BatchDescriptor& x_desc,
3665     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3666     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3667     DeviceMemory<float>* offset_backprop,
3668     DeviceMemory<uint8>* reserve_space_data,
3669     ScratchAllocator* workspace_allocator) {
3670   return IsStatusOk(DoBatchNormalizationBackwardImpl(
3671                         stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
3672                         x, scale, mean, inv_var, x_desc, scale_offset_desc,
3673                         epsilon, x_backprop, scale_backprop, offset_backprop,
3674                         reserve_space_data, workspace_allocator),
3675                     /*report_error=*/true);
3676 }
3677 
3678 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int cudnn_input_type,int cudnn_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3679 port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
3680     Stream* stream, int cudnn_input_type, int cudnn_scale_type,
3681     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3682     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3683     const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
3684     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3685     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3686     DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
3687     ScratchAllocator* workspace_allocator) {
3688   CudnnTensorDescriptor x_descriptor(
3689       x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
3690   CudnnTensorDescriptor scale_offset_descriptor(
3691       scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
3692   cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
3693   if (BatchnormSpatialPersistentEnabled()) {
3694     mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
3695   }
3696   float one = 1.0;
3697   float zero = 0.0;
3698 
3699   auto cudnn = cudnn_->GetHandle(parent_, stream);
3700 
3701   bool called = false;
3702 #if CUDNN_VERSION >= 7402
3703   if (reserve_space_data != nullptr && workspace_allocator != nullptr) {
3704     called = true;
3705     const cudnnBatchNormOps_t bn_ops = CUDNN_BATCHNORM_OPS_BN;
3706     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
3707                         CreateBatchNormBackwardWorkspace(
3708                             stream, cudnn, mode, bn_ops, x_descriptor,
3709                             scale_offset_descriptor, workspace_allocator))
3710     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackwardEx(
3711         /*handle=*/cudnn.handle(),
3712         /*mode=*/mode,
3713         /*bnOps=*/bn_ops,
3714         /*alphaDataDiff=*/&one,
3715         /*betaDataDiff=*/&zero,
3716         /*alphaParamDiff=*/&one,
3717         /*betaParamDiff=*/&zero,
3718         /*xDesc=*/x_descriptor.handle(),
3719         /*xData=*/x.opaque(),
3720         /*yDesc=*/nullptr,
3721         /*yData=*/nullptr,
3722         /*dyDesc=*/x_descriptor.handle(),
3723         /*dyData=*/y_backprop.opaque(),
3724         /*dzDesc=*/nullptr,
3725         /*dzData=*/nullptr,
3726         /*dxDesc=*/x_descriptor.handle(),
3727         /*dxData=*/x_backprop->opaque(),
3728         /*dBnScaleBiasDesc=*/scale_offset_descriptor.handle(),
3729         /*bnScaleData=*/scale.opaque(),
3730         /*bnBiasData=*/nullptr,
3731         /*dBnScaleData=*/scale_backprop->opaque(),
3732         /*dBnBiasData=*/offset_backprop->opaque(),
3733         /*epsilon=*/epsilon,
3734         /*savedMean=*/mean.opaque(),
3735         /*savedInvVariance=*/inv_var.opaque(),
3736         /*activationDesc=*/nullptr,
3737         /*workspace=*/workspace.opaque(),
3738         /*workSpaceSizeInBytes=*/workspace.size(),
3739         /*reserveSpace=*/reserve_space_data->opaque(),
3740         /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
3741   }
3742 #endif
3743   if (!called) {
3744     RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
3745         cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3746         x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3747         x_descriptor.handle(), x_backprop->opaque(),
3748         scale_offset_descriptor.handle(), scale.opaque(),
3749         scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3750         mean.opaque(), inv_var.opaque()));
3751   }
3752 
3753   return port::Status::OK();
3754 }
3755 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3756 port::Status CudnnSupport::DoFusedConvolve(
3757     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3758     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3759     const dnn::FilterDescriptor& filter_descriptor,
3760     const DeviceMemory<double>& filter_data,
3761     const dnn::ConvolutionDescriptor& convolution_descriptor,
3762     const DeviceMemory<double>& side_input_data, double side_input_scale,
3763     const dnn::BatchDescriptor& bias_descriptor,
3764     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3765     const dnn::BatchDescriptor& output_descriptor,
3766     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3767     const dnn::AlgorithmConfig& algorithm_config,
3768     dnn::ProfileResult* output_profile_result) {
3769   return DoFusedConvolveImpl(
3770       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3771       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3772       side_input_scale, bias_descriptor, biases, activation_mode,
3773       output_descriptor, output_data,
3774       GetConvAccumulatorType(dnn::DataType::kDouble), scratch_allocator,
3775       algorithm_config, output_profile_result);
3776 }
3777 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3778 port::Status CudnnSupport::DoFusedConvolve(
3779     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3780     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3781     const dnn::FilterDescriptor& filter_descriptor,
3782     const DeviceMemory<float>& filter_data,
3783     const dnn::ConvolutionDescriptor& convolution_descriptor,
3784     const DeviceMemory<float>& side_input_data, float side_input_scale,
3785     const dnn::BatchDescriptor& bias_descriptor,
3786     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3787     const dnn::BatchDescriptor& output_descriptor,
3788     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3789     const dnn::AlgorithmConfig& algorithm_config,
3790     dnn::ProfileResult* output_profile_result) {
3791   return DoFusedConvolveImpl(
3792       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3793       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3794       side_input_scale, bias_descriptor, biases, activation_mode,
3795       output_descriptor, output_data,
3796       GetConvAccumulatorType(dnn::DataType::kFloat), scratch_allocator,
3797       algorithm_config, output_profile_result);
3798 }
3799 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3800 port::Status CudnnSupport::DoFusedConvolve(
3801     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3802     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3803     const dnn::FilterDescriptor& filter_descriptor,
3804     const DeviceMemory<Eigen::half>& filter_data,
3805     const dnn::ConvolutionDescriptor& convolution_descriptor,
3806     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3807     const dnn::BatchDescriptor& bias_descriptor,
3808     const DeviceMemory<Eigen::half>& biases,
3809     dnn::ActivationMode activation_mode,
3810     const dnn::BatchDescriptor& output_descriptor,
3811     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3812     const dnn::AlgorithmConfig& algorithm_config,
3813     dnn::ProfileResult* output_profile_result) {
3814   return DoFusedConvolveImpl(
3815       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3816       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3817       side_input_scale, bias_descriptor, biases, activation_mode,
3818       output_descriptor, output_data,
3819       GetConvAccumulatorType(dnn::DataType::kHalf), scratch_allocator,
3820       algorithm_config, output_profile_result);
3821 }
3822 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3823 port::Status CudnnSupport::DoFusedConvolve(
3824     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3825     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3826     const dnn::FilterDescriptor& filter_descriptor,
3827     const DeviceMemory<int8>& filter_data,
3828     const dnn::ConvolutionDescriptor& convolution_descriptor,
3829     const DeviceMemory<int8>& side_input_data, float side_input_scale,
3830     const dnn::BatchDescriptor& bias_descriptor,
3831     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3832     const dnn::BatchDescriptor& output_descriptor,
3833     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3834     const dnn::AlgorithmConfig& algorithm_config,
3835     dnn::ProfileResult* output_profile_result) {
3836   int cc_major, cc_minor;
3837   std::tie(cc_major, cc_minor) = GetCcMajorMinor(stream);
3838 
3839   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3840     return port::UnimplementedError(
3841         "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
3842         "GPUs with compute capability 6.1 or later.");
3843   }
3844 
3845   return DoFusedConvolveImpl(
3846       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3847       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3848       side_input_scale, bias_descriptor, biases, activation_mode,
3849       output_descriptor, output_data,
3850       GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3851       algorithm_config, output_profile_result);
3852 }
3853 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3854 port::Status CudnnSupport::DoFusedConvolve(
3855     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3856     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3857     const dnn::FilterDescriptor& filter_descriptor,
3858     const DeviceMemory<int8>& filter_data,
3859     const dnn::ConvolutionDescriptor& convolution_descriptor,
3860     const DeviceMemory<float>& side_input_data, float side_input_scale,
3861     const dnn::BatchDescriptor& bias_descriptor,
3862     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3863     const dnn::BatchDescriptor& output_descriptor,
3864     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3865     const dnn::AlgorithmConfig& algorithm_config,
3866     dnn::ProfileResult* output_profile_result) {
3867   int cc_major, cc_minor;
3868   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
3869                                                                    &cc_minor);
3870   if (cc_major < 6 || (cc_major == 6 && cc_minor < 1)) {
3871     return port::UnimplementedError(
3872         "cudnnConvolutionBiasActivationForward() for int8 is only supported on "
3873         "GPUs with compute capability 6.1 or later.");
3874   }
3875 
3876   return DoFusedConvolveImpl(
3877       stream, conv_input_descriptor, conv_input_data, conv_input_scale,
3878       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
3879       side_input_scale, bias_descriptor, biases, activation_mode,
3880       output_descriptor, output_data,
3881       GetConvAccumulatorType(dnn::DataType::kInt8), scratch_allocator,
3882       algorithm_config, output_profile_result);
3883 }
3884 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)3885 port::Status CudnnSupport::DoPrepareForCtcLoss(
3886     Stream* stream, dnn::DataType element_type,
3887     const dnn::RnnStateTensorDescriptor& probs_desc,
3888     const dnn::RnnStateTensorDescriptor& grads_desc,
3889     absl::Span<const int> labels_data,
3890     absl::Span<const int> labels_lengths_data,
3891     absl::Span<const int> input_lengths_data,
3892     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
3893     int* ctc_loss_algo_id) {
3894   auto cudnn = cudnn_->GetHandle(parent_, stream);
3895   // Query the workspace size.
3896   size_t workspace_size_in_bytes = 0;
3897 #if CUDNN_VERSION >= 7603
3898   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3899   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3900       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3901   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3902       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3903 
3904   // Try running with `algo`, if successful then pick it. The non-deterministic
3905   // algorithm is first and thus preferentially picked when determinism is not
3906   // required.
3907   auto algo = RequireCudnnDeterminism() ? CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
3908                                         : CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC;
3909   cudnnStatus_t status = cudnnGetCTCLossWorkspaceSize(
3910       /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3911       /*gradientsDesc=*/cudnn_grads_desc.handle(),
3912       /*labels=*/labels_data.data(),
3913       /*labelLengths=*/labels_lengths_data.data(),
3914       /*inputLengths=*/input_lengths_data.data(),
3915       /*algo=*/algo,
3916       /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3917       /*sizeInBytes=*/&workspace_size_in_bytes);
3918   if (RequireCudnnDeterminism()) {
3919     RETURN_IF_CUDNN_ERROR(status);
3920   }
3921 
3922   if (status != CUDNN_STATUS_SUCCESS) {
3923     algo = CUDNN_CTC_LOSS_ALGO_DETERMINISTIC;
3924     RETURN_IF_CUDNN_ERROR(cudnnGetCTCLossWorkspaceSize(
3925         /*handle=*/cudnn.handle(), /*probsDesc=*/cudnn_probs_desc.handle(),
3926         /*gradientsDesc=*/cudnn_grads_desc.handle(),
3927         /*labels=*/labels_data.data(),
3928         /*labelLengths=*/labels_lengths_data.data(),
3929         /*inputLengths=*/input_lengths_data.data(),
3930         /*algo=*/algo,
3931         /*ctcLossDesc=*/cudnn_ctc_loss_desc.handle(),
3932         /*sizeInBytes=*/&workspace_size_in_bytes));
3933   }
3934   *ctc_loss_algo_id = algo;
3935 #else
3936   return port::Status(port::error::INVALID_ARGUMENT,
3937                       "No supported cudnnGetCTCLossWorkspaceSize when "
3938                       "CUDNN_VERSION < 7.6.3");
3939 #endif
3940   // Allocate the workspace.
3941   if (workspace_size_in_bytes == 0) {
3942     *scratch_memory = DeviceMemory<uint8>();
3943     return port::Status::OK();
3944   }
3945   const auto scratch_or =
3946       scratch_allocator->AllocateBytes(workspace_size_in_bytes);
3947   if (scratch_or.ok()) {
3948     *scratch_memory = scratch_or.ValueOrDie();
3949     return port::Status::OK();
3950   }
3951   return port::InternalError(
3952       "Failed to allocate scratch memory for the CuDNN CTC Loss");
3953 }
3954 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)3955 port::Status CudnnSupport::DoCtcLoss(
3956     Stream* stream, dnn::DataType element_type,
3957     const dnn::RnnStateTensorDescriptor& probs_desc,
3958     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
3959     absl::Span<const int> labels_lengths_data,
3960     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
3961     const dnn::RnnStateTensorDescriptor& grads_desc,
3962     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
3963     int ctc_loss_algo_id) {
3964   // Current cuDNN CTC Loss only supports the float datatype
3965   if (CUDNN_VERSION < 7603 || element_type != dnn::DataType::kFloat) {
3966     return port::Status(port::error::INVALID_ARGUMENT,
3967                         "CudnnCtcLossDescriptor is supported only when the "
3968                         "CUDNN_VERSION >= 7.6.3 and DataType is float");
3969   }
3970   CudnnCtcLossDescriptor cudnn_ctc_loss_desc(ToCudnnDataType(element_type));
3971   const CudnnRnnStateTensorDescriptor& cudnn_probs_desc =
3972       static_cast<const CudnnRnnStateTensorDescriptor&>(probs_desc);
3973   const CudnnRnnStateTensorDescriptor& cudnn_grads_desc =
3974       static_cast<const CudnnRnnStateTensorDescriptor&>(grads_desc);
3975   return DoCtcLossImpl(stream, cudnn_probs_desc, probs_data, labels_data,
3976                        labels_lengths_data, input_lengths_data, costs_data,
3977                        cudnn_grads_desc, grads_data, cudnn_ctc_loss_desc,
3978                        scratch_memory, ctc_loss_algo_id);
3979 }
3980 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3981 bool CudnnSupport::DoTransformTensor(Stream* stream,
3982                                      const dnn::BatchDescriptor& input_desc,
3983                                      dnn::DataType input_type,
3984                                      const DeviceMemoryBase& input_data,
3985                                      const dnn::BatchDescriptor& output_desc,
3986                                      dnn::DataType output_type, float scale,
3987                                      DeviceMemoryBase* output_data) {
3988   float beta = 0.0f;
3989   CudnnTensorDescriptor input_tensor_desc(
3990       input_desc, ToCudnnDataType(input_type, input_desc.layout()));
3991   CudnnTensorDescriptor output_tensor_desc(
3992       output_desc, ToCudnnDataType(output_type, output_desc.layout()));
3993   auto cudnn = cudnn_->GetHandle(parent_, stream);
3994   const auto status = [&] {
3995     RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
3996         cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
3997         &beta, output_tensor_desc.handle(), output_data->opaque()));
3998     return port::Status::OK();
3999   }();
4000   return IsStatusOk(status, /*report_error=*/true);
4001 }
4002 
4003 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)4004 port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
4005     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4006     const DeviceMemory<T>& input_data,
4007     const dnn::BatchDescriptor& bias_descriptor,
4008     DeviceMemory<T>* backward_bias_data) {
4009   cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
4010   CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
4011   CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
4012 
4013   // Alpha is the scaling factor for input.
4014   float alpha = 1.0;
4015   // Beta is the scaling factor for output.
4016   float beta = 0.0;
4017 
4018   auto cudnn = cudnn_->GetHandle(parent_, stream);
4019   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
4020       cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
4021       bias_nd.handle(), backward_bias_data->opaque()));
4022   return port::Status::OK();
4023 }
4024 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)4025 bool CudnnSupport::DoConvolveBackwardBias(
4026     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4027     const DeviceMemory<double>& input_data,
4028     const dnn::BatchDescriptor& bias_descriptor,
4029     DeviceMemory<double>* backward_bias_data) {
4030   return IsStatusOk(
4031       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4032                                  bias_descriptor, backward_bias_data),
4033       /*report_error=*/true);
4034 }
4035 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)4036 bool CudnnSupport::DoConvolveBackwardBias(
4037     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4038     const DeviceMemory<float>& input_data,
4039     const dnn::BatchDescriptor& bias_descriptor,
4040     DeviceMemory<float>* backward_bias_data) {
4041   return IsStatusOk(
4042       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4043                                  bias_descriptor, backward_bias_data),
4044       /*report_error=*/true);
4045 }
4046 
DoConvolveBackwardBias(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)4047 bool CudnnSupport::DoConvolveBackwardBias(
4048     Stream* stream, const dnn::BatchDescriptor& input_descriptor,
4049     const DeviceMemory<Eigen::half>& input_data,
4050     const dnn::BatchDescriptor& bias_descriptor,
4051     DeviceMemory<Eigen::half>* backward_bias_data) {
4052   return IsStatusOk(
4053       DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
4054                                  bias_descriptor, backward_bias_data),
4055       /*report_error=*/true);
4056 }
4057 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4058 bool CudnnSupport::DoMatMul(Stream* stream,
4059                             const DeviceMemory<float>& input_data,
4060                             const DeviceMemory<float>& weights,
4061                             const dnn::BatchDescriptor& input_dimensions,
4062                             const dnn::BatchDescriptor& output_dimensions,
4063                             DeviceMemory<float>* output_data) {
4064   if (input_dimensions.count() != output_dimensions.count()) {
4065     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
4066     return false;
4067   }
4068 
4069   // We do not permute the input or output, instead we just
4070   // reinterpret the layout. We are working with row-major matrices
4071   // and the rows of the input and output correspond to batch, so
4072   // batch has to be outermost in both the input and output.
4073   //
4074   // By adding transposes to the BLAS gemm call we could perhaps make
4075   // the kYXDepthBatch layout work as well, but there has been no need
4076   // for that so far.
4077   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4078       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4079     LOG(ERROR) << "Unsupported MatMul input layout.";
4080     return false;
4081   }
4082   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4083       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4084     LOG(ERROR) << "Unsupported MatMul output layout.";
4085     return false;
4086   }
4087 
4088   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
4089     // This is a fast path that also supports the kBatchYXDepth layout.
4090 
4091     // The matrices here are in row-major format while BLAS expects
4092     // column-major, i.e. our matrices are transposed as far as BLAS
4093     // is concerned. So we need to compute output^T =
4094     // input^T*weights^T. There is no parameter for transposing the
4095     // output in BLAS gemm, but instead we can transpose both sides of
4096     // the equality to see that this is equivalent to
4097     // output=weights*input. So we only need to swap the order of
4098     // weights and input in the matrix product to correct for the
4099     // row-major versus column-major difference.
4100     const float alpha = 1.0f;  // Take the matrix product without scaling it.
4101     const float beta = 0.0f;   // Ignore the original values in output_data.
4102     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
4103     const int64 n = input_dimensions.count();
4104     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
4105     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
4106                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
4107                          m, input_data, k, beta, output_data, m);
4108   } else {
4109     // This is a slower and more complex path that supports output
4110     // width() * height() > 1, though it only supports the
4111     // kBatchYXDepth layout. Does support kBatchDepthYX if output
4112     // feature_map_count() == 1, as then there is no difference
4113     // between the two layouts.
4114     //
4115     // The operation here is the same as above, except that we have to
4116     // do the matrix multiplication for each (y,x) output coordinate
4117     // separately. We then interpret weights as containing K = width()
4118     // * height() different matrices, which we all multiply onto the
4119     // matrix from input_data, yielding K matrix products. We then
4120     // combine these together into one matrix by concatenating all the
4121     // first rows of these matrices, then all the seconds rows and so
4122     // on. We can do this with a batched matrix multiplication, where
4123     // the result is written to a different submatrix of the output
4124     // for each matrix multiplication.
4125     //
4126     // The reason that we only support the kBatchYXDepth output layout
4127     // is that we have to do something in the depth for each (y,x)
4128     // coordinate. The kBatchYXDepth layout has the depth information
4129     // for each point (y,x) in contiguous memory while the
4130     // kBatchDepthYX layout does not.
4131     //
4132     // TODO(broune): Consider a special case for when output depth ==
4133     // 1, as then possibly this could all be done as one matrix
4134     // multiplication instead of a batched one, which should be
4135     // faster. Another possibility would be to add a weights layout
4136     // parameter and then support kBatchDepthYX for a different
4137     // weights layout.
4138     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
4139         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
4140           output_dimensions.feature_map_count() == 1)) {
4141       LOG(ERROR) << "Unsupported MatMul output layout.";
4142       return false;
4143     }
4144 
4145     const float alpha = 1.0f;  // Take the matrix product without scaling it.
4146     const float beta = 0.0f;   // Ignore the original values in output_data.
4147     const uint64 m = output_dimensions.feature_map_count();
4148     const uint64 n = input_dimensions.count();
4149     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
4150     const int lda = m;
4151     const int ldb = k;
4152     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
4153     const int batch_count = output_dimensions.NodesPerFeatureMap();
4154 
4155     std::vector<DeviceMemory<float>> a(batch_count);
4156     std::vector<DeviceMemory<float>> b(batch_count);
4157     std::vector<DeviceMemory<float>> c(batch_count);
4158     for (int i = 0; i < batch_count; ++i) {
4159       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
4160                                  output_dimensions.feature_map_count();
4161       a[i] = DeviceMemory<float>::MakeFromByteSize(
4162           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
4163               weights_offset,
4164           weights.ElementCount() - weights_offset);
4165 
4166       b[i] = input_data;
4167 
4168       const int output_offset = i * output_dimensions.feature_map_count();
4169       c[i] = DeviceMemory<float>::MakeFromByteSize(
4170           const_cast<float*>(
4171               reinterpret_cast<const float*>(output_data->opaque())) +
4172               output_offset,
4173           output_data->ElementCount() - output_offset);
4174     }
4175     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
4176       std::vector<DeviceMemory<float>*> ptrs;
4177       ptrs.reserve(v.size());
4178       for (auto& mem : v) {
4179         ptrs.push_back(&mem);
4180       }
4181       return ptrs;
4182     };
4183 
4184     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
4185                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
4186                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
4187                                 ldc, batch_count);
4188   }
4189 
4190   return stream->ok();
4191 }
4192 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)4193 bool CudnnSupport::DoBiasAdd(Stream* stream,
4194                              const DeviceMemory<float>& input_data,
4195                              const DeviceMemory<float>& biases,
4196                              const dnn::BatchDescriptor& dimensions,
4197                              DeviceMemory<float>* output_data) {
4198   CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
4199 
4200   dnn::BatchDescriptor bias_dimensions;
4201   bias_dimensions.set_count(1)
4202       .set_feature_map_count(dimensions.feature_map_count())
4203       .set_height(1)
4204       .set_width(1)
4205       .set_layout(dnn::DataLayout::kBatchYXDepth);
4206   CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
4207 
4208   // cudnnAddTensor after R3 is in-place, so we need to copy input_data to
4209   // output_data before doing the addition, unless the input and
4210   // output are at the same address.
4211   if (input_data.opaque() != output_data->opaque()) {
4212     stream->ThenMemcpy(output_data, input_data,
4213                        dimensions.ElementCount() * sizeof(float));
4214     if (!stream->ok()) {
4215       LOG(ERROR)
4216           << "stream " << stream
4217           << " could not enqueue a tensor copy as part of bias addition.";
4218       return false;
4219     }
4220   }
4221 
4222   const float alpha = 1.0f;
4223   const float beta = 1.0f;
4224 
4225   auto cudnn = cudnn_->GetHandle(parent_, stream);
4226 
4227   const auto status = [&] {
4228     RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
4229         cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
4230         &beta, input_descriptor.handle(), output_data->opaque()));
4231     return port::Status::OK();
4232   }();
4233   return IsStatusOk(status, /*report_error=*/true);
4234 }
4235 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)4236 bool CudnnSupport::DoActivate(Stream* stream,
4237                               dnn::ActivationMode activation_mode,
4238                               const dnn::BatchDescriptor& dimensions,
4239                               const DeviceMemory<float>& input_data,
4240                               DeviceMemory<float>* output_data,
4241                               uint64 options) {
4242   CudnnActivationDescriptor activation_desc(
4243       activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
4244 
4245   CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
4246   // Alpha is the input scaling factor.
4247   float alpha = 1.0;
4248   // Beta is the output scaling factor.
4249   float beta = 0.0;
4250 
4251   auto cudnn = cudnn_->GetHandle(parent_, stream);
4252   const auto status = [&] {
4253     RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
4254         cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
4255         input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
4256     return port::Status::OK();
4257   }();
4258   return IsStatusOk(status, /*report_error=*/true);
4259 }
4260 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)4261 bool CudnnSupport::DoPoolForward(
4262     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4263     const dnn::BatchDescriptor& input_dimensions,
4264     const DeviceMemory<double>& input_data,
4265     const dnn::BatchDescriptor& output_dimensions,
4266     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
4267   // Alpha is the scaling factor for input.
4268   double alpha = 1.0;
4269   // Beta is the scaling factor for output.
4270   double beta = 0.0;
4271 
4272   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4273   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4274   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4275 
4276   auto cudnn = cudnn_->GetHandle(parent_, stream);
4277   const auto status = [&] {
4278     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4279         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4280         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4281     return port::Status::OK();
4282   }();
4283   return IsStatusOk(status, /*report_error=*/true);
4284 }
4285 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)4286 bool CudnnSupport::DoPoolForward(
4287     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4288     const dnn::BatchDescriptor& input_dimensions,
4289     const DeviceMemory<float>& input_data,
4290     const dnn::BatchDescriptor& output_dimensions,
4291     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
4292   // Alpha is the scaling factor for input.
4293   float alpha = 1.0;
4294   // Beta is the scaling factor for output.
4295   float beta = 0.0;
4296 
4297   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4298   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4299   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4300 
4301   auto cudnn = cudnn_->GetHandle(parent_, stream);
4302   const auto status = [&] {
4303     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4304         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4305         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4306     return port::Status::OK();
4307   }();
4308   return IsStatusOk(status, /*report_error=*/true);
4309 }
4310 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4311 bool CudnnSupport::DoPoolForward(
4312     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4313     const dnn::BatchDescriptor& input_dimensions,
4314     const DeviceMemory<Eigen::half>& input_data,
4315     const dnn::BatchDescriptor& output_dimensions,
4316     DeviceMemory<Eigen::half>* output_data,
4317     ScratchAllocator* workspace_allocator) {
4318   // Alpha is the scaling factor for input.
4319   float alpha = 1.0;
4320   // Beta is the scaling factor for output.
4321   float beta = 0.0;
4322 
4323   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4324   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4325   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4326   auto cudnn = cudnn_->GetHandle(parent_, stream);
4327   const auto status = [&] {
4328     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4329         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4330         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4331     return port::Status::OK();
4332   }();
4333   return IsStatusOk(status, /*report_error=*/true);
4334 }
4335 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)4336 bool CudnnSupport::DoPoolForward(
4337     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4338     const dnn::BatchDescriptor& input_dimensions,
4339     const DeviceMemory<int8>& input_data,
4340     const dnn::BatchDescriptor& output_dimensions,
4341     DeviceMemory<int8>* output_data, ScratchAllocator* workspace_allocator) {
4342   // Alpha is the scaling factor for input.
4343   float alpha = 1.0;
4344   // Beta is the scaling factor for output.
4345   float beta = 0.0;
4346 
4347   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_INT8);
4348   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_INT8);
4349   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4350 
4351   auto cudnn = cudnn_->GetHandle(parent_, stream);
4352   const auto status = [&] {
4353     RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
4354         cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4355         input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
4356     return port::Status::OK();
4357   }();
4358   return IsStatusOk(status, /*report_error=*/true);
4359 }
4360 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4361 bool CudnnSupport::DoPoolBackward(
4362     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4363     const dnn::BatchDescriptor& input_dimensions,
4364     const DeviceMemory<double>& input_data,
4365     const dnn::BatchDescriptor& output_dimensions,
4366     const DeviceMemory<double>& output_data,
4367     const DeviceMemory<double>& input_diff_data,
4368     DeviceMemory<double>* output_diff_data,
4369     ScratchAllocator* workspace_allocator) {
4370   // Alpha is the scaling factor for input.
4371   double alpha = 1.0;
4372   // Beta is the scaling factor for output.
4373   double beta = 0.0;
4374 
4375   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
4376   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
4377   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4378 
4379   auto cudnn = cudnn_->GetHandle(parent_, stream);
4380   const auto status = [&] {
4381     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4382         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4383         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4384         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4385         output_diff_data->opaque()));
4386     return port::Status::OK();
4387   }();
4388   return IsStatusOk(status, /*report_error=*/true);
4389 }
4390 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4391 bool CudnnSupport::DoPoolBackward(
4392     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4393     const dnn::BatchDescriptor& input_dimensions,
4394     const DeviceMemory<float>& input_data,
4395     const dnn::BatchDescriptor& output_dimensions,
4396     const DeviceMemory<float>& output_data,
4397     const DeviceMemory<float>& input_diff_data,
4398     DeviceMemory<float>* output_diff_data,
4399     ScratchAllocator* workspace_allocator) {
4400   // Alpha is the scaling factor for input.
4401   float alpha = 1.0;
4402   // Beta is the scaling factor for output.
4403   float beta = 0.0;
4404 
4405   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
4406   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
4407   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4408 
4409   auto cudnn = cudnn_->GetHandle(parent_, stream);
4410   const auto status = [&] {
4411     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4412         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4413         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4414         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4415         output_diff_data->opaque()));
4416     return port::Status::OK();
4417   }();
4418   return IsStatusOk(status, /*report_error=*/true);
4419 }
4420 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4421 bool CudnnSupport::DoPoolBackward(
4422     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4423     const dnn::BatchDescriptor& input_dimensions,
4424     const DeviceMemory<Eigen::half>& input_data,
4425     const dnn::BatchDescriptor& output_dimensions,
4426     const DeviceMemory<Eigen::half>& output_data,
4427     const DeviceMemory<Eigen::half>& input_diff_data,
4428     DeviceMemory<Eigen::half>* output_diff_data,
4429     ScratchAllocator* workspace_allocator) {
4430   // Alpha is the scaling factor for input.
4431   float alpha = 1.0;
4432   // Beta is the scaling factor for output.
4433   float beta = 0.0;
4434 
4435   CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
4436   CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
4437   CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
4438 
4439   auto cudnn = cudnn_->GetHandle(parent_, stream);
4440   const auto status = [&] {
4441     RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
4442         cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4443         output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4444         src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4445         output_diff_data->opaque()));
4446     return port::Status::OK();
4447   }();
4448   return IsStatusOk(status, /*report_error=*/true);
4449 }
4450 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4451 bool CudnnSupport::DoNormalizeWithDimensions(
4452     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4453     const dnn::BatchDescriptor& dimensions,
4454     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4455   // Check for unsupported modes.
4456   if (normalize_descriptor.wrap_around()) {
4457     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4458     return false;
4459   }
4460   if (normalize_descriptor.segment_size()) {
4461     LOG(ERROR) << "CUDA LRN does not support segmentation";
4462     return false;
4463   }
4464 
4465   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4466   CudnnNormalizeDescriptor normalize(normalize_descriptor);
4467 
4468   // Alpha is the scaling factor for input.
4469   float alpha = 1.0f;
4470   // Beta is the scaling factor for output.
4471   float beta = 0.0f;
4472 
4473   auto cudnn = cudnn_->GetHandle(parent_, stream);
4474 
4475   // Launch the normalization.
4476   const auto status = [&] {
4477     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
4478         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4479         &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
4480         output_data->opaque()));
4481     return port::Status::OK();
4482   }();
4483   return IsStatusOk(status, /*report_error=*/true);
4484 }
4485 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4486 bool CudnnSupport::DoNormalizeBackwardWithDimensions(
4487     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4488     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4489     const DeviceMemory<float>& normalized_data,
4490     const DeviceMemory<float>& normalized_variable_gradient,
4491     DeviceMemory<float>* raw_variable_gradient,
4492     ScratchAllocator* workspace_allocator) {
4493   // Check for unsupported modes.
4494   if (normalize_descriptor.wrap_around()) {
4495     LOG(ERROR) << "CUDA LRN does not support cudnn-around mode";
4496     return false;
4497   }
4498   if (normalize_descriptor.segment_size()) {
4499     LOG(ERROR) << "CUDA LRN does not support segmentation";
4500     return false;
4501   }
4502 
4503   CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
4504   CudnnNormalizeDescriptor normalize(normalize_descriptor);
4505 
4506   float alpha = 1.0f;
4507   float beta = 0.0f;
4508 
4509   auto cudnn = cudnn_->GetHandle(parent_, stream);
4510   const auto status = [&] {
4511     RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
4512         cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
4513         &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
4514         normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4515         &beta, dims.handle(), raw_variable_gradient->opaque()));
4516     return port::Status::OK();
4517   }();
4518   return IsStatusOk(status, /*report_error=*/true);
4519 }
4520 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4521 bool CudnnSupport::DoDepthConcatenate(
4522     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4523     port::ArraySlice<const DeviceMemory<float>*> input_data,
4524     DeviceMemory<float>* output_data) {
4525   CHECK_EQ(input_dimensions.size(), input_data.size());
4526 
4527   for (const auto& dimensions : input_dimensions) {
4528     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4529       LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
4530                     "supports the kBatchDepthYX layout.";
4531       return false;
4532     }
4533   }
4534 
4535   if (input_dimensions.empty()) {
4536     return true;  // Nothing to do.
4537   }
4538 
4539   dnn::BatchDescriptor output_dimensions =
4540       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4541 
4542   const int64 area = output_dimensions.width() * output_dimensions.height();
4543   const auto index = [area](int64 batch, int64 depth, int64 yx,
4544                             int64 max_depth) {
4545     return (batch * max_depth + depth) * area + yx;
4546   };
4547 
4548   std::vector<float> output_host(output_dimensions.ElementCount());
4549   std::vector<float> tmp;
4550   int64 depth_sum = 0;
4551   for (size_t i = 0; i < input_data.size(); ++i) {
4552     const auto& dimensions = input_dimensions[i];
4553     tmp.resize(dimensions.ElementCount());
4554     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4555     port::Status block_status = stream->BlockHostUntilDone();
4556     if (!block_status.ok()) {
4557       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4558       return false;
4559     }
4560 
4561     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4562       for (int64 yx = 0; yx < area; ++yx) {
4563         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4564           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4565                     << yx << ' ' << depth;
4566           output_host[index(batch, depth + depth_sum, yx,
4567                             output_dimensions.feature_map_count())] =
4568               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4569         }
4570       }
4571     }
4572     depth_sum += dimensions.feature_map_count();
4573   }
4574   stream->ThenMemcpyH2D<float>(output_host, output_data);
4575   return true;
4576 }
4577 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4578 bool CudnnSupport::DoElementwiseOperate(
4579     Stream* stream, dnn::ElementwiseOperation operation,
4580     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4581     port::ArraySlice<const DeviceMemory<float>*> input_data,
4582     const dnn::BatchDescriptor& output_dimensions,
4583     DeviceMemory<float>* output_data) {
4584   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4585   return false;
4586 }
4587 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4588 bool CudnnSupport::DoXYPad(Stream* stream,
4589                            const dnn::BatchDescriptor& dimensions,
4590                            const DeviceMemory<float>& input_data,
4591                            int64 left_pad, int64 right_pad, int64 top_pad,
4592                            int64 bottom_pad, DeviceMemory<float>* output_data) {
4593   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4594   return false;
4595 }
4596 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4597 bool CudnnSupport::DoXYSlice(Stream* stream,
4598                              const dnn::BatchDescriptor& dimensions,
4599                              const DeviceMemory<float>& input_data,
4600                              int64 left_trim, int64 right_trim, int64 top_trim,
4601                              int64 bottom_trim,
4602                              DeviceMemory<float>* output_data) {
4603   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4604   return false;
4605 }
4606 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4607 bool CudnnSupport::DoMemcpyD2HQuantized(
4608     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4609     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4610   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4611   return false;
4612 }
4613 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4614 bool CudnnSupport::DoMemcpyH2DQuantized(
4615     Stream* stream, const void* host_src, int64 size,
4616     dnn::QuantizedActivationMode mode,
4617     DeviceMemory<float>* gpu_unquantized_dst) {
4618   LOG(ERROR) << "quantized memcpy not supported by cuDNN";
4619   return false;
4620 }
4621 
DeriveOutputBatchDescriptor(const dnn::BatchDescriptor & batch_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4622 bool CudnnSupport::DeriveOutputBatchDescriptor(
4623     const dnn::BatchDescriptor& batch_descriptor,
4624     const dnn::FilterDescriptor& filter_descriptor,
4625     const dnn::ConvolutionDescriptor& convolution_descriptor,
4626     dnn::BatchDescriptor* output_batch_descriptor) {
4627   CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
4628   CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
4629   CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
4630 
4631   int dn = batch_descriptor.ndims() + 2;
4632   std::vector<int> dims(dn);  // in BDYX
4633   const auto status = [&] {
4634     RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
4635         conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
4636     output_batch_descriptor->set_count(dims[0])
4637         .set_feature_map_count(dims[1])
4638         .set_layout(batch_descriptor.layout());
4639 
4640     for (int i = 0; i < batch_descriptor.ndims(); i++) {
4641       output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4642                                                dims.rbegin()[i]);
4643     }
4644     return port::Status::OK();
4645   }();
4646   return IsStatusOk(status, /*report_error=*/true);
4647 }
4648 
4649 }  // namespace gpu
4650 
initialize_cudnn()4651 void initialize_cudnn() {
4652   port::Status status =
4653       PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4654           cuda::kCudaPlatformId, gpu::kCuDnnPlugin, "cuDNN",
4655           [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4656             gpu::GpuExecutor* cuda_executor =
4657                 dynamic_cast<gpu::GpuExecutor*>(parent);
4658             if (cuda_executor == nullptr) {
4659               LOG(ERROR) << "Attempting to initialize an instance of the cuDNN "
4660                          << "support library with a non-CUDA StreamExecutor";
4661               return nullptr;
4662             }
4663 
4664             gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor);
4665             if (!dnn->Init().ok()) {
4666               // Note: Init() will log a more specific error.
4667               delete dnn;
4668               return nullptr;
4669             }
4670             return dnn;
4671           });
4672 
4673   if (!status.ok()) {
4674     LOG(ERROR) << "Unable to register cuDNN factory: "
4675                << status.error_message();
4676   }
4677 
4678   PluginRegistry::Instance()->SetDefaultFactory(
4679       cuda::kCudaPlatformId, PluginKind::kDnn, gpu::kCuDnnPlugin);
4680 }
4681 
4682 }  // namespace stream_executor
4683 
4684 #pragma clang diagnostic pop
4685 
4686 REGISTER_MODULE_INITIALIZER(register_cudnn,
4687                             { stream_executor::initialize_cudnn(); });
4688