1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/strings/str_cat.h"
22 #include "third_party/eigen3/Eigen/Core"
23 #include "rocm/include/miopen/miopen.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/core/util/env_var.h"
26 #include "tensorflow/stream_executor/dnn.h"
27 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
28 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
29 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
30 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
31 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
32 #include "tensorflow/stream_executor/lib/env.h"
33 #include "tensorflow/stream_executor/lib/error.h"
34 #include "tensorflow/stream_executor/lib/initialize.h"
35 #include "tensorflow/stream_executor/lib/threadpool.h"
36 #include "tensorflow/stream_executor/platform/dso_loader.h"
37 #include "tensorflow/stream_executor/platform/logging.h"
38 #include "tensorflow/stream_executor/plugin_registry.h"
39 #include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
40 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
41 #include "tensorflow/stream_executor/scratch_allocator.h"
42 #include "tensorflow/stream_executor/stream.h"
43 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
44 
45 namespace {
46 
47 // Converts (via narrowing) a type T value to a type U, and checks that the
48 // value has no value change due to the conversion.
49 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)50 NarrowT CheckedNarrowing(const WideT& wide) {
51   NarrowT narrow = wide;
52   CHECK_EQ(narrow, wide)
53       << "checked narrowing failed; values not equal post-conversion";
54   return narrow;
55 }
56 
57 const int kConvDebugVlogLevel = 3;
58 
59 }  // namespace
60 
61 namespace stream_executor {
62 
63 using dnn::AlgorithmDesc;
64 using dnn::BatchDescriptor;
65 using dnn::ConvolutionDescriptor;
66 using dnn::FilterDescriptor;
67 using dnn::NormalizeDescriptor;
68 using dnn::PoolingDescriptor;
69 
70 namespace gpu {
71 
72 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kMIOpenPlugin);
73 
ToString(miopenStatus_t status)74 string ToString(miopenStatus_t status) {
75   switch (status) {
76     case miopenStatusSuccess:
77       return "miopenStatusSuccess";
78     case miopenStatusNotInitialized:
79       return "miopenStatusNotInitialized";
80     case miopenStatusAllocFailed:
81       return "miopenStatusAllocFailed";
82     case miopenStatusBadParm:
83       return "miopenStatusBadParm";
84     case miopenStatusInternalError:
85       return "miopenStatusInternalError";
86     case miopenStatusInvalidValue:
87       return "miopenStatusInvalidValue";
88     case miopenStatusNotImplemented:
89       return "miopenStatusNotImplemented";
90     case miopenStatusUnknownError:
91       return "miopenStatusUnknownError";
92     default:
93       return absl::StrCat("<unknown miopen status: ", static_cast<int>(status),
94                           ">");
95   }
96 }
97 
ToString(miopenConvFwdAlgorithm_t algorithm)98 string ToString(miopenConvFwdAlgorithm_t algorithm) {
99   string s;
100   switch (algorithm) {
101     case miopenConvolutionFwdAlgoGEMM:
102       s = "GEMM";
103       break;
104     case miopenConvolutionFwdAlgoDirect:
105       s = "Direct";
106       break;
107     case miopenConvolutionFwdAlgoFFT:
108       s = "FFT";
109       break;
110     case miopenConvolutionFwdAlgoWinograd:
111       s = "Winograd";
112       break;
113     case miopenConvolutionFwdAlgoImplicitGEMM:
114       s = "Implicit GEMM";
115       break;
116   }
117   return s;
118 }
119 
ToString(miopenConvBwdWeightsAlgorithm_t algorithm)120 string ToString(miopenConvBwdWeightsAlgorithm_t algorithm) {
121   string s;
122   switch (algorithm) {
123     case miopenConvolutionBwdWeightsAlgoGEMM:
124       s = "GEMM";
125       break;
126     case miopenConvolutionBwdWeightsAlgoDirect:
127       s = "Direct";
128       break;
129     case miopenConvolutionBwdWeightsAlgoWinograd:
130       s = "Winograd";
131       break;
132     case miopenConvolutionBwdWeightsAlgoImplicitGEMM:
133       s = "Implicit GEMM";
134       break;
135   }
136   return s;
137 }
138 
ToString(miopenConvBwdDataAlgorithm_t algorithm)139 string ToString(miopenConvBwdDataAlgorithm_t algorithm) {
140   string s;
141   switch (algorithm) {
142     case miopenConvolutionBwdDataAlgoGEMM:
143       s = "GEMM";
144       break;
145     case miopenConvolutionBwdDataAlgoDirect:
146       s = "Direct";
147       break;
148     case miopenConvolutionBwdDataAlgoFFT:
149       s = "FFT";
150       break;
151     case miopenConvolutionBwdDataAlgoWinograd:
152       s = "Winograd";
153       break;
154     case miopenTransposeBwdDataAlgoGEMM:
155       s = "Transpose GEMM";
156       break;
157     case miopenConvolutionBwdDataAlgoImplicitGEMM:
158       s = "Implicit GEMM";
159       break;
160   }
161   return s;
162 }
163 
ToString(miopenConvAlgorithm_t algorithm)164 string ToString(miopenConvAlgorithm_t algorithm) {
165   string s;
166   switch (algorithm) {
167     case miopenConvolutionAlgoGEMM:
168       s = "GEMM";
169       break;
170     case miopenConvolutionAlgoDirect:
171       s = "Direct";
172       break;
173     case miopenConvolutionAlgoFFT:
174       s = "FFT";
175       break;
176     case miopenConvolutionAlgoWinograd:
177       s = "Winograd";
178       break;
179     case miopenConvolutionAlgoImplicitGEMM:
180       s = "Implicit GEMM";
181       break;
182   }
183   return s;
184 }
185 
186 // RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
187 //
188 // See MIOpenAccess::GetHandle() for details.
189 class MIOpenHandle {
190  public:
191   // Takes ownership of the executor context and the lock to access MIOpen
192   // using handle.
MIOpenHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,miopenHandle_t handle)193   MIOpenHandle(gpu::ScopedActivateExecutorContext context,
194                std::unique_ptr<absl::MutexLock> lock, miopenHandle_t handle)
195       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
196 
197   // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
198   // a copy.
handle() const199   miopenHandle_t handle() const { return handle_; }
200 
201  private:
202   gpu::ScopedActivateExecutorContext context_;
203   std::unique_ptr<absl::MutexLock> lock_;
204   miopenHandle_t handle_;  // Not owned.
205 };
206 
207 namespace wrap {
208 
209 #ifdef PLATFORM_GOOGLE
210 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)      \
211   struct WrapperShim__##__name {                 \
212     template <typename... Args>                  \
213     miopenStatus_t operator()(Args... args) {    \
214       miopenStatus_t retval = ::__name(args...); \
215       return retval;                             \
216     }                                            \
217   } __name;
218 
219 #else
220 
221 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)                               \
222   struct DynLoadShim__##__name {                                          \
223     static const char* kName;                                             \
224     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
225     static void* GetDsoHandle() {                                         \
226       auto s = internal::CachedDsoLoader::GetMiopenDsoHandle();           \
227       return s.ValueOrDie();                                              \
228     }                                                                     \
229     static FuncPtrT LoadOrDie() {                                         \
230       void* f;                                                            \
231       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
232                                                           kName, &f);     \
233       CHECK(s.ok()) << "could not find " << kName                         \
234                     << " in miopen DSO; dlerror: " << s.error_message();  \
235       return reinterpret_cast<FuncPtrT>(f);                               \
236     }                                                                     \
237     static FuncPtrT DynLoad() {                                           \
238       static FuncPtrT f = LoadOrDie();                                    \
239       return f;                                                           \
240     }                                                                     \
241     template <typename... Args>                                           \
242     miopenStatus_t operator()(Args... args) {                             \
243       return DynLoad()(args...);                                          \
244     }                                                                     \
245   } __name;                                                               \
246   const char* DynLoadShim__##__name::kName = #__name;
247 
248 #endif
249 
250 // clang-format off
251 #define MIOPEN_DNN_ROUTINE_EACH(__macro)                             \
252   __macro(miopenBatchNormalizationBackward)                          \
253   __macro(miopenBatchNormalizationForwardInference)                  \
254   __macro(miopenBatchNormalizationForwardTraining)                   \
255   __macro(miopenGetConvolutionForwardOutputDim)                      \
256   __macro(miopenGetConvolutionNdForwardOutputDim)                    \
257   __macro(miopenFindConvolutionForwardAlgorithm)                     \
258   __macro(miopenCreateTensorDescriptor)                              \
259   __macro(miopenDestroyTensorDescriptor)                             \
260   __macro(miopenSetNdPoolingDescriptor)                              \
261   __macro(miopenSetPoolingIndexType)                                 \
262   __macro(miopenSetLRNDescriptor)                                    \
263   __macro(miopenLRNGetWorkSpaceSize)                                 \
264   __macro(miopenCreateConvolutionDescriptor)                         \
265   __macro(miopenCreatePoolingDescriptor)                             \
266   __macro(miopenDestroyPoolingDescriptor)                            \
267   __macro(miopenCreateLRNDescriptor)                                 \
268   __macro(miopenDestroyLRNDescriptor)                                \
269   __macro(miopenDestroyConvolutionDescriptor)                        \
270   __macro(miopenCreateWithStream)                                    \
271   __macro(miopenDestroy)                                             \
272   __macro(miopenSetStream)                                           \
273   __macro(miopenSetAllocator)                                        \
274   __macro(miopenActivationForward)                                   \
275   __macro(miopenConvolutionForward)                                  \
276   __macro(miopenConvolutionBackwardBias)                             \
277   __macro(miopenConvolutionForwardGetWorkSpaceSize)                  \
278   __macro(miopenInitConvolutionDescriptor)                           \
279   __macro(miopenInitConvolutionNdDescriptor)                         \
280   __macro(miopenGetConvolutionDescriptor)                            \
281   __macro(miopenGetConvolutionNdDescriptor)                          \
282   __macro(miopenSetConvolutionGroupCount)                            \
283   __macro(miopenSet4dTensorDescriptor)                               \
284   __macro(miopenGetTensorDescriptor)                                 \
285   __macro(miopenSetTensorDescriptor)                                 \
286   __macro(miopenGetTensorDescriptorSize)                             \
287   __macro(miopenPoolingForward)                                      \
288   __macro(miopenPoolingGetWorkSpaceSizeV2)                           \
289   __macro(miopenPoolingBackward)                                     \
290   __macro(miopenLRNForward)                                          \
291   __macro(miopenLRNBackward)                                         \
292   __macro(miopenOpTensor)                                            \
293   __macro(miopenConvolutionBackwardData)                             \
294   __macro(miopenConvolutionBackwardWeights)                          \
295   __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)          \
296   __macro(miopenFindConvolutionBackwardDataAlgorithm)                \
297   __macro(miopenFindConvolutionBackwardWeightsAlgorithm)             \
298   __macro(miopenConvolutionBackwardDataGetWorkSpaceSize)             \
299   __macro(miopenCreateRNNDescriptor)                                 \
300   __macro(miopenSetRNNDescriptor)                                    \
301   __macro(miopenDestroyRNNDescriptor)                                \
302   __macro(miopenGetRNNParamsSize)                                    \
303   __macro(miopenGetRNNLayerParam)                                    \
304   __macro(miopenGetRNNLayerBias)                                     \
305   __macro(miopenGetRNNWorkspaceSize)                                 \
306   __macro(miopenGetRNNTrainingReserveSize)                           \
307   __macro(miopenRNNForwardInference)                                 \
308   __macro(miopenRNNForwardTraining)                                  \
309   __macro(miopenRNNBackwardData)                                     \
310   __macro(miopenRNNBackwardWeights)                                  \
311   __macro(miopenGetRNNLayerParamOffset)                              \
312   __macro(miopenGetRNNLayerParamSize)                                \
313   __macro(miopenGetRNNLayerBiasOffset)                               \
314   __macro(miopenGetRNNLayerBiasSize)                                 \
315   __macro(miopenGetRNNParamsDescriptor)                              \
316   __macro(miopenCreateActivationDescriptor)                          \
317   __macro(miopenSetActivationDescriptor)                             \
318   __macro(miopenGetActivationDescriptor)                             \
319   __macro(miopenDestroyActivationDescriptor)                         \
320   __macro(miopenCreateFusionPlan)                                    \
321   __macro(miopenCreateOpConvForward)                                 \
322   __macro(miopenCreateOpBiasForward)                                 \
323   __macro(miopenCreateOpActivationForward)                           \
324   __macro(miopenCreateOpActivationBackward)                          \
325   __macro(miopenCreateOpBatchNormInference)                          \
326   __macro(miopenCreateOpBatchNormForward)                            \
327   __macro(miopenCreateOpBatchNormBackward)                           \
328   __macro(miopenCompileFusionPlan)                                   \
329   __macro(miopenFusionPlanGetOp)                                     \
330   __macro(miopenCreateOperatorArgs)                                  \
331   __macro(miopenSetOpArgsConvForward)                                \
332   __macro(miopenSetOpArgsBiasForward)                                \
333   __macro(miopenSetOpArgsActivForward)                               \
334   __macro(miopenSetOpArgsActivBackward)                              \
335   __macro(miopenSetOpArgsBatchNormInference)                         \
336   __macro(miopenSetOpArgsBatchNormForward)                           \
337   __macro(miopenSetOpArgsBatchNormBackward)                          \
338   __macro(miopenExecuteFusionPlan)                                   \
339   __macro(miopenDestroyOperatorArgs)                                 \
340   __macro(miopenDestroyFusionPlan)                                   \
341   __macro(miopenConvolutionForwardGetSolutionCount)                  \
342   __macro(miopenConvolutionForwardGetSolution)                       \
343   __macro(miopenConvolutionForwardGetSolutionWorkspaceSize)          \
344   __macro(miopenConvolutionForwardCompileSolution)                   \
345   __macro(miopenConvolutionForwardImmediate)                         \
346   __macro(miopenConvolutionBackwardDataGetSolutionCount)             \
347   __macro(miopenConvolutionBackwardDataGetSolution)                  \
348   __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize)     \
349   __macro(miopenConvolutionBackwardDataCompileSolution)              \
350   __macro(miopenConvolutionBackwardDataImmediate)                    \
351   __macro(miopenConvolutionBackwardWeightsGetSolutionCount)          \
352   __macro(miopenConvolutionBackwardWeightsGetSolution)               \
353   __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize)  \
354   __macro(miopenConvolutionBackwardWeightsCompileSolution)           \
355   __macro(miopenConvolutionBackwardWeightsImmediate)                 \
356   __macro(miopenCreateCTCLossDescriptor)                             \
357   __macro(miopenSetCTCLossDescriptor)                                \
358   __macro(miopenGetCTCLossWorkspaceSize)                             \
359   __macro(miopenCTCLoss)                                             \
360   __macro(miopenDestroyCTCLossDescriptor)
361 // clang-format on
362 
363 MIOPEN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_MIOPEN_WRAP)
364 
365 #undef MIOPEN_DNN_ROUTINE_EACH
366 
367 }  // namespace wrap
368 
369 namespace {
370 
371 // These routines should ideally be provided as an MIOpen API.
372 // They are called for *every* _ROCMmFusedOp*::Compute call, and they need to be
373 // efficient! Instead of calculating the hash value by quering the MIOpen Get*
374 // APIs for the descriptor components, it would be a lot more efficient if,
375 // MIOpen calculated the hash value when creating the descriptor, stored it on
376 // the descriptor datastructure, and provided an API routine to query it.
377 
378 const int kMaxMIOpenTensorSize = 5;
379 
GetHashValue(miopenTensorDescriptor_t tensor_desc)380 uint64 GetHashValue(miopenTensorDescriptor_t tensor_desc) {
381   miopenDataType_t datatype = miopenFloat;
382   int dims[kMaxMIOpenTensorSize] = {0};
383   int strides[kMaxMIOpenTensorSize] = {0};
384   wrap::miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides);
385 
386   uint64 hash_value = tensorflow::hash<int>()(datatype);
387   for (int dim : dims)
388     hash_value =
389         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(dim));
390   for (int stride : strides)
391     hash_value =
392         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(stride));
393 
394   return hash_value;
395 }
396 
GetHashValue(miopenConvolutionDescriptor_t conv_desc)397 uint64 GetHashValue(miopenConvolutionDescriptor_t conv_desc) {
398   miopenConvolutionMode_t c_mode = miopenConvolution;
399   int nd = 0;
400   wrap::miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr,
401                                          nullptr, &c_mode);
402 
403   std::vector<int> stride(nd);
404   std::vector<int> pad(nd);
405   std::vector<int> dilation(nd);
406 
407   wrap::miopenGetConvolutionNdDescriptor(
408       conv_desc, nd, &nd, pad.data(), stride.data(), dilation.data(), &c_mode);
409 
410   uint64 hash_value = tensorflow::hash<int>()(c_mode);
411   auto hash64Combine = [&hash_value](int element) {
412     tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(element));
413   };
414   std::for_each(pad.begin(), pad.end(), hash64Combine);
415   std::for_each(stride.begin(), stride.end(), hash64Combine);
416   std::for_each(dilation.begin(), dilation.end(), hash64Combine);
417 
418   return hash_value;
419 }
420 
421 // Class to implement a cache of compiled fusion plans
422 class CachedFusionPlans {
423  public:
424   // Check if we already have a fusion_plan corresponding to the given hash
425   // value.
426   // If we do, then
427   //   return true (+ the cached fusion plan via given pointer)
428   // Else
429   //   create a new fusion plan descriptor,
430   //   associate it with the given hash value in the cache
431   //   return false (+ newly created fusion plan via given pointer)
FindOrCreate(uint64 hash,miopenFusionPlanDescriptor_t * fusion_plan,miopenFusionDirection_t fusion_direction,miopenTensorDescriptor_t input_descriptor)432   static bool FindOrCreate(uint64 hash,
433                            miopenFusionPlanDescriptor_t* fusion_plan,
434                            miopenFusionDirection_t fusion_direction,
435                            miopenTensorDescriptor_t input_descriptor) {
436     absl::MutexLock lock{&cached_plans_mutex};
437 
438     bool found_cached_plan = false;
439 
440     auto it = cached_plans.find(hash);
441     if (it != cached_plans.end()) {
442       *fusion_plan = it->second;
443       found_cached_plan = true;
444     } else {
445       auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction,
446                                                  input_descriptor);
447       if (status != miopenStatusSuccess) {
448         LOG(FATAL) << "call to miopenCreateFusionPlan failed: "
449                    << ToString(status);
450       } else {
451         cached_plans[hash] = *fusion_plan;
452       }
453     }
454 
455     return found_cached_plan;
456   }
457 
458   // Need to figure out the right place to call this routine
Clear()459   static void Clear() {
460     absl::MutexLock lock{&cached_plans_mutex};
461 
462     for (auto it : cached_plans) {
463       auto status = wrap::miopenDestroyFusionPlan(it.second);
464       if (status != miopenStatusSuccess) {
465         LOG(FATAL) << "call to miopenDestroyFusionPlan failed: "
466                    << ToString(status);
467       }
468     }
469 
470     cached_plans.clear();
471 
472     unsupported_plans.clear();
473   }
474 
475   // Is the Fusion plan corresponding to this hash unsupported
IsUnsupportedFusionPlan(uint64 hash)476   static bool IsUnsupportedFusionPlan(uint64 hash) {
477     absl::MutexLock lock{&cached_plans_mutex};
478     return unsupported_plans.count(hash) > 0;
479   }
480 
481   // Mark the given hash value as corresponding to an unsupported fusion plan
MarkFusionPlanUnsupported(uint64 hash)482   static void MarkFusionPlanUnsupported(uint64 hash) {
483     absl::MutexLock lock{&cached_plans_mutex};
484     unsupported_plans.insert(hash);
485   }
486 
487  private:
488   // Mutex to guard access to all data within this class
489   static absl::Mutex cached_plans_mutex;
490 
491   // Map of hash-value to MIOpen Fusion plan descriptors
492   // Need to be able share this across more than one stream and hence static
493   static std::map<uint64, miopenFusionPlanDescriptor_t> cached_plans;
494 
495   // Set of hash-values that correspond to MIOpen Fusion plans that will fail
496   // compile and hence are not supported.
497   static std::set<uint64> unsupported_plans;
498 };
499 
500 absl::Mutex CachedFusionPlans::cached_plans_mutex;
501 std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
502 std::set<uint64> CachedFusionPlans::unsupported_plans;
503 
GetProfileResultFromConvSolution(miopenConvSolution_t solution)504 dnn::ProfileResult GetProfileResultFromConvSolution(
505     miopenConvSolution_t solution) {
506   dnn::ProfileResult profile_result;
507   profile_result.set_algorithm(
508       {static_cast<AlgorithmDesc::Index>(solution.solution_id), false});
509   profile_result.set_elapsed_time_in_ms(solution.time);
510   profile_result.set_scratch_size(solution.workspace_size);
511   return profile_result;
512 }
513 
GetProfileResultFromConvAlgoPerf(dnn::ConvolutionKind kind,miopenConvAlgoPerf_t algorithm)514 dnn::ProfileResult GetProfileResultFromConvAlgoPerf(
515     dnn::ConvolutionKind kind, miopenConvAlgoPerf_t algorithm) {
516   dnn::ProfileResult profile_result;
517   switch (kind) {
518     case dnn::ConvolutionKind::FORWARD:
519       profile_result.set_algorithm(
520           {static_cast<AlgorithmDesc::Index>(algorithm.fwd_algo), false});
521       break;
522     case dnn::ConvolutionKind::BACKWARD_DATA:
523       profile_result.set_algorithm(
524           {static_cast<AlgorithmDesc::Index>(algorithm.bwd_data_algo), false});
525       break;
526     case dnn::ConvolutionKind::BACKWARD_FILTER:
527       profile_result.set_algorithm(
528           {static_cast<AlgorithmDesc::Index>(algorithm.bwd_weights_algo),
529            false});
530       break;
531     default:
532       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
533       break;
534   }
535   profile_result.set_elapsed_time_in_ms(algorithm.time);
536   profile_result.set_scratch_size(algorithm.memory);
537   return profile_result;
538 }
539 }  // namespace
540 
541 // Wraps a MIOpen handle and provides access to it through miopenHandle_t
542 // instances, which also locks a mutex, acquires the ROCm context, and sets
543 // the stream that MIOpen should use to enqueue any work.
544 //
545 // Note: MIOpenSupport::miopen_ should be the only instantiation of this class.
546 class MIOpenAccess {
547  public:
548   // Takes ownership of the handle.
MIOpenAccess(miopenHandle_t handle)549   explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
550 
~MIOpenAccess()551   ~MIOpenAccess() {
552     absl::MutexLock lock(&mutex_);
553     wrap::miopenDestroy(handle_);
554   }
555 
556   // Creates a MIOpenHandle instance for stream.
557   //
558   // MIOpen API calls using the same handle instance need to be serialized
559   // across threads. This is guaranteed by MIOpenHandle instances locking the
560   // mutex owned by this class.
561   //
562   // Most MIOpen APIs taking a handle perform work on a HIP stream. The
563   // MIOpenHandle instance acquires the executor's ROCm context and sets MIOpen
564   // to use the provided stream.
565   //
566   // The stream argument may be null, which translates to the null stream.
567   // The null stream synchronizes with all other streams and it is
568   // therefore a bad idea (performance wise) to call any MIOpen APIs that
569   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)570   MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
571     auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
572     mutex_.AssertHeld();
573     gpu::ScopedActivateExecutorContext context(executor);
574     hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
575     auto status = wrap::miopenSetStream(handle_, hip_stream);
576     CHECK_EQ(status, miopenStatusSuccess) << "Failed to set MIOpen stream.";
577     return MIOpenHandle(std::move(context), std::move(lock), handle_);
578   }
579 
580  private:
581   // Guards the enqueueing of MIOpen operations via the handle_ below.
582   absl::Mutex mutex_;
583 
584   // MIOpen library handle.
585   miopenHandle_t handle_ TF_GUARDED_BY(mutex_);  // Owned.
586 };
587 
MIOpenSupport(GpuExecutor * parent)588 MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
589   // by default, the Get*Algorithm API will return the list of all applicable
590   // algorithms
591   return_best_algo_only_ = false;
592   // but if the env var TF_ROCM_RETURN_BEST_ALGO_ONLY is set, only the best
593   // (i.e. most efficient) algorithm will be returned
594   tensorflow::ReadBoolFromEnvVar("TF_ROCM_RETURN_BEST_ALGO_ONLY", false,
595                                  &return_best_algo_only_);
596 
597   // by default, use Find Mode APIs for convolution
598   use_immediate_mode_ = false;
599   // swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
600   tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
601                                  &use_immediate_mode_);
602 
603   bool enable_pooling_cache = false;
604   tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
605                                  &enable_pooling_cache);
606   if (enable_pooling_cache) m_pooling_cache_allowed = true;
607 }
608 
Init()609 port::Status MIOpenSupport::Init() {
610   ScopedActivateExecutorContext context(parent_);
611   miopenHandle_t miopen_handle = nullptr;
612   auto status = wrap::miopenCreateWithStream(
613       reinterpret_cast<miopenHandle_t*>(&miopen_handle), (hipStream_t)(0));
614   if (status == miopenStatusSuccess) {
615     miopen_.reset(new MIOpenAccess(miopen_handle));
616     return port::Status::OK();
617   }
618 
619   CHECK_EQ(miopen_handle, nullptr);
620   LOG(ERROR) << "could not create miopen handle: " << ToString(status);
621   if (status == miopenStatusNotInitialized) {
622     auto result = rocm::Diagnostician::FindKernelDriverVersion();
623     if (!result.ok()) {
624       LOG(ERROR) << "error retrieving driver version: "
625                  << rocm::DriverVersionStatusToString(result);
626     } else {
627       const auto& version = result.ValueOrDie();
628       LOG(INFO) << "possibly insufficient driver version: "
629                 << rocm::DriverVersionToString(version);
630     }
631   }
632 
633   return port::Status{port::error::INTERNAL,
634                       absl::StrCat("miopen library could not create a handle: ",
635                                    ToString(status))};
636 }
637 
638 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()639 MIOpenSupport::GetVersion() {
640   // ROCM TODO: retrieve MIOpen version with its API
641   return perftools::gputools::dnn::VersionInfo(1, 3, 0);
642 }
643 
644 // Turns a BatchDescriptor structure into a miopen tensor handle within a scope.
645 class ScopedTensorDescriptor {
646  public:
ScopedTensorDescriptor(const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)647   ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
648                          miopenDataType_t elem_type)
649       : handle_(nullptr) {
650     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
651     if (status != miopenStatusSuccess) {
652       LOG(FATAL) << "could not create miopen tensor descriptor: "
653                  << ToString(status);
654     }
655 
656     switch (batch_descriptor.layout()) {
657       case dnn::DataLayout::kBatchYXDepth:
658       case dnn::DataLayout::kBatchDepthYX: {
659         const int nd = batch_descriptor.ndims() + 2;
660 
661         // MIOpen requires the strides and dims to be ordered as BDYX.
662         std::vector<int64> strides64 =
663             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
664         std::vector<int64> dims64 =
665             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
666 
667         // MIOpen requires arrays of ints.
668         std::vector<int> strides(nd);
669         std::vector<int> dims(nd);
670         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
671                        &CheckedNarrowing<int64, int>);
672         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
673                        &CheckedNarrowing<int64, int>);
674         status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
675                                                  dims.data(), strides.data());
676 
677         if (status != miopenStatusSuccess) {
678           LOG(FATAL) << "could not convert BatchDescriptor "
679                      << batch_descriptor.ToString()
680                      << " to miopen tensor descriptor: " << ToString(status);
681         }
682       } break;
683       default:
684         LOG(FATAL) << "Unsupported tensor format "
685                    << DataLayoutString(batch_descriptor.layout());
686         break;
687     }
688   }
689 
~ScopedTensorDescriptor()690   ~ScopedTensorDescriptor() {
691     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
692     if (status != miopenStatusSuccess) {
693       LOG(ERROR) << "could not destroy miopen tensor descriptor: "
694                  << ToString(status);
695     }
696   }
697 
handle() const698   miopenTensorDescriptor_t handle() const { return handle_; }
699 
700  private:
701   miopenTensorDescriptor_t handle_;  // Owned.
702 
703   SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
704 };
705 
706 // Turns a FilterDescriptor structure into a miopen filter handle within a
707 // scope.
708 class ScopedFilterDescriptor {
709  public:
ScopedFilterDescriptor(const FilterDescriptor & filter_descriptor,const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)710   ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
711                          const BatchDescriptor& batch_descriptor,
712                          miopenDataType_t elem_type)
713       : handle_(nullptr) {
714     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
715     if (status != miopenStatusSuccess) {
716       LOG(FATAL) << "could not create miopen filter descriptor: "
717                  << ToString(status);
718     }
719 
720     const int nd = batch_descriptor.ndims() + 2;
721 
722     std::vector<int> dims(2 + filter_descriptor.ndims());
723     dims[0] = filter_descriptor.output_feature_map_count();
724     dims[1] = filter_descriptor.input_feature_map_count();
725     const auto& spatial_dims = filter_descriptor.input_filter_dims();
726     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
727 
728     status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
729                                              dims.data(), nullptr);
730     if (status != miopenStatusSuccess) {
731       LOG(FATAL) << "could not set miopen filter descriptor: "
732                  << ToString(status);
733     }
734   }
735 
~ScopedFilterDescriptor()736   ~ScopedFilterDescriptor() {
737     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
738     if (status != miopenStatusSuccess) {
739       LOG(ERROR) << "could not destroy miopen filter descriptor: "
740                  << ToString(status);
741     }
742   }
743 
handle() const744   miopenTensorDescriptor_t handle() const { return handle_; }
745 
746  private:
747   // miopen filter descriptor this object creates. Owned.
748   miopenTensorDescriptor_t handle_;
749 
750   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
751 };
752 
753 // Turns a ConvolutionDescriptor structure into a miopen convolution handle
754 // within a scope.
755 class ScopedConvolutionDescriptor {
756  public:
ScopedConvolutionDescriptor(const ConvolutionDescriptor & convolution_descriptor,miopenDataType_t data_type)757   ScopedConvolutionDescriptor(
758       const ConvolutionDescriptor& convolution_descriptor,
759       miopenDataType_t data_type)
760       : handle_(nullptr) {
761     auto status = wrap::miopenCreateConvolutionDescriptor(&handle_);
762     if (status != miopenStatusSuccess) {
763       LOG(FATAL) << "could not create miopen convolution descriptor: "
764                  << ToString(status);
765     }
766     const auto& strides64 = convolution_descriptor.strides();
767     const auto& padding64 = convolution_descriptor.padding();
768     if (convolution_descriptor.pad_alignment() ==
769         dnn::PadAlignment::kTensorFlowPadding) {
770       LOG(ERROR) << "TensorFlow padding alignment is not supported.";
771     }
772 
773     // MIOpen requires arrays of ints.
774     std::vector<int> strides(convolution_descriptor.ndims());
775     std::vector<int> padding(convolution_descriptor.ndims());
776     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
777                    &CheckedNarrowing<int64, int>);
778     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
779                    &CheckedNarrowing<int64, int>);
780 
781     std::vector<int> upscale(convolution_descriptor.ndims());
782     const auto& dilations64 = convolution_descriptor.dilations();
783     std::transform(dilations64.cbegin(), dilations64.cend(), upscale.begin(),
784                    &CheckedNarrowing<int64, int>);
785 
786     status = wrap::miopenInitConvolutionNdDescriptor(
787         handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
788         upscale.data(), miopenConvolution);
789     if (status != miopenStatusSuccess) {
790       LOG(FATAL) << "could not set miopen convolution descriptor: "
791                  << ToString(status);
792     }
793 
794     VLOG(2) << "Requesting grouped convolution: "
795             << convolution_descriptor.group_count();
796     status = wrap::miopenSetConvolutionGroupCount(
797         handle_, convolution_descriptor.group_count());
798     if (status != miopenStatusSuccess) {
799       LOG(FATAL) << "could not set miopen convolution group count: "
800                  << ToString(status);
801     }
802   }
~ScopedConvolutionDescriptor()803   ~ScopedConvolutionDescriptor() {
804     auto status = wrap::miopenDestroyConvolutionDescriptor(handle_);
805     if (status != miopenStatusSuccess) {
806       LOG(ERROR) << "could not destroy miopen convolution descriptor: "
807                  << ToString(status);
808     }
809   }
810 
handle() const811   miopenConvolutionDescriptor_t handle() const { return handle_; }
812 
813  private:
814   miopenConvolutionDescriptor_t handle_;  // Owned.
815 
816   SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
817 };
818 
819 // Turns a PoolingDescriptor structure into a miopen pooling descriptor handle
820 // within a scope.
821 class ScopedPoolingDescriptor {
822  public:
ScopedPoolingDescriptor(const PoolingDescriptor & pooling_descriptor)823   ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
824       : handle_(nullptr) {
825     auto status = wrap::miopenCreatePoolingDescriptor(&handle_);
826     if (status != miopenStatusSuccess) {
827       LOG(FATAL) << "could not create miopen pooling descriptor: "
828                  << ToString(status);
829     }
830 
831     absl::Span<const int64> strides64 = pooling_descriptor.strides();
832     absl::Span<const int64> padding64 = pooling_descriptor.padding();
833     absl::Span<const int64> shape64 = pooling_descriptor.window();
834 
835     const int nd = pooling_descriptor.ndims();
836     std::vector<int> shape(nd);
837     std::vector<int> padding(nd);
838     std::vector<int> strides(nd);
839     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
840                    &CheckedNarrowing<int64, int>);
841     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
842                    &CheckedNarrowing<int64, int>);
843     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
844                    &CheckedNarrowing<int64, int>);
845 
846     status = wrap::miopenSetNdPoolingDescriptor(
847         handle_,
848         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
849              ? miopenPoolingMax
850              : miopenPoolingAverage),
851         nd, shape.data(), padding.data(), strides.data());
852 
853     // Note: The index type has to be uint32 type for now because MIOpen
854     // API assumes all input indexes to be the same type. Since a tensor
855     // descriptor can only use int32 type, the index type here need to be
856     // aligned with the tensor index type of the (input) tensor descritptor
857     status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
858 
859     if (status != miopenStatusSuccess) {
860       LOG(FATAL) << "could not set miopen pooling descriptor: "
861                  << ToString(status);
862     }
863   }
~ScopedPoolingDescriptor()864   ~ScopedPoolingDescriptor() {
865     auto status = wrap::miopenDestroyPoolingDescriptor(handle_);
866     if (status != miopenStatusSuccess) {
867       LOG(ERROR) << "could not destroy miopen pooling descriptor: "
868                  << ToString(status);
869     }
870   }
871 
handle() const872   miopenPoolingDescriptor_t handle() const { return handle_; }
873 
874  private:
875   miopenPoolingDescriptor_t handle_;  // Owned.
876 
877   SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
878 };
879 
880 // Turns a NormalizeDescriptor structure into a miopen LRN descriptor handle.
881 class ScopedNormalizeDescriptor {
882  public:
ScopedNormalizeDescriptor(const NormalizeDescriptor & normalize_descriptor)883   ScopedNormalizeDescriptor(const NormalizeDescriptor& normalize_descriptor)
884       : handle_(nullptr) {
885     auto status = wrap::miopenCreateLRNDescriptor(&handle_);
886     if (status != miopenStatusSuccess) {
887       LOG(FATAL) << "could not create miopen LRN descriptor: "
888                  << ToString(status);
889     }
890 
891     // The range specifies that the indices in the closed range
892     // [i - range, i + range] should be included in the normalization for index
893     // i. The lrnN value is the total number of elements in the range, so
894     // lrnN = 2*range + 1.
895     unsigned lrn_N = 2 * normalize_descriptor.range() + 1;
896 
897     // Note that SE defines the normalization operation as
898     //
899     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
900     //
901     // but MIOpen defines it as
902     //
903     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
904     //
905     // i.e. there is a factor of n difference between the meaning of the alphas
906     // in the two contexts. The MIOpen alpha is n times the SE alpha.
907     double lrn_alpha = lrn_N * normalize_descriptor.alpha();
908 
909     double lrn_beta = normalize_descriptor.beta();
910     double lrn_k = normalize_descriptor.bias();
911     status = wrap::miopenSetLRNDescriptor(handle_, miopenLRNCrossChannel, lrn_N,
912                                           lrn_alpha, lrn_beta, lrn_k);
913     if (status != miopenStatusSuccess) {
914       LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status);
915     }
916   }
917 
~ScopedNormalizeDescriptor()918   ~ScopedNormalizeDescriptor() {
919     auto status = wrap::miopenDestroyLRNDescriptor(handle_);
920     if (status != miopenStatusSuccess) {
921       LOG(ERROR) << "could not destroy miopen LRN descriptor: "
922                  << ToString(status);
923     }
924   }
925 
handle() const926   miopenLRNDescriptor_t handle() const { return handle_; }
927 
928  private:
929   miopenLRNDescriptor_t handle_;  // Owned.
930 
931   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
932 };
933 
934 // Turns a activation mode into a miopen activation mode descriptor with a scope
935 // around it
936 class ScopedActivationDescriptor {
937  public:
ScopedActivationDescriptor(dnn::ActivationMode activation_mode)938   ScopedActivationDescriptor(dnn::ActivationMode activation_mode)
939       : handle_(nullptr),
940         miopen_activation_mode_(miopenActivationPASTHRU),
941         alpha_(0.0),
942         beta_(0.0),
943         gamma_(0.0) {
944     auto status = wrap::miopenCreateActivationDescriptor(&handle_);
945     if (status != miopenStatusSuccess) {
946       LOG(FATAL) << "call to miopenCreateActivationDescriptor failed: "
947                  << ToString(status);
948     } else {
949       switch (activation_mode) {
950         case dnn::ActivationMode::kNone:
951           miopen_activation_mode_ = miopenActivationPASTHRU;
952           break;
953 
954         case dnn::ActivationMode::kSigmoid:
955           miopen_activation_mode_ = miopenActivationLOGISTIC;
956           break;
957 
958         case dnn::ActivationMode::kRelu:
959           miopen_activation_mode_ = miopenActivationRELU;
960           break;
961 
962         case dnn::ActivationMode::kRelu6:
963           miopen_activation_mode_ = miopenActivationRELU;
964           alpha_ = 6.0;
965           break;
966 
967         case dnn::ActivationMode::kTanh:
968           miopen_activation_mode_ = miopenActivationTANH;
969           break;
970 
971         default:
972           LOG(FATAL) << "Activation mode ("
973                      << dnn::ActivationModeString(activation_mode)
974                      << ") not yet implemented";
975           break;
976       }
977 
978       status = wrap::miopenSetActivationDescriptor(
979           handle_, miopen_activation_mode_, alpha_, beta_, gamma_);
980       if (status != miopenStatusSuccess) {
981         LOG(FATAL) << "call to miopenSetActivationDescriptor failed: "
982                    << ToString(status);
983       }
984     }
985   }
986 
~ScopedActivationDescriptor()987   ~ScopedActivationDescriptor() {
988     auto status = wrap::miopenDestroyActivationDescriptor(handle_);
989     if (status != miopenStatusSuccess) {
990       LOG(FATAL) << "call to miopenDestroyActivationDescriptor failed: "
991                  << ToString(status);
992     }
993   }
994 
handle() const995   miopenActivationDescriptor_t handle() const { return handle_; }
996 
GetHashValue()997   uint64 GetHashValue() {
998     uint64 hash_value = tensorflow::hash<int>()(miopen_activation_mode_);
999     hash_value = tensorflow::Hash64Combine(hash_value,
1000                                            tensorflow::hash<double>()(alpha_));
1001     hash_value = tensorflow::Hash64Combine(hash_value,
1002                                            tensorflow::hash<double>()(beta_));
1003     hash_value = tensorflow::Hash64Combine(hash_value,
1004                                            tensorflow::hash<double>()(gamma_));
1005 
1006     return hash_value;
1007   }
1008 
1009  private:
1010   miopenActivationDescriptor_t handle_;  // Owned.
1011 
1012   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
1013 
1014  public:
1015   // caching these values here to avoid calling miopenGetActivationDescriptor
1016   // to do the same. miopenGetActivationDescriptor gets called twice during each
1017   // call to execute a fusion plan (that involves the activation op)...once call
1018   // during calculating hashvalue for the fusion op, and another before calling
1019   // SetOpArgs for the activation op
1020   miopenActivationMode_t miopen_activation_mode_;
1021   double alpha_;
1022   double beta_;
1023   double gamma_;
1024 };
1025 
1026 // base class for all fusion plan implementations to derive from
1027 class ScopedFusionPlanBase {
1028  public:
ScopedFusionPlanBase(miopenHandle_t miopen_handle,const miopenFusionDirection_t fuse_direction,const miopenTensorDescriptor_t input_descriptor)1029   ScopedFusionPlanBase(miopenHandle_t miopen_handle,
1030                        const miopenFusionDirection_t fuse_direction,
1031                        const miopenTensorDescriptor_t input_descriptor)
1032       : miopen_handle_(miopen_handle),
1033         fusion_plan_(nullptr),
1034         fusion_args_(nullptr),
1035         fusion_plan_compiled_(false) {
1036     auto status = wrap::miopenCreateOperatorArgs(&fusion_args_);
1037     if (status != miopenStatusSuccess) {
1038       LOG(FATAL) << "call to miopenCreateOperatorArgs failed: "
1039                  << ToString(status);
1040     }
1041   }
1042 
~ScopedFusionPlanBase()1043   virtual ~ScopedFusionPlanBase() {
1044     auto status = wrap::miopenDestroyOperatorArgs(fusion_args_);
1045     if (status != miopenStatusSuccess) {
1046       LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: "
1047                  << ToString(status);
1048     }
1049   }
1050 
Execute(miopenTensorDescriptor_t input_descriptor,const void * input_data,miopenTensorDescriptor_t output_descriptor,void * output_data)1051   miopenStatus_t Execute(miopenTensorDescriptor_t input_descriptor,
1052                          const void* input_data,
1053                          miopenTensorDescriptor_t output_descriptor,
1054                          void* output_data) {
1055     auto status = wrap::miopenExecuteFusionPlan(
1056         miopen_handle_, fusion_plan_, input_descriptor, input_data,
1057         output_descriptor, output_data, fusion_args_);
1058     if (status != miopenStatusSuccess) {
1059       LOG(FATAL) << "call to miopenExecuteFusionPlan failed: "
1060                  << ToString(status);
1061     }
1062 
1063     return status;
1064   }
1065 
CompilationSucceeded()1066   bool CompilationSucceeded() { return fusion_plan_compiled_; }
1067 
1068  protected:
SetConvolutionArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1069   miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha,
1070                                     const float* beta, const void* data) {
1071     miopenFusionOpDescriptor_t conv_op;
1072     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op);
1073     if (status != miopenStatusSuccess) {
1074       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1075                  << ToString(status);
1076     }
1077 
1078     status = wrap::miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha,
1079                                               beta, data);
1080     if (status != miopenStatusSuccess) {
1081       LOG(FATAL) << "call to miopenSetOpArgsConvForward failed: "
1082                  << ToString(status);
1083     }
1084     return status;
1085   }
1086 
SetBiasArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1087   miopenStatus_t SetBiasArgs(const int op_idx, const float* alpha,
1088                              const float* beta, const void* data) {
1089     miopenFusionOpDescriptor_t bias_op;
1090     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op);
1091     if (status != miopenStatusSuccess) {
1092       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1093                  << ToString(status);
1094     }
1095 
1096     status = wrap::miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha,
1097                                               beta, data);
1098     if (status != miopenStatusSuccess) {
1099       LOG(FATAL) << "call to miopenSetOpArgsBiasForward failed: "
1100                  << ToString(status);
1101     }
1102     return status;
1103   }
1104 
SetBatchNormInferenceArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1105   miopenStatus_t SetBatchNormInferenceArgs(const int op_idx, const float* alpha,
1106                                            const float* beta, const void* scale,
1107                                            const void* offset, const void* mean,
1108                                            const void* variance,
1109                                            double epsilon) {
1110     miopenFusionOpDescriptor_t batchnorm_op;
1111     auto status =
1112         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1113     if (status != miopenStatusSuccess) {
1114       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1115                  << ToString(status);
1116     }
1117 
1118     status = wrap::miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op,
1119                                                      alpha, beta, scale, offset,
1120                                                      mean, variance, epsilon);
1121     if (status != miopenStatusSuccess) {
1122       LOG(FATAL) << "call to miopenSetOpArgsBatchNormInference failed: "
1123                  << ToString(status);
1124     }
1125     return status;
1126   }
1127 
SetBatchNormForwardArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,void * running_mean,void * running_variance,void * saved_mean,void * saved_inv_variance,double epsilon,double exponential_average_factor)1128   miopenStatus_t SetBatchNormForwardArgs(
1129       const int op_idx, const float* alpha, const float* beta,
1130       const void* scale, const void* offset, void* running_mean,
1131       void* running_variance, void* saved_mean, void* saved_inv_variance,
1132       double epsilon, double exponential_average_factor) {
1133     miopenFusionOpDescriptor_t batchnorm_op;
1134     auto status =
1135         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1136     if (status != miopenStatusSuccess) {
1137       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1138                  << ToString(status);
1139     }
1140 
1141     status = wrap::miopenSetOpArgsBatchNormForward(
1142         fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean,
1143         saved_inv_variance, running_mean, running_variance, epsilon,
1144         exponential_average_factor);
1145     if (status != miopenStatusSuccess) {
1146       LOG(FATAL) << "call to miopenSetOpArgsBatchNormForward failed: "
1147                  << ToString(status);
1148     }
1149     return status;
1150   }
1151 
SetBatchNormBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * x,const void * scale,const void * offset,void * scale_grad,void * offset_grad,const void * saved_mean,const void * saved_inv_variance)1152   miopenStatus_t SetBatchNormBackwardArgs(const int op_idx, const float* alpha,
1153                                           const float* beta, const void* x,
1154                                           const void* scale, const void* offset,
1155                                           void* scale_grad, void* offset_grad,
1156                                           const void* saved_mean,
1157                                           const void* saved_inv_variance) {
1158     miopenFusionOpDescriptor_t batchnorm_op;
1159     auto status =
1160         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1161     if (status != miopenStatusSuccess) {
1162       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1163                  << ToString(status);
1164     }
1165 
1166     status = wrap::miopenSetOpArgsBatchNormBackward(
1167         fusion_args_, batchnorm_op, alpha, beta, x, scale, offset, scale_grad,
1168         offset_grad, saved_mean, saved_inv_variance);
1169     if (status != miopenStatusSuccess) {
1170       LOG(FATAL) << "call to miopenSetOpArgsBatchNormBackward failed: "
1171                  << ToString(status);
1172     }
1173     return status;
1174   }
1175 
SetActivationForwardArgs(const int op_idx,const float * alpha,const float * beta,double activ_alpha,double activ_beta,double activ_gamma)1176   miopenStatus_t SetActivationForwardArgs(const int op_idx, const float* alpha,
1177                                           const float* beta, double activ_alpha,
1178                                           double activ_beta,
1179                                           double activ_gamma) {
1180     miopenFusionOpDescriptor_t actv_op;
1181     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1182     if (status != miopenStatusSuccess) {
1183       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1184                  << ToString(status);
1185     }
1186 
1187     status =
1188         wrap::miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta,
1189                                           activ_alpha, activ_beta, activ_gamma);
1190     if (status != miopenStatusSuccess) {
1191       LOG(FATAL) << "call to miopenSetOpArgsActivForward failed: "
1192                  << ToString(status);
1193     }
1194     return status;
1195   }
1196 
SetActivationBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * y,double activ_alpha,double activ_beta,double activ_gamma)1197   miopenStatus_t SetActivationBackwardArgs(const int op_idx, const float* alpha,
1198                                            const float* beta, const void* y,
1199                                            double activ_alpha,
1200                                            double activ_beta,
1201                                            double activ_gamma) {
1202     miopenFusionOpDescriptor_t actv_op;
1203     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1204     if (status != miopenStatusSuccess) {
1205       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1206                  << ToString(status);
1207     }
1208 
1209     status = wrap::miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha,
1210                                                 beta, y, nullptr, activ_alpha,
1211                                                 activ_beta, activ_gamma);
1212     if (status != miopenStatusSuccess) {
1213       LOG(FATAL) << "call to miopenSetOpArgsActivBackward failed: "
1214                  << ToString(status);
1215     }
1216     return status;
1217   }
1218 
1219   miopenHandle_t miopen_handle_;
1220   miopenFusionPlanDescriptor_t fusion_plan_;
1221   miopenOperatorArgs_t fusion_args_;  // Owned.
1222   bool fusion_plan_compiled_;
1223 
1224   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBase);
1225 };
1226 
1227 // class to represent the Convolution+Bias+Activation fusion plan
1228 class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase {
1229  public:
ScopedFusionPlanConvolutionBiasActivation(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1230   ScopedFusionPlanConvolutionBiasActivation(
1231       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1232       miopenTensorDescriptor_t filter_descriptor,
1233       miopenConvolutionDescriptor_t conv_descriptor,
1234       miopenTensorDescriptor_t bias_descriptor,
1235       ScopedActivationDescriptor& activation_descriptor)
1236       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1237                              input_descriptor) {
1238     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1239                                        filter_descriptor, conv_descriptor,
1240                                        bias_descriptor, activation_descriptor);
1241 
1242     bool is_compiled = CachedFusionPlans::FindOrCreate(
1243         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1244     if (!is_compiled) {
1245       miopenFusionOpDescriptor_t conv_op;
1246       auto status = wrap::miopenCreateOpConvForward(
1247           fusion_plan_, &conv_op, conv_descriptor, filter_descriptor);
1248       if (status != miopenStatusSuccess) {
1249         LOG(FATAL) << "call to miopenCreateOpConvForward failed: "
1250                    << ToString(status);
1251       }
1252 
1253       miopenFusionOpDescriptor_t bias_op;
1254       status = wrap::miopenCreateOpBiasForward(fusion_plan_, &bias_op,
1255                                                bias_descriptor);
1256       if (status != miopenStatusSuccess) {
1257         LOG(FATAL) << "call to miopenCreateOpBiasForward failed: "
1258                    << ToString(status);
1259       }
1260 
1261       miopenFusionOpDescriptor_t actv_op;
1262       status = wrap::miopenCreateOpActivationForward(
1263           fusion_plan_, &actv_op,
1264           activation_descriptor.miopen_activation_mode_);
1265       if (status != miopenStatusSuccess) {
1266         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1267                    << ToString(status);
1268       }
1269 
1270       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1271       if (status != miopenStatusSuccess) {
1272         VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: "
1273                 << ToString(status);
1274 
1275         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1276       } else {
1277         VLOG(2) << "Fusion Plan compile succedded (CBA) ";
1278         fusion_plan_compiled_ = true;
1279       }
1280     } else {
1281       // fusion plan was already compiled...check whether it failed to compile
1282       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1283     }
1284   }
1285 
SetConvolutionArgs(const void * filter_data)1286   miopenStatus_t SetConvolutionArgs(const void* filter_data) {
1287     float alpha = 1.0;
1288     float beta = 0.0;
1289     return ScopedFusionPlanBase::SetConvolutionArgs(k_conv_op_idx, &alpha,
1290                                                     &beta, filter_data);
1291   }
1292 
SetBiasArgs(const void * bias_data)1293   miopenStatus_t SetBiasArgs(const void* bias_data) {
1294     float alpha = 1.0;
1295     float beta = 0.0;
1296     return ScopedFusionPlanBase::SetBiasArgs(k_bias_op_idx, &alpha, &beta,
1297                                              bias_data);
1298   }
1299 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1300   miopenStatus_t SetActivationForwardArgs(
1301       ScopedActivationDescriptor& activation_descriptor) {
1302     float alpha = 1.0;
1303     float beta = 0.0;
1304 
1305     return ScopedFusionPlanBase::SetActivationForwardArgs(
1306         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1307         activation_descriptor.beta_, activation_descriptor.gamma_);
1308   }
1309 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1310   uint64 GetFusionOpHashValue(
1311       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1312       miopenTensorDescriptor_t filter_descriptor,
1313       miopenConvolutionDescriptor_t conv_descriptor,
1314       miopenTensorDescriptor_t bias_descriptor,
1315       ScopedActivationDescriptor& activation_descriptor) {
1316     uint64 hash_value = tensorflow::Hash64("ConvolutionBiasActivation");
1317 
1318     hash_value = tensorflow::Hash64Combine(
1319         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1320 
1321     hash_value =
1322         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1323     hash_value =
1324         tensorflow::Hash64Combine(hash_value, GetHashValue(filter_descriptor));
1325     hash_value =
1326         tensorflow::Hash64Combine(hash_value, GetHashValue(conv_descriptor));
1327     hash_value =
1328         tensorflow::Hash64Combine(hash_value, GetHashValue(bias_descriptor));
1329     hash_value = tensorflow::Hash64Combine(
1330         hash_value, activation_descriptor.GetHashValue());
1331     return hash_value;
1332   }
1333 
1334  private:
1335   const int k_conv_op_idx = 0;
1336   const int k_bias_op_idx = 1;
1337   const int k_actv_op_idx = 2;
1338 
1339   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanConvolutionBiasActivation);
1340 };
1341 
1342 // class to represent the BatchNorm+Activation (inference) fusion plan
1343 class ScopedFusionPlanBatchNormActivationInference
1344     : public ScopedFusionPlanBase {
1345  public:
ScopedFusionPlanBatchNormActivationInference(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1346   ScopedFusionPlanBatchNormActivationInference(
1347       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1348       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1349       ScopedActivationDescriptor& activation_descriptor)
1350       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1351                              input_descriptor) {
1352     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1353                                        scale_offset_mean_variance_descriptor,
1354                                        activation_descriptor);
1355 
1356     bool is_compiled = CachedFusionPlans::FindOrCreate(
1357         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1358 
1359     if (!is_compiled) {
1360       miopenFusionOpDescriptor_t batchnorm_op;
1361       auto status = wrap::miopenCreateOpBatchNormInference(
1362           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1363           scale_offset_mean_variance_descriptor);
1364 
1365       if (status != miopenStatusSuccess) {
1366         LOG(FATAL) << "call to miopenCreateOpBatchNormInference failed: "
1367                    << ToString(status);
1368       }
1369 
1370       miopenFusionOpDescriptor_t actv_op;
1371       status = wrap::miopenCreateOpActivationForward(
1372           fusion_plan_, &actv_op,
1373           activation_descriptor.miopen_activation_mode_);
1374       if (status != miopenStatusSuccess) {
1375         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1376                    << ToString(status);
1377       }
1378 
1379       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1380       if (status != miopenStatusSuccess) {
1381         VLOG(2) << "call to miopenCompileFusionPlan (BnA inference) failed: "
1382                 << ToString(status);
1383 
1384         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1385       } else {
1386         VLOG(2) << "Fusion Plan compile succedded (BnA inference) ";
1387         fusion_plan_compiled_ = true;
1388       }
1389     } else {
1390       // fusion plan was already compiled...check whether it failed to compile
1391       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1392     }
1393   }
1394 
SetBatchNormInferenceArgs(const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1395   miopenStatus_t SetBatchNormInferenceArgs(const void* scale,
1396                                            const void* offset, const void* mean,
1397                                            const void* variance,
1398                                            double epsilon) {
1399     float alpha = 1.0;
1400     float beta = 0.0;
1401     return ScopedFusionPlanBase::SetBatchNormInferenceArgs(
1402         k_batchnorm_op_idx, &alpha, &beta, scale, offset, mean, variance,
1403         epsilon);
1404   }
1405 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1406   miopenStatus_t SetActivationForwardArgs(
1407       ScopedActivationDescriptor& activation_descriptor) {
1408     float alpha = 1.0;
1409     float beta = 0.0;
1410 
1411     return ScopedFusionPlanBase::SetActivationForwardArgs(
1412         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1413         activation_descriptor.beta_, activation_descriptor.gamma_);
1414   }
1415 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1416   uint64 GetFusionOpHashValue(
1417       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1418       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1419       ScopedActivationDescriptor& activation_descriptor) {
1420     uint64 hash_value = tensorflow::Hash64("BatchNormActivationInference");
1421 
1422     hash_value = tensorflow::Hash64Combine(
1423         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1424 
1425     hash_value =
1426         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1427 
1428     hash_value = tensorflow::Hash64Combine(
1429         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1430 
1431     hash_value = tensorflow::Hash64Combine(
1432         hash_value, activation_descriptor.GetHashValue());
1433     return hash_value;
1434   }
1435 
1436  private:
1437   const int k_batchnorm_op_idx = 0;
1438   const int k_actv_op_idx = 1;
1439 
1440   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationInference);
1441 };
1442 
1443 // class to represent the BatchNorm+Activation (training-forward) fusion plan
1444 class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase {
1445  public:
ScopedFusionPlanBatchNormActivationForward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1446   ScopedFusionPlanBatchNormActivationForward(
1447       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1448       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1449       ScopedActivationDescriptor& activation_descriptor)
1450       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1451                              input_descriptor) {
1452     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1453                                        scale_offset_mean_variance_descriptor,
1454                                        activation_descriptor);
1455 
1456     bool is_compiled = CachedFusionPlans::FindOrCreate(
1457         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1458 
1459     if (!is_compiled) {
1460       miopenFusionOpDescriptor_t batchnorm_op;
1461       auto status = wrap::miopenCreateOpBatchNormForward(
1462           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1463           true /* runningMeanVariance */);
1464 
1465       if (status != miopenStatusSuccess) {
1466         LOG(FATAL) << "call to miopenCreateOpBatchNormForward failed: "
1467                    << ToString(status);
1468       }
1469 
1470       miopenFusionOpDescriptor_t actv_op;
1471       status = wrap::miopenCreateOpActivationForward(
1472           fusion_plan_, &actv_op,
1473           activation_descriptor.miopen_activation_mode_);
1474       if (status != miopenStatusSuccess) {
1475         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1476                    << ToString(status);
1477       }
1478 
1479       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1480       if (status != miopenStatusSuccess) {
1481         VLOG(2) << "call to miopenCompileFusionPlan (BnA forward) failed: "
1482                 << ToString(status);
1483 
1484         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1485       } else {
1486         VLOG(2) << "Fusion Plan compile succedded (BnA forward) ";
1487         fusion_plan_compiled_ = true;
1488       }
1489     } else {
1490       // fusion plan was already compiled...check whether it failed to compile
1491       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1492     }
1493   }
1494 
SetBatchNormForwardArgs(const void * scale,const void * offset,void * batch_mean,void * batch_var,void * saved_mean,void * saved_var,double epsilon)1495   miopenStatus_t SetBatchNormForwardArgs(const void* scale, const void* offset,
1496                                          void* batch_mean, void* batch_var,
1497                                          void* saved_mean, void* saved_var,
1498                                          double epsilon) {
1499     float alpha = 1.0;
1500     float beta = 0.0;
1501     return ScopedFusionPlanBase::SetBatchNormForwardArgs(
1502         k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var,
1503         saved_mean, saved_var, epsilon, /*exponential_average_factor=*/1.0);
1504   }
1505 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1506   miopenStatus_t SetActivationForwardArgs(
1507       ScopedActivationDescriptor& activation_descriptor) {
1508     float alpha = 1.0;
1509     float beta = 0.0;
1510 
1511     return ScopedFusionPlanBase::SetActivationForwardArgs(
1512         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1513         activation_descriptor.beta_, activation_descriptor.gamma_);
1514   }
1515 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1516   uint64 GetFusionOpHashValue(
1517       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1518       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1519       ScopedActivationDescriptor& activation_descriptor) {
1520     uint64 hash_value = tensorflow::Hash64("BatchNormActivationForward");
1521 
1522     hash_value = tensorflow::Hash64Combine(
1523         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1524 
1525     hash_value =
1526         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1527 
1528     hash_value = tensorflow::Hash64Combine(
1529         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1530 
1531     hash_value = tensorflow::Hash64Combine(
1532         hash_value, activation_descriptor.GetHashValue());
1533     return hash_value;
1534   }
1535 
1536  private:
1537   const int k_batchnorm_op_idx = 0;
1538   const int k_actv_op_idx = 1;
1539 
1540   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationForward);
1541 };
1542 
1543 // class to represent the BatchNorm+Activation (training-backward) fusion plan
1544 class ScopedFusionPlanBatchNormActivationBackward
1545     : public ScopedFusionPlanBase {
1546  public:
ScopedFusionPlanBatchNormActivationBackward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1547   ScopedFusionPlanBatchNormActivationBackward(
1548       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1549       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1550       ScopedActivationDescriptor& activation_descriptor)
1551       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1552                              input_descriptor) {
1553     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1554                                        scale_offset_mean_variance_descriptor,
1555                                        activation_descriptor);
1556 
1557     bool is_compiled = CachedFusionPlans::FindOrCreate(
1558         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1559 
1560     if (!is_compiled) {
1561       miopenFusionOpDescriptor_t batchnorm_op;
1562       auto status = wrap::miopenCreateOpBatchNormBackward(
1563           fusion_plan_, &batchnorm_op, miopenBNSpatial);
1564 
1565       if (status != miopenStatusSuccess) {
1566         LOG(FATAL) << "call to miopenCreateOpBatchNormBackward failed: "
1567                    << ToString(status);
1568       }
1569 
1570       miopenFusionOpDescriptor_t actv_op;
1571       status = wrap::miopenCreateOpActivationBackward(
1572           fusion_plan_, &actv_op,
1573           activation_descriptor.miopen_activation_mode_);
1574       if (status != miopenStatusSuccess) {
1575         LOG(FATAL) << "call to miopenCreateOpActivationBackward failed: "
1576                    << ToString(status);
1577       }
1578 
1579       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1580       if (status != miopenStatusSuccess) {
1581         VLOG(2) << "call to miopenCompileFusionPlan (BnA backward) failed: "
1582                 << ToString(status);
1583 
1584         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1585       } else {
1586         VLOG(2) << "Fusion Plan compile succedded (BnA backward) ";
1587         fusion_plan_compiled_ = true;
1588       }
1589     } else {
1590       // fusion plan was already compiled...check whether it failed to compile
1591       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1592     }
1593   }
1594 
SetBatchNormBackwardArgs(const void * x,const void * scale,const void * offset,const void * saved_mean,const void * saved_var,void * scale_grad,void * offset_grad)1595   miopenStatus_t SetBatchNormBackwardArgs(const void* x, const void* scale,
1596                                           const void* offset,
1597                                           const void* saved_mean,
1598                                           const void* saved_var,
1599                                           void* scale_grad, void* offset_grad) {
1600     float alpha = 1.0;
1601     float beta = 0.0;
1602 
1603     return ScopedFusionPlanBase::SetBatchNormBackwardArgs(
1604         k_batchnorm_op_idx, &alpha, &beta, x, scale, offset, scale_grad,
1605         offset_grad, saved_mean, saved_var);
1606   }
1607 
SetActivationBackwardArgs(ScopedActivationDescriptor & activation_descriptor,const void * y)1608   miopenStatus_t SetActivationBackwardArgs(
1609       ScopedActivationDescriptor& activation_descriptor, const void* y) {
1610     float alpha = 1.0;
1611     float beta = 0.0;
1612 
1613     return ScopedFusionPlanBase::SetActivationBackwardArgs(
1614         k_actv_op_idx, &alpha, &beta, y, activation_descriptor.alpha_,
1615         activation_descriptor.beta_, activation_descriptor.gamma_);
1616   }
1617 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1618   uint64 GetFusionOpHashValue(
1619       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1620       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1621       ScopedActivationDescriptor& activation_descriptor) {
1622     uint64 hash_value = tensorflow::Hash64("BatchNormActivationBackward");
1623 
1624     hash_value = tensorflow::Hash64Combine(
1625         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1626 
1627     hash_value =
1628         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1629 
1630     hash_value = tensorflow::Hash64Combine(
1631         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1632 
1633     hash_value = tensorflow::Hash64Combine(
1634         hash_value, activation_descriptor.GetHashValue());
1635     return hash_value;
1636   }
1637 
1638  private:
1639   const int k_batchnorm_op_idx = 0;
1640   const int k_actv_op_idx = 1;
1641 
1642   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationBackward);
1643 };
1644 
1645 namespace {
ToMIOpenDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1646 miopenDataType_t ToMIOpenDataType(
1647     dnn::DataType data_type,
1648     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1649   switch (data_type) {
1650     case dnn::DataType::kFloat:
1651       return miopenFloat;
1652     case dnn::DataType::kHalf:
1653       return miopenHalf;
1654     case dnn::DataType::kDouble:
1655     default:
1656       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1657   }
1658 }
1659 
ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode)1660 miopenRNNInputMode_t ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode) {
1661   switch (input_mode) {
1662     case dnn::RnnInputMode::kRnnLinearSkip:
1663       return miopenRNNlinear;
1664     case dnn::RnnInputMode::kRnnSkipInput:
1665       return miopenRNNskip;
1666     default:
1667       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1668   }
1669 }
1670 
ToMIOpenRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1671 miopenRNNDirectionMode_t ToMIOpenRnnDirectionMode(
1672     dnn::RnnDirectionMode direction_mode) {
1673   switch (direction_mode) {
1674     case dnn::RnnDirectionMode::kRnnUnidirectional:
1675       return miopenRNNunidirection;
1676     case dnn::RnnDirectionMode::kRnnBidirectional:
1677       return miopenRNNbidirection;
1678     default:
1679       LOG(FATAL) << "Invalid RNN direction mode: "
1680                  << static_cast<int>(direction_mode);
1681   }
1682 }
1683 
ToMIOpenRnnMode(dnn::RnnMode rnn_mode)1684 miopenRNNMode_t ToMIOpenRnnMode(dnn::RnnMode rnn_mode) {
1685   switch (rnn_mode) {
1686     case dnn::RnnMode::kRnnRelu:
1687       return miopenRNNRELU;
1688     case dnn::RnnMode::kRnnTanh:
1689       return miopenRNNTANH;
1690     case dnn::RnnMode::kRnnLstm:
1691       return miopenLSTM;
1692     case dnn::RnnMode::kRnnGru:
1693       return miopenGRU;
1694     default:
1695       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1696   }
1697 }
1698 
1699 template <typename Base>
1700 class MixinBase : public Base {};
1701 template <>
1702 class MixinBase<void> {};
1703 
1704 }  // namespace
1705 
1706 #define RETURN_IF_MIOPEN_ERROR(STATUS, ...)                              \
1707   if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) {               \
1708     string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
1709     SetFailure(port::Status(port::error::UNKNOWN, error_msg));           \
1710     LOG(ERROR) << error_msg;                                             \
1711     return;                                                              \
1712   }
1713 
1714 template <typename Base>
1715 class MIOpenDescriptorCommon : public MixinBase<Base> {
1716  public:
ok() const1717   bool ok() const { return status_.ok(); }
Status() const1718   port::Status Status() const { return status_; }
1719 
1720  protected:
SetFailure(const port::Status & status)1721   void SetFailure(const port::Status& status) { status_.Update(status); }
1722   port::Status status_;
1723 };
1724 
1725 class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon<void> {
1726  public:
1727   typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
1728   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1729   MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,
1730                             const MIOpenRnnDescriptor& rnn_desc);
~MIOpenRnnParamsDescriptor()1731   ~MIOpenRnnParamsDescriptor() {
1732     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1733     RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy RNN tensor descriptor");
1734   }
handle() const1735   miopenTensorDescriptor_t handle() const {
1736     if (!ok()) return nullptr;
1737     return handle_;
1738   }
params_size_in_bytes() const1739   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1740   ParamsRegions params_weights() const {
1741     if (!ok()) return ParamsRegions();
1742     return weights_;
1743   }
params_biases() const1744   ParamsRegions params_biases() const {
1745     if (!ok()) return ParamsRegions();
1746     return biases_;
1747   }
1748 
1749  private:
1750   int GetRegionCountPerLayer() const;
1751   miopenTensorDescriptor_t handle_;
1752   const MIOpenRnnDescriptor* rnn_desc_;
1753   int64 params_size_in_bytes_;
1754   ParamsRegions weights_;
1755   ParamsRegions biases_;
1756   port::Status status_;
1757   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnParamsDescriptor);
1758 };
1759 
1760 class MIOpenRnnDescriptor : public MIOpenDescriptorCommon<dnn::RnnDescriptor> {
1761  public:
MIOpenRnnDescriptor(miopenHandle_t miopen_handle,int num_layers,int hidden_size,int input_size,miopenRNNInputMode_t input_mode,miopenRNNDirectionMode_t direction_mode,miopenRNNMode_t rnn_mode,miopenDataType_t data_type,float dropout,uint64 seed,ScratchAllocator * state_allocator)1762   MIOpenRnnDescriptor(miopenHandle_t miopen_handle, int num_layers,
1763                       int hidden_size, int input_size,
1764                       miopenRNNInputMode_t input_mode,
1765                       miopenRNNDirectionMode_t direction_mode,
1766                       miopenRNNMode_t rnn_mode, miopenDataType_t data_type,
1767                       float dropout, uint64 seed,
1768                       ScratchAllocator* state_allocator)
1769       : rnn_desc_(nullptr),
1770         num_layers_(num_layers),
1771         hidden_size_(hidden_size),
1772         input_size_(input_size),
1773         input_mode_(input_mode),
1774         direction_mode_(direction_mode),
1775         rnn_mode_(rnn_mode),
1776         data_type_(data_type) {
1777     // Create the RNN handle
1778     auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_);
1779     RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor");
1780     status = wrap::miopenSetRNNDescriptor(
1781         rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
1782         num_layers /*numLayers*/, input_mode /*inputMode*/,
1783         direction_mode /*direction*/, rnn_mode /*mode*/,
1784         miopenRNNwithBias /*biasMode*/, miopenRNNdefault /*algo*/,
1785         data_type /*dataType*/);
1786     RETURN_IF_MIOPEN_ERROR(status, "Unable to update RNN descriptor");
1787     // Create the params handle.
1788     miopen_params_desc_.reset(
1789         new MIOpenRnnParamsDescriptor(miopen_handle, *this));
1790     if (!miopen_params_desc_->ok()) {
1791       SetFailure(miopen_params_desc_->Status());
1792       return;
1793     }
1794   }
~MIOpenRnnDescriptor()1795   ~MIOpenRnnDescriptor() override {
1796     if (rnn_desc_) {
1797       auto status = wrap::miopenDestroyRNNDescriptor(rnn_desc_);
1798       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN descriptor");
1799     }
1800   }
handle() const1801   miopenRNNDescriptor_t handle() const {
1802     if (!ok()) return nullptr;
1803     return rnn_desc_;
1804   }
num_layers() const1805   int num_layers() const { return num_layers_; }
hidden_size() const1806   int hidden_size() const { return hidden_size_; }
input_size() const1807   int input_size() const { return input_size_; }
input_mode() const1808   miopenRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1809   miopenRNNDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1810   miopenRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1811   miopenDataType_t data_type() const { return data_type_; }
ParamsSizeInBytes() const1812   int64 ParamsSizeInBytes() const override {
1813     return miopen_params_desc_->params_size_in_bytes();
1814   }
params_handle() const1815   miopenTensorDescriptor_t params_handle() const {
1816     if (!miopen_params_desc_) return nullptr;
1817     return miopen_params_desc_->handle();
1818   }
ParamsWeightRegions() const1819   ParamsRegions ParamsWeightRegions() const override {
1820     if (!ok()) return ParamsRegions();
1821     return miopen_params_desc_->params_weights();
1822   }
ParamsBiasRegions() const1823   ParamsRegions ParamsBiasRegions() const override {
1824     if (!ok()) return ParamsRegions();
1825     return miopen_params_desc_->params_biases();
1826   }
1827 
1828  private:
1829   miopenRNNDescriptor_t rnn_desc_;
1830   int num_layers_;
1831   int hidden_size_;
1832   int input_size_;
1833   miopenRNNInputMode_t input_mode_;
1834   miopenRNNDirectionMode_t direction_mode_;
1835   miopenRNNMode_t rnn_mode_;
1836   miopenDataType_t data_type_;
1837   port::Status status_;
1838   // no dropout in MIOpen.
1839   // std::unique_ptr<miopenDropoutDescriptor> miopen_dropout_desc_;
1840   std::unique_ptr<MIOpenRnnParamsDescriptor> miopen_params_desc_;
1841   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnDescriptor);
1842 };
1843 
1844 // Get ID of the internal parameter tensor.
1845 //
GetRegionCountPerLayer() const1846 int MIOpenRnnParamsDescriptor::GetRegionCountPerLayer() const {
1847   auto rnn_mode = rnn_desc_->rnn_mode();
1848   switch (rnn_mode) {
1849     case miopenRNNRELU:
1850     case miopenRNNTANH:
1851       return 2;
1852     case miopenLSTM:
1853       return 8;
1854     case miopenGRU:
1855       return 6;
1856     default:
1857       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1858   }
1859 }
1860 
1861 class MIOpenRnnSequenceTensorDescriptor
1862     : public MIOpenDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
1863  public:
MIOpenRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,miopenDataType_t data_type)1864   MIOpenRnnSequenceTensorDescriptor(int seq_length, int batch_size,
1865                                     int data_size, miopenDataType_t data_type)
1866       : seq_length_(seq_length),
1867         batch_size_(batch_size),
1868         data_size_(data_size),
1869         data_type_(data_type) {
1870     miopenTensorDescriptor_t handle = nullptr;
1871     if (seq_length <= 0) {
1872       string error_msg =
1873           absl::StrCat("sequence length must be positive: ", seq_length);
1874       LOG(ERROR) << error_msg;
1875       SetFailure(port::Status(port::error::UNKNOWN, error_msg));
1876       return;
1877     }
1878     auto status = wrap::miopenCreateTensorDescriptor(&handle);
1879     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1880     std::array<int, 2> dims = {{batch_size, data_size}};
1881     status = wrap::miopenSetTensorDescriptor(
1882         handle /*tensorDesc*/, data_type /*dataType*/, 2 /*nbDims*/,
1883         dims.data() /*dimA*/, nullptr /*strideA*/);
1884     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1885     // Replicate handle across the number of steps.
1886     handles_.assign(seq_length, handle);
1887   }
1888 
~MIOpenRnnSequenceTensorDescriptor()1889   ~MIOpenRnnSequenceTensorDescriptor() override {
1890     // Only the first one needs to be destroyed. All others are the same.
1891     auto status = wrap::miopenDestroyTensorDescriptor(handles_[0]);
1892     RETURN_IF_MIOPEN_ERROR(status,
1893                            "Failed to destroy sequence tensor descriptor");
1894   }
1895 
handles() const1896   const miopenTensorDescriptor_t* handles() const {
1897     if (!ok()) return nullptr;
1898     CHECK(!handles_.empty()) << "handles cannot be empty";
1899     return handles_.data();
1900   }
1901 
seq_length() const1902   int seq_length() const { return seq_length_; }
batch_size() const1903   int batch_size() const { return batch_size_; }
data_size() const1904   int data_size() const { return data_size_; }
1905 
1906  private:
1907   int seq_length_;
1908   int batch_size_;
1909   int data_size_;
1910   miopenDataType_t data_type_;
1911   std::vector<miopenTensorDescriptor_t> handles_;
1912   port::Status status_;
1913   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnSequenceTensorDescriptor);
1914 };
1915 
1916 class MIOpenRnnStateTensorDescriptor
1917     : public MIOpenDescriptorCommon<dnn::RnnStateTensorDescriptor> {
1918  public:
MIOpenRnnStateTensorDescriptor(int num_layers,int batch_size,int data_size,miopenDataType_t data_type)1919   MIOpenRnnStateTensorDescriptor(int num_layers, int batch_size, int data_size,
1920                                  miopenDataType_t data_type)
1921       : handle_(nullptr),
1922         num_layers_(num_layers),
1923         batch_size_(batch_size),
1924         data_size_(data_size),
1925         data_type_(data_type) {
1926     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
1927     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1928     std::array<int, 3> dims = {{num_layers, batch_size, data_size}};
1929     status = wrap::miopenSetTensorDescriptor(
1930         handle_ /*tensorDesc*/, data_type /*dataType*/, 3 /*nbDims*/,
1931         dims.data() /*dimA*/, nullptr /*strideA*/);
1932     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1933   }
1934 
~MIOpenRnnStateTensorDescriptor()1935   ~MIOpenRnnStateTensorDescriptor() override {
1936     if (!handle_) {
1937       auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1938       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN state tensor");
1939     }
1940   }
1941 
handle() const1942   miopenTensorDescriptor_t handle() const {
1943     if (!ok()) return nullptr;
1944     return handle_;
1945   }
num_layers() const1946   int num_layers() const { return num_layers_; }
batch_size() const1947   int batch_size() const { return batch_size_; }
data_size() const1948   int data_size() const { return data_size_; }
1949 
1950  private:
1951   miopenTensorDescriptor_t handle_;
1952   int num_layers_;
1953   int batch_size_;
1954   int data_size_;
1955   port::Status status_;
1956   miopenDataType_t data_type_;
1957   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnStateTensorDescriptor);
1958 };
1959 
1960 namespace {
1961 
1962 struct RnnModelDims {
1963   int num_layers = 0;
1964   int batch_size = 0;
1965   int seq_length = 0;
1966   int hidden_size = 0;
1967   int input_size = 0;
1968   int dir_count = 0;
1969 };
1970 
1971 template <class T>
ExtractAndCheckRnnForward(const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,RnnModelDims * model_dims)1972 bool ExtractAndCheckRnnForward(
1973     const MIOpenRnnDescriptor& rnn_desc,
1974     const MIOpenRnnSequenceTensorDescriptor& input_desc,
1975     const DeviceMemory<T>& input_data,
1976     const MIOpenRnnStateTensorDescriptor& input_h_desc,
1977     const DeviceMemory<T>& input_h_data,
1978     const MIOpenRnnStateTensorDescriptor& input_c_desc,
1979     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1980     const MIOpenRnnSequenceTensorDescriptor& output_desc,
1981     const DeviceMemory<T>& output_data,
1982     const MIOpenRnnStateTensorDescriptor& output_h_desc,
1983     const DeviceMemory<T>& output_h_data,
1984     const MIOpenRnnStateTensorDescriptor& output_c_desc,
1985     const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
1986   // extract model parameters
1987   model_dims->num_layers = rnn_desc.num_layers();
1988   model_dims->batch_size = input_desc.batch_size();
1989   model_dims->seq_length = input_desc.seq_length();
1990   model_dims->hidden_size = rnn_desc.hidden_size();
1991   model_dims->input_size = input_desc.data_size();
1992   model_dims->dir_count =
1993       (rnn_desc.direction_mode() == miopenRNNbidirection) ? 2 : 1;
1994 
1995   // check parameters
1996   if (!(input_h_desc.num_layers() ==
1997             model_dims->num_layers * model_dims->dir_count &&
1998         input_h_desc.batch_size() == model_dims->batch_size &&
1999         input_h_desc.data_size() == model_dims->hidden_size)) {
2000     LOG(ERROR) << "Invalid input_h shape";
2001     return false;
2002   }
2003   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
2004         input_h_desc.batch_size() == input_c_desc.batch_size() &&
2005         input_h_desc.data_size() == input_c_desc.data_size())) {
2006     LOG(ERROR) << "Invalid input_c shape";
2007     return false;
2008   }
2009   if (!(output_desc.seq_length() == model_dims->seq_length &&
2010         output_desc.batch_size() == model_dims->batch_size &&
2011         output_desc.data_size() ==
2012             model_dims->hidden_size * model_dims->dir_count)) {
2013     LOG(ERROR) << "Invalid output shape";
2014     return false;
2015   }
2016   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
2017         input_h_desc.batch_size() == output_h_desc.batch_size() &&
2018         input_h_desc.data_size() == output_h_desc.data_size())) {
2019     LOG(ERROR) << "Invalid output_h shape";
2020     return false;
2021   }
2022   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
2023         input_h_desc.batch_size() == output_c_desc.batch_size() &&
2024         input_h_desc.data_size() == output_c_desc.data_size())) {
2025     LOG(ERROR) << "Invalid output_h shape";
2026     return false;
2027   }
2028 
2029   return true;
2030 }
2031 
CheckRNNParameterSize(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc)2032 bool CheckRNNParameterSize(
2033     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc,
2034     const MIOpenRnnSequenceTensorDescriptor& input_desc) {
2035   size_t params_size_in_bytes = 0;
2036   auto status = wrap::miopenGetRNNParamsSize(
2037       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2038       input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
2039       rnn_desc.data_type() /*dataType*/);
2040   if (status != miopenStatusSuccess) {
2041     LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
2042     return false;
2043   }
2044   return static_cast<int64>(params_size_in_bytes) ==
2045          rnn_desc.ParamsSizeInBytes();
2046 }
2047 
CreateRnnWorkspace(Stream * stream,miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * workspace)2048 bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle,
2049                         const MIOpenRnnDescriptor& rnn_desc,
2050                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
2051                         ScratchAllocator* workspace_allocator,
2052                         DeviceMemory<uint8>* workspace) {
2053   // Query the workspace size.
2054   size_t workspace_size_in_bytes = 0;
2055   auto status = wrap::miopenGetRNNWorkspaceSize(
2056       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2057       input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
2058       &workspace_size_in_bytes /*sizeInBytes*/);
2059   if (status != miopenStatusSuccess) {
2060     LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
2061     return false;
2062   }
2063   // Allocate the workspace.
2064   if (workspace_size_in_bytes > 0) {
2065     auto allocated =
2066         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
2067     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
2068       LOG(ERROR) << "Failed to allocate RNN workspace";
2069 
2070       return false;
2071     }
2072     stream->ThenMemZero(workspace, workspace_size_in_bytes);
2073   } else {
2074     *workspace = DeviceMemory<uint8>();
2075   }
2076   return true;
2077 }
2078 
2079 }  // namespace
2080 
2081 template <class T>
DoRnnForwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)2082 bool MIOpenSupport::DoRnnForwardImpl(
2083     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2084     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2085     const DeviceMemory<T>& input_data,
2086     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2087     const DeviceMemory<T>& input_h_data,
2088     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2089     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2090     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2091     DeviceMemory<T>* output_data,
2092     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2093     DeviceMemory<T>* output_h_data,
2094     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2095     DeviceMemory<T>* output_c_data, bool is_training,
2096     ScratchAllocator* reserve_space_allocator,
2097     ScratchAllocator* workspace_allocator) {
2098   // extract model parameters
2099   RnnModelDims model_dims;
2100   bool res = ExtractAndCheckRnnForward(
2101       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2102       input_c_desc, input_c_data, params, output_desc, *output_data,
2103       output_h_desc, *output_h_data, output_c_desc, *output_c_data,
2104       &model_dims);
2105   if (!res) {
2106     LOG(ERROR) << "Invalid parameters for RNN Model";
2107     return false;
2108   }
2109 
2110   auto miopen = miopen_->GetHandle(parent_, stream);
2111 
2112   // check params size
2113 
2114   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2115     LOG(ERROR) << "Invalid parameters";
2116     return false;
2117   }
2118 
2119   // create the workspace
2120   DeviceMemory<uint8> workspace;
2121   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2122                           workspace_allocator, &workspace)) {
2123     LOG(ERROR) << "Unable to create rnn workspace";
2124 
2125     return false;
2126   }
2127 
2128   // query the reserve space size
2129   // allocate the reserve space
2130   DeviceMemory<uint8> reserve_space;
2131   if (is_training) {
2132     size_t reserve_space_size_in_bytes = 0;
2133     auto status = wrap::miopenGetRNNTrainingReserveSize(
2134         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2135         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2136         &reserve_space_size_in_bytes /*sizeInBytes*/);
2137     if (status != miopenStatusSuccess) {
2138       LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
2139       return false;
2140     }
2141 
2142     if (reserve_space_size_in_bytes > 0) {
2143       auto allocated =
2144           reserve_space_allocator->AllocateBytes(reserve_space_size_in_bytes);
2145       if (!allocated.ok() ||
2146           (reserve_space = allocated.ValueOrDie()) == nullptr) {
2147         LOG(ERROR) << "Fail to allocate RNN reserve space";
2148         return false;
2149       }
2150       stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes);
2151     }
2152   }
2153 
2154   // make the forward call
2155   if (!is_training) {
2156     auto status = wrap::miopenRNNForwardInference(
2157         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2158         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2159         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2160         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2161         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2162         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2163         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2164         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2165         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2166         workspace.size() /*workSpaceSizeInBytes*/);
2167 
2168     if (status != miopenStatusSuccess) {
2169       LOG(ERROR) << "Failed to call miopenRNNForwardInference: "
2170                  << ToString(status);
2171       return false;
2172     }
2173   } else {
2174     auto status = wrap::miopenRNNForwardTraining(
2175         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2176         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2177         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2178         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2179         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2180         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2181         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2182         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2183         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2184         workspace.size() /*workSpaceSizeInBytes*/,
2185         reserve_space.opaque() /*reserveSpace*/,
2186         reserve_space.size() /*reserveSpaceSizeInBytes*/);
2187     if (status != miopenStatusSuccess) {
2188       LOG(ERROR) << "Failed to call miopenRNNForwardTraining"
2189                  << ToString(status);
2190       return false;
2191     }
2192   }
2193   return true;
2194 }
2195 
2196 template <class T>
DoRnnBackwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)2197 bool MIOpenSupport::DoRnnBackwardImpl(
2198     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2199     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2200     const DeviceMemory<T>& input_data,
2201     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2202     const DeviceMemory<T>& input_h_data,
2203     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2204     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2205     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2206     const DeviceMemory<T>& output_data,
2207     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2208     const DeviceMemory<T>& output_h_data,
2209     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2210     const DeviceMemory<T>& output_c_data,
2211     const DeviceMemory<T>& output_backprop_data,
2212     const DeviceMemory<T>& output_h_backprop_data,
2213     const DeviceMemory<T>& output_c_backprop_data,
2214     DeviceMemory<T>* input_backprop_data,
2215     DeviceMemory<T>* input_h_backprop_data,
2216     DeviceMemory<T>* input_c_backprop_data,
2217     DeviceMemory<T>* params_backprop_data,
2218     DeviceMemory<uint8>* reserve_space_data,
2219     ScratchAllocator* workspace_allocator) {
2220   // extract model parameters
2221   RnnModelDims model_dims;
2222   bool res = ExtractAndCheckRnnForward(
2223       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2224       input_c_desc, input_c_data, params, output_desc, output_data,
2225       output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
2226   if (!res) {
2227     LOG(ERROR) << "Invalid parameters for RNN Model";
2228     return false;
2229   }
2230 
2231   auto miopen = miopen_->GetHandle(parent_, stream);
2232 
2233   // check params size
2234 
2235   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2236     LOG(ERROR) << "Invalid parameters";
2237     return false;
2238   }
2239 
2240   // create the workspace
2241   DeviceMemory<uint8> workspace;
2242   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2243                           workspace_allocator, &workspace)) {
2244     LOG(ERROR) << "Unable to create rnn workspace";
2245     return false;
2246   }
2247 
2248   // workaround for missing initialization support in MIOpen.
2249   // TODO: remove this when MIOpen is ready.
2250   auto type_size = std::is_same<T, Eigen::half>::value ? 2 : sizeof(T);
2251   auto size_data = input_desc.seq_length() * input_desc.batch_size() *
2252                    input_desc.data_size();
2253   if ((size_data > 0) && (input_backprop_data->opaque() != nullptr))
2254     stream->ThenMemZero(input_backprop_data, size_data * type_size);
2255 
2256   size_data = input_h_desc.num_layers() * input_h_desc.batch_size() *
2257               input_h_desc.data_size();
2258   if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr))
2259     stream->ThenMemZero(input_h_backprop_data, size_data * type_size);
2260 
2261   size_data = input_c_desc.num_layers() * input_c_desc.batch_size() *
2262               input_c_desc.data_size();
2263   if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr))
2264     stream->ThenMemZero(input_c_backprop_data, size_data * type_size);
2265 
2266   // make the backward data call
2267   auto status = wrap::miopenRNNBackwardData(
2268       miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2269       model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
2270       output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
2271       output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
2272       output_h_backprop_data.opaque() /*dhy*/,
2273       output_c_desc.handle() /*dcyDesc*/,
2274       output_c_backprop_data.opaque() /*dcy*/,
2275       rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
2276       input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
2277       input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
2278       input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
2279       input_h_desc.handle() /*dhxDesc*/,
2280       input_h_backprop_data->opaque() /*dhx*/,
2281       input_c_desc.handle() /*dcxDesc*/,
2282       input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
2283       workspace.size() /*workSpaceSizeInBytes*/,
2284       reserve_space_data->opaque() /*reserveSpace*/,
2285       reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2286   if (status != miopenStatusSuccess) {
2287     LOG(ERROR) << "Failed to call miopenRNNBackwardData: " << ToString(status);
2288     return false;
2289   }
2290 
2291   if (params_backprop_data != nullptr) {
2292     // Clear the dw to zeros.
2293     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2294     // make the backward weight call
2295     status = wrap::miopenRNNBackwardWeights(
2296         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2297         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2298         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2299         input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
2300         output_data.opaque() /*y*/, rnn_desc.params_handle() /*dwDesc*/,
2301         params_backprop_data->opaque() /*dw*/, workspace.opaque() /*workspace*/,
2302         workspace.size() /*workSpaceSizeInBytes*/,
2303         reserve_space_data->opaque() /*reserveSpace*/,
2304         reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2305     if (status != miopenStatusSuccess) {
2306       LOG(ERROR) << "Failed to call miopenRNNBackwardWeights: "
2307                  << ToString(status);
2308       return false;
2309     }
2310   }
2311 
2312   return true;
2313 }
2314 
MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc)2315 MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
2316     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc)
2317     : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
2318   miopenTensorDescriptor_t input_desc = nullptr;
2319   {
2320     // Query the params size.
2321     auto status = wrap::miopenCreateTensorDescriptor(&input_desc);
2322     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create tensor descriptor");
2323     std::array<int, 2> dims = {{1, rnn_desc.input_size()}};
2324     status = wrap::miopenSetTensorDescriptor(
2325         input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
2326         2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/);
2327     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to set tensor descriptor");
2328 
2329     size_t params_size = 0;
2330     status = wrap::miopenGetRNNParamsSize(
2331         miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2332         input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
2333         rnn_desc.data_type() /*dataType*/);
2334     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to get RNN parameter size");
2335     params_size_in_bytes_ = static_cast<int64>(params_size);
2336   }
2337 
2338   {
2339     // Create the params descriptor.
2340     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
2341     RETURN_IF_MIOPEN_ERROR(status,
2342                            "MIOpen fails to create RNN params descriptor");
2343     status = wrap::miopenGetRNNParamsDescriptor(miopen_handle,
2344                                                 rnn_desc.handle(), input_desc,
2345                                                 handle_, rnn_desc.data_type());
2346     RETURN_IF_MIOPEN_ERROR(status,
2347                            "MIOpen fails to update RNN filter descriptor");
2348   }
2349   {
2350     // Release the dummy input tensor descriptor.
2351     auto status = wrap::miopenDestroyTensorDescriptor(input_desc);
2352     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to destroy tensor descriptor");
2353   }
2354 }
2355 
2356 class MIOpenCTCLossDescriptor {
2357  public:
MIOpenCTCLossDescriptor(miopenDataType_t data_type)2358   explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
2359     auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
2360     if (status != miopenStatusSuccess) {
2361       LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
2362                  << ToString(status);
2363     }
2364 
2365     bool apply_softmax_layer = true;
2366     status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
2367                                               apply_softmax_layer);
2368     if (status != miopenStatusSuccess) {
2369       LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
2370                  << ToString(status);
2371     }
2372   }
2373 
~MIOpenCTCLossDescriptor()2374   ~MIOpenCTCLossDescriptor() {
2375     auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
2376     if (status != miopenStatusSuccess) {
2377       LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2378                  << ToString(status);
2379     }
2380   }
2381 
handle() const2382   miopenCTCLossDescriptor_t handle() const { return handle_; }
2383 
2384  private:
2385   miopenCTCLossDescriptor_t handle_;  // Owned
2386 
2387   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
2388 };
2389 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2390 port::Status MIOpenSupport::DoPrepareForCtcLoss(
2391     Stream* stream, dnn::DataType element_type,
2392     const dnn::RnnStateTensorDescriptor& probs_desc,
2393     const dnn::RnnStateTensorDescriptor& grads_desc,
2394     absl::Span<const int> labels_data,
2395     absl::Span<const int> labels_lengths_data,
2396     absl::Span<const int> input_lengths_data,
2397     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2398     int* ctc_loss_algo_id) {
2399   auto miopen = miopen_->GetHandle(parent_, stream);
2400 
2401   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2402 
2403   // Query the workspace size.
2404   size_t workspace_size_in_bytes = 0;
2405 
2406   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2407       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2408 
2409   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2410       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2411 
2412   auto status = wrap::miopenGetCTCLossWorkspaceSize(
2413       miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
2414       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2415       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
2416       &workspace_size_in_bytes);
2417 
2418   if (status != miopenStatusSuccess) {
2419     LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2420                << ToString(status);
2421     return port::InternalError(
2422         "Failed to determine scratch memory size for MIOpen CTC Loss");
2423   }
2424 
2425   *scratch_memory = DeviceMemory<uint8>();
2426 
2427   // Allocate the workspace.
2428   if (workspace_size_in_bytes != 0) {
2429     if (scratch_allocator == nullptr) {
2430       return port::InternalError(
2431           absl::StrCat("An allocator must be specified when scratch memory is "
2432                        "needed"));
2433     }
2434     auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
2435     if (scratch_or.ok()) {
2436       *scratch_memory = scratch_or.ValueOrDie();
2437     } else {
2438       LOG(ERROR)
2439           << "Failed to allocate scratch memory - "
2440           << scratch_or.status().error_message() << "\n"
2441           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2442              "larger number (e.g. 8192) to increase the max memory limit.\n"
2443           << "\tIncreasing the max memory limit might help resolve this "
2444              "error";
2445       return port::InternalError(absl::StrCat(
2446           "Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
2447           workspace_size_in_bytes));
2448     }
2449   }
2450 
2451   return port::Status::OK();
2452 }
2453 
DoCtcLossImpl(Stream * stream,const MIOpenRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const MIOpenRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const MIOpenCTCLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2454 port::Status MIOpenSupport::DoCtcLossImpl(
2455     Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
2456     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2457     absl::Span<const int> labels_lengths_data,
2458     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2459     const MIOpenRnnStateTensorDescriptor& grads_desc,
2460     DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
2461     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2462   auto miopen = miopen_->GetHandle(parent_, stream);
2463 
2464   int kNumTimestamps = probs_desc.num_layers();
2465   int kBatchSize = probs_desc.batch_size();
2466   int kNumLabels = probs_desc.data_size();
2467   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2468   (void)total_size;
2469 
2470   auto status = wrap::miopenCTCLoss(
2471       miopen.handle(), probs_desc.handle(), probs_data.opaque(),
2472       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2473       costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
2474       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
2475       scratch_memory.opaque(), scratch_memory.size());
2476   if (status != miopenStatusSuccess) {
2477     LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
2478     return port::InternalError("Failure during MIOpen CTC Loss");
2479   }
2480 
2481   return port::Status::OK();
2482 }
2483 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2484 port::Status MIOpenSupport::DoCtcLoss(
2485     Stream* stream, dnn::DataType element_type,
2486     const dnn::RnnStateTensorDescriptor& probs_desc,
2487     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2488     absl::Span<const int> labels_lengths_data,
2489     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2490     const dnn::RnnStateTensorDescriptor& grads_desc,
2491     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
2492     int ctc_loss_algo_id) {
2493   // Current MIOPen CTC Loss only supports the float datatype
2494   if (element_type != dnn::DataType::kFloat) {
2495     return port::Status(port::error::INVALID_ARGUMENT,
2496                         "MIOpenCTCLossDescriptor is supported only when the "
2497                         "DataType is float");
2498   }
2499 
2500   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2501 
2502   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2503       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2504 
2505   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2506       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2507 
2508   return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
2509                        labels_lengths_data, input_lengths_data, costs_data,
2510                        miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
2511                        scratch_memory, ctc_loss_algo_id);
2512 }
2513 
2514 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2515 MIOpenSupport::createRnnDescriptor(
2516     int num_layers, int hidden_size, int input_size, int cell_size,
2517     int batch_size, dnn::RnnInputMode input_mode,
2518     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2519     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2520     float dropout, uint64 seed, ScratchAllocator* state_allocator,
2521     bool use_padded_io) {
2522   // ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is
2523   // not supported by MIOpen now.
2524   if (use_padded_io) {
2525     return port::Status(port::error::INVALID_ARGUMENT,
2526                         "ROCm MIOpen only supports packed input output.");
2527   }
2528 
2529   bool use_projection = cell_size != 0 && hidden_size < cell_size;
2530   if (use_projection) {
2531     return port::Status(
2532         port::error::INVALID_ARGUMENT,
2533         "ROCm MIOpen does not support RNN ProjectionLayers yet.");
2534   }
2535 
2536   auto miopen = miopen_->GetHandle(parent_, nullptr);
2537   std::unique_ptr<MIOpenRnnDescriptor> rnn_desc(new MIOpenRnnDescriptor(
2538       miopen.handle(), num_layers, hidden_size, input_size,
2539       ToMIOpenRnnInputMode(input_mode),
2540       ToMIOpenRnnDirectionMode(direction_mode), ToMIOpenRnnMode(rnn_mode),
2541       ToMIOpenDataType(data_type), dropout, seed, state_allocator));
2542   if (!rnn_desc->ok()) {
2543     return rnn_desc->Status();
2544   }
2545   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
2546       std::move(rnn_desc));
2547 }
2548 
2549 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)2550 MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
2551                                                  int data_size,
2552                                                  dnn::DataType data_type) {
2553   std::unique_ptr<MIOpenRnnSequenceTensorDescriptor> seq_desc(
2554       new MIOpenRnnSequenceTensorDescriptor(seq_length, batch_size, data_size,
2555                                             ToMIOpenDataType(data_type)));
2556   if (!seq_desc->ok()) {
2557     return seq_desc->Status();
2558   }
2559   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
2560       std::move(seq_desc));
2561 }
2562 
2563 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2564 MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2565                                               int data_size,
2566                                               dnn::DataType data_type) {
2567   std::unique_ptr<MIOpenRnnStateTensorDescriptor> state_desc(
2568       new MIOpenRnnStateTensorDescriptor(num_layer, batch_size, data_size,
2569                                          ToMIOpenDataType(data_type)));
2570   if (!state_desc->ok()) {
2571     return state_desc->Status();
2572   }
2573   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
2574       std::move(state_desc));
2575 }
2576 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2577 bool MIOpenSupport::DoRnnForward(
2578     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2579     const dnn::RnnSequenceTensorDescriptor& input_desc,
2580     const DeviceMemory<Eigen::half>& input_data,
2581     const dnn::RnnStateTensorDescriptor& input_h_desc,
2582     const DeviceMemory<Eigen::half>& input_h_data,
2583     const dnn::RnnStateTensorDescriptor& input_c_desc,
2584     const DeviceMemory<Eigen::half>& input_c_data,
2585     const DeviceMemory<Eigen::half>& params,
2586     const dnn::RnnSequenceTensorDescriptor& output_desc,
2587     DeviceMemory<Eigen::half>* output_data,
2588     const dnn::RnnStateTensorDescriptor& output_h_desc,
2589     DeviceMemory<Eigen::half>* output_h_data,
2590     const dnn::RnnStateTensorDescriptor& output_c_desc,
2591     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2592     ScratchAllocator* reserve_space_allocator,
2593     ScratchAllocator* workspace_allocator,
2594     dnn::ProfileResult* output_profile_result) {
2595   // ROCM TODO: output_profile_result is ignore for now
2596 
2597   const MIOpenRnnDescriptor& miopen_rnn_desc =
2598       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2599   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2600       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2601   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2602       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2603   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2604       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2605   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2606       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2607   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2608       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2609   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2610       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2611 
2612   return DoRnnForwardImpl<Eigen::half>(
2613       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2614       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2615       params, miopen_output_desc, output_data, miopen_output_h_desc,
2616       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2617       reserve_space_allocator, workspace_allocator);
2618 }
2619 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2620 bool MIOpenSupport::DoRnnForward(
2621     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2622     const dnn::RnnSequenceTensorDescriptor& input_desc,
2623     const DeviceMemory<float>& input_data,
2624     const dnn::RnnStateTensorDescriptor& input_h_desc,
2625     const DeviceMemory<float>& input_h_data,
2626     const dnn::RnnStateTensorDescriptor& input_c_desc,
2627     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2628     const dnn::RnnSequenceTensorDescriptor& output_desc,
2629     DeviceMemory<float>* output_data,
2630     const dnn::RnnStateTensorDescriptor& output_h_desc,
2631     DeviceMemory<float>* output_h_data,
2632     const dnn::RnnStateTensorDescriptor& output_c_desc,
2633     DeviceMemory<float>* output_c_data, bool is_training,
2634     ScratchAllocator* reserve_space_allocator,
2635     ScratchAllocator* workspace_allocator,
2636     dnn::ProfileResult* output_profile_result) {
2637   // ROCM TODO: output_profile_result is ignore for now
2638 
2639   const MIOpenRnnDescriptor& miopen_rnn_desc =
2640       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2641   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2642       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2643   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2644       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2645   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2646       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2647   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2648       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2649   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2650       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2651   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2652       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2653 
2654   return DoRnnForwardImpl<float>(
2655       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2656       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2657       params, miopen_output_desc, output_data, miopen_output_h_desc,
2658       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2659       reserve_space_allocator, workspace_allocator);
2660 }
2661 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2662 bool MIOpenSupport::DoRnnForward(
2663     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2664     const dnn::RnnSequenceTensorDescriptor& input_desc,
2665     const DeviceMemory<double>& input_data,
2666     const dnn::RnnStateTensorDescriptor& input_h_desc,
2667     const DeviceMemory<double>& input_h_data,
2668     const dnn::RnnStateTensorDescriptor& input_c_desc,
2669     const DeviceMemory<double>& input_c_data,
2670     const DeviceMemory<double>& params,
2671     const dnn::RnnSequenceTensorDescriptor& output_desc,
2672     DeviceMemory<double>* output_data,
2673     const dnn::RnnStateTensorDescriptor& output_h_desc,
2674     DeviceMemory<double>* output_h_data,
2675     const dnn::RnnStateTensorDescriptor& output_c_desc,
2676     DeviceMemory<double>* output_c_data, bool is_training,
2677     ScratchAllocator* reserve_space_allocator,
2678     ScratchAllocator* workspace_allocator,
2679     dnn::ProfileResult* output_profile_result) {
2680   LOG(ERROR) << "miopen does not support double type RNN fwd yet";
2681   return false;
2682 }
2683 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2684 bool MIOpenSupport::DoRnnBackward(
2685     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2686     const dnn::RnnSequenceTensorDescriptor& input_desc,
2687     const DeviceMemory<Eigen::half>& input_data,
2688     const dnn::RnnStateTensorDescriptor& input_h_desc,
2689     const DeviceMemory<Eigen::half>& input_h_data,
2690     const dnn::RnnStateTensorDescriptor& input_c_desc,
2691     const DeviceMemory<Eigen::half>& input_c_data,
2692     const DeviceMemory<Eigen::half>& params,
2693     const dnn::RnnSequenceTensorDescriptor& output_desc,
2694     const DeviceMemory<Eigen::half>& output_data,
2695     const dnn::RnnStateTensorDescriptor& output_h_desc,
2696     const DeviceMemory<Eigen::half>& output_h_data,
2697     const dnn::RnnStateTensorDescriptor& output_c_desc,
2698     const DeviceMemory<Eigen::half>& output_c_data,
2699     const DeviceMemory<Eigen::half>& output_backprop_data,
2700     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2701     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2702     DeviceMemory<Eigen::half>* input_backprop_data,
2703     DeviceMemory<Eigen::half>* input_h_backprop_data,
2704     DeviceMemory<Eigen::half>* input_c_backprop_data,
2705     DeviceMemory<Eigen::half>* params_backprop_data,
2706     DeviceMemory<uint8>* reserve_space_data,
2707     ScratchAllocator* workspace_allocator,
2708     dnn::ProfileResult* output_profile_result) {
2709   // ROCM TODO: output_profile_result is ignore for now
2710 
2711   const MIOpenRnnDescriptor& miopen_rnn_desc =
2712       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2713   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2714       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2715   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2716       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2717   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2718       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2719   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2720       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2721   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2722       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2723   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2724       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2725 
2726   return DoRnnBackwardImpl<Eigen::half>(
2727       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2728       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2729       params, miopen_output_desc, output_data, miopen_output_h_desc,
2730       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2731       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2732       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2733       reserve_space_data, workspace_allocator);
2734 }
2735 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2736 bool MIOpenSupport::DoRnnBackward(
2737     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2738     const dnn::RnnSequenceTensorDescriptor& input_desc,
2739     const DeviceMemory<float>& input_data,
2740     const dnn::RnnStateTensorDescriptor& input_h_desc,
2741     const DeviceMemory<float>& input_h_data,
2742     const dnn::RnnStateTensorDescriptor& input_c_desc,
2743     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2744     const dnn::RnnSequenceTensorDescriptor& output_desc,
2745     const DeviceMemory<float>& output_data,
2746     const dnn::RnnStateTensorDescriptor& output_h_desc,
2747     const DeviceMemory<float>& output_h_data,
2748     const dnn::RnnStateTensorDescriptor& output_c_desc,
2749     const DeviceMemory<float>& output_c_data,
2750     const DeviceMemory<float>& output_backprop_data,
2751     const DeviceMemory<float>& output_h_backprop_data,
2752     const DeviceMemory<float>& output_c_backprop_data,
2753     DeviceMemory<float>* input_backprop_data,
2754     DeviceMemory<float>* input_h_backprop_data,
2755     DeviceMemory<float>* input_c_backprop_data,
2756     DeviceMemory<float>* params_backprop_data,
2757     DeviceMemory<uint8>* reserve_space_data,
2758     ScratchAllocator* workspace_allocator,
2759     dnn::ProfileResult* output_profile_result) {
2760   // ROCM TODO: output_profile_result is ignore for now
2761 
2762   const MIOpenRnnDescriptor& miopen_rnn_desc =
2763       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2764   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2765       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2766   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2767       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2768   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2769       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2770   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2771       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2772   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2773       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2774   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2775       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2776 
2777   return DoRnnBackwardImpl<float>(
2778       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2779       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2780       params, miopen_output_desc, output_data, miopen_output_h_desc,
2781       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2782       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2783       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2784       reserve_space_data, workspace_allocator);
2785 }
2786 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2787 bool MIOpenSupport::DoRnnBackward(
2788     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2789     const dnn::RnnSequenceTensorDescriptor& input_desc,
2790     const DeviceMemory<double>& input_data,
2791     const dnn::RnnStateTensorDescriptor& input_h_desc,
2792     const DeviceMemory<double>& input_h_data,
2793     const dnn::RnnStateTensorDescriptor& input_c_desc,
2794     const DeviceMemory<double>& input_c_data,
2795     const DeviceMemory<double>& params,
2796     const dnn::RnnSequenceTensorDescriptor& output_desc,
2797     const DeviceMemory<double>& output_data,
2798     const dnn::RnnStateTensorDescriptor& output_h_desc,
2799     const DeviceMemory<double>& output_h_data,
2800     const dnn::RnnStateTensorDescriptor& output_c_desc,
2801     const DeviceMemory<double>& output_c_data,
2802     const DeviceMemory<double>& output_backprop_data,
2803     const DeviceMemory<double>& output_h_backprop_data,
2804     const DeviceMemory<double>& output_c_backprop_data,
2805     DeviceMemory<double>* input_backprop_data,
2806     DeviceMemory<double>* input_h_backprop_data,
2807     DeviceMemory<double>* input_c_backprop_data,
2808     DeviceMemory<double>* params_backprop_data,
2809     DeviceMemory<uint8>* reserve_space_data,
2810     ScratchAllocator* workspace_allocator,
2811     dnn::ProfileResult* output_profile_result) {
2812   LOG(ERROR) << "miopen does not support half type RNN bwd yet";
2813   return false;
2814 }
2815 
2816 // This is the context required to use the TF scratch allocator:
2817 struct MIOpenAllocatorContext {
MIOpenAllocatorContextstream_executor::gpu::MIOpenAllocatorContext2818   MIOpenAllocatorContext(ScratchAllocator* scratch_allocator, Stream* stream)
2819       : scratch_allocator_(scratch_allocator), stream_(stream) {}
2820 
2821   ScratchAllocator* scratch_allocator_;
2822   Stream* stream_;
2823 };
2824 
MIOpenAllocatorCallback(void * ctx,size_t size_in_bytes)2825 void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
2826   auto* mac = static_cast<MIOpenAllocatorContext*>(ctx);
2827   auto allocated = mac->scratch_allocator_->AllocateBytes(size_in_bytes);
2828 
2829   DeviceMemory<uint8> scratch;
2830   if (allocated.ok()) {
2831     scratch = allocated.ValueOrDie();
2832     return scratch.opaque();
2833   } else {
2834     return nullptr;
2835   }
2836 }
2837 
MIOpenDeallocatorCallback(void * ctx,void * mem)2838 void MIOpenDeallocatorCallback(void* ctx, void* mem) {
2839   // Don't need deallocator since the TensorFlow heap will automatically
2840   // reclaim the memory
2841 }
2842 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2843 port::Status MIOpenSupport::DoPrepareForConvolution(
2844     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2845     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2846     const dnn::FilterDescriptor& filter_descriptor,
2847     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2848     DeviceMemoryBase output_data,
2849     const dnn::ConvolutionDescriptor& convolution_descriptor,
2850     const dnn::AlgorithmConfig& algorithm_config,
2851     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2852     DeviceMemory<uint8>* scratch_memory) {
2853   absl::optional<dnn::AlgorithmDesc> input_algo_desc =
2854       algorithm_config.algorithm();
2855 
2856   assert(input_algo_desc.has_value());
2857 
2858   // An algorithm has been specified.
2859   *algorithm_desc = *input_algo_desc;
2860 
2861   assert(algorithm_config.scratch_size().has_value());
2862 
2863   size_t scratch_memory_size = *(algorithm_config.scratch_size());
2864 
2865   // allocate scratch memory
2866   if (scratch_memory_size != 0) {
2867     if (scratch_allocator == nullptr) {
2868       return port::InternalError(
2869           absl::StrCat("An allocator must be specified when scratch memory is "
2870                        "needed"));
2871     }
2872     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
2873     if (allocated.ok()) {
2874       *scratch_memory = allocated.ValueOrDie();
2875     } else {
2876       LOG(ERROR)
2877           << "Failed to allocate scratch memory - "
2878           << allocated.status().error_message() << "\n"
2879           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2880              "larger number (e.g. 8192) to increase the max memory limit.\n"
2881           << "\tIncreasing the max memory limit might help resolve this "
2882              "error";
2883       return port::InternalError(absl::StrCat(
2884           "Failed to allocate scratch memory of size: ", scratch_memory_size));
2885     }
2886   }
2887 
2888   return port::Status::OK();
2889 }
2890 
2891 // NOTE(keveman): Temporary data layout transformation until MIOpen supports
2892 // kBatchYXDepth for backward pass. This function allocates temporary memory,
2893 // lays out the source data into the temporary but in the kBatchDepthXY
2894 // layout, and returns the temporary memory. The caller is responsible for
2895 // deallocating the temporary. Since the allocation is done using Stream's
2896 // AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
2897 // deallocation.
2898 //
2899 // transform_scratch is populated with a legitimate temporary allocation iff
2900 // the original output data needs to be transformed.
MaybeTransformLayout(Stream * stream,miopenHandle_t handle_,int miopen_type,BatchDescriptor * output_descriptor,DeviceMemoryBase backward_output_data,std::unique_ptr<TemporaryDeviceMemory<uint8>> * transform_scratch)2901 static DeviceMemoryBase MaybeTransformLayout(
2902     Stream* stream, miopenHandle_t handle_,
2903     int miopen_type,  // Actually miopenDataType_t.
2904     BatchDescriptor* output_descriptor, DeviceMemoryBase backward_output_data,
2905     std::unique_ptr<TemporaryDeviceMemory<uint8>>* transform_scratch) {
2906   if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
2907     return backward_output_data;
2908   }
2909   CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
2910   *transform_scratch =
2911       stream->AllocateTemporaryArray<uint8>(backward_output_data.size())
2912           .ConsumeValueOrDie();
2913   BatchDescriptor transformed_output_descriptor;
2914   transformed_output_descriptor.CloneFrom(*output_descriptor);
2915   transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
2916   ScopedTensorDescriptor orig_out_back_nd{
2917       *output_descriptor, static_cast<miopenDataType_t>(miopen_type)};
2918   ScopedTensorDescriptor transformed_out_back_nd{
2919       transformed_output_descriptor,
2920       static_cast<miopenDataType_t>(miopen_type)};
2921 
2922   float alpha1 = 1.0f;
2923   float alpha2 = 0.0f;
2924   float beta = 0.0f;
2925   auto status = wrap::miopenOpTensor(
2926       handle_, miopenTensorOpAdd, &alpha1, orig_out_back_nd.handle(),
2927       backward_output_data.opaque(), &alpha2, orig_out_back_nd.handle(),
2928       backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
2929       (*transform_scratch)->mutable_device_memory()->opaque());
2930 
2931   if (status != miopenStatusSuccess) {
2932     LOG(FATAL) << "Failed to transform the data layout.";
2933   }
2934   output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
2935   return (*transform_scratch)->device_memory();
2936 }
2937 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2938 port::Status MIOpenSupport::DoConvolve(
2939     dnn::ConvolutionKind kind, dnn::DataType element_type,
2940     dnn::DataType output_type, Stream* stream,
2941     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2942     const dnn::FilterDescriptor& filter_descriptor,
2943     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2944     DeviceMemoryBase output_data,
2945     const dnn::ConvolutionDescriptor& convolution_descriptor,
2946     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2947     dnn::ProfileResult* output_profile_result) {
2948   auto miopen = miopen_->GetHandle(parent_, stream);
2949   ScopedTensorDescriptor input_nd{input_descriptor,
2950                                   ToMIOpenDataType(element_type)};
2951   ScopedTensorDescriptor output_nd{output_descriptor,
2952                                    ToMIOpenDataType(element_type)};
2953   ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
2954                                 ToMIOpenDataType(element_type)};
2955   ScopedConvolutionDescriptor conv{convolution_descriptor,
2956                                    ToMIOpenDataType(element_type)};
2957 
2958   // Alpha is the scaling factor for input.
2959   float alpha = 1.0;
2960   // Beta is the scaling factor for output.
2961   float beta = 0.0;
2962 
2963   const bool is_profiling = output_profile_result != nullptr;
2964 
2965   std::unique_ptr<GpuTimer> timer;
2966   if (is_profiling) {
2967     timer.reset(new GpuTimer(parent_));
2968     if (!timer->Init()) {
2969       return port::Status(port::error::INTERNAL, "Failed to init timer");
2970     }
2971     // The start and stop of the timer should be as close to the MIOpen call
2972     // as possible. It is still possible for other threads to issue workload
2973     // on to this stream. So it could take multiple profiling measurements.
2974     if (!timer->Start(AsGpuStream(stream))) {
2975       timer->Destroy();
2976       return port::Status(port::error::INTERNAL, "Failed to start timer");
2977     }
2978   }
2979 
2980   miopenStatus_t status = miopenStatusSuccess;
2981   switch (kind) {
2982     case dnn::ConvolutionKind::FORWARD: {
2983       if (use_immediate_mode_) {
2984         status = wrap::miopenConvolutionForwardImmediate(
2985             miopen.handle(), filter.handle(), filter_data.opaque(),
2986             input_nd.handle(), input_data.opaque(), conv.handle(),
2987             output_nd.handle(), output_data.opaque(), scratch_memory.opaque(),
2988             scratch_memory.size(),
2989             static_cast<uint64_t>(algorithm_desc.algo_id()));
2990       } else {
2991         status = wrap::miopenConvolutionForward(
2992             miopen.handle(), &alpha, input_nd.handle(), input_data.opaque(),
2993             filter.handle(), filter_data.opaque(), conv.handle(),
2994             static_cast<miopenConvFwdAlgorithm_t>(algorithm_desc.algo_id()),
2995             &beta, output_nd.handle(), output_data.opaque(),
2996             scratch_memory.opaque(), scratch_memory.size());
2997       }
2998 
2999       break;
3000     }
3001     case dnn::ConvolutionKind::BACKWARD_DATA: {
3002       // TBD: remove once MIOpen supports kBatchYXDepth for backward pass.
3003       BatchDescriptor output_back_descriptor;
3004       output_back_descriptor.CloneFrom(output_descriptor);
3005       std::unique_ptr<TemporaryDeviceMemory<uint8>> transform_scratch;
3006       output_data = MaybeTransformLayout(
3007           stream, miopen.handle(), ToMIOpenDataType(element_type),
3008           &output_back_descriptor, output_data, &transform_scratch);
3009 
3010       if (use_immediate_mode_) {
3011         status = wrap::miopenConvolutionBackwardDataImmediate(
3012             miopen.handle(), output_nd.handle(), output_data.opaque(),
3013             filter.handle(), filter_data.opaque(), conv.handle(),
3014             input_nd.handle(), input_data.opaque(), scratch_memory.opaque(),
3015             scratch_memory.size(),
3016             static_cast<uint64_t>(algorithm_desc.algo_id()));
3017       } else {
3018         status = wrap::miopenConvolutionBackwardData(
3019             miopen.handle(), &alpha, output_nd.handle(), output_data.opaque(),
3020             filter.handle(), filter_data.opaque(), conv.handle(),
3021             static_cast<miopenConvBwdDataAlgorithm_t>(algorithm_desc.algo_id()),
3022             &beta, input_nd.handle(), input_data.opaque(),
3023             scratch_memory.opaque(), scratch_memory.size());
3024       }
3025       break;
3026     }
3027     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3028       // TBD: remove once MIOpen supports kBatchYXDepth for backward pass.
3029       BatchDescriptor output_back_descriptor;
3030       output_back_descriptor.CloneFrom(output_descriptor);
3031       std::unique_ptr<TemporaryDeviceMemory<uint8>> transform_scratch;
3032       output_data = MaybeTransformLayout(
3033           stream, miopen.handle(), ToMIOpenDataType(element_type),
3034           &output_back_descriptor, output_data, &transform_scratch);
3035 
3036       if (use_immediate_mode_) {
3037         status = wrap::miopenConvolutionBackwardWeightsImmediate(
3038             miopen.handle(), output_nd.handle(), output_data.opaque(),
3039             input_nd.handle(), input_data.opaque(), conv.handle(),
3040             filter.handle(), filter_data.opaque(), scratch_memory.opaque(),
3041             scratch_memory.size(),
3042             static_cast<uint64_t>(algorithm_desc.algo_id()));
3043       } else {
3044         status = wrap::miopenConvolutionBackwardWeights(
3045             miopen.handle(), &alpha, output_nd.handle(), output_data.opaque(),
3046             input_nd.handle(), input_data.opaque(), conv.handle(),
3047             static_cast<miopenConvBwdWeightsAlgorithm_t>(
3048                 algorithm_desc.algo_id()),
3049             &beta, filter.handle(), filter_data.opaque(),
3050             scratch_memory.opaque(), scratch_memory.size());
3051       }
3052       break;
3053     }
3054     default:
3055       return port::InternalError(
3056           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3057   }
3058 
3059   if (is_profiling) {
3060     if (!timer->Stop(AsGpuStream(stream))) {
3061       timer->Destroy();
3062       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3063     }
3064     if (status == miopenStatusSuccess) {
3065       dnn::AlgorithmDesc algotype(algorithm_desc.algo_id(), false);
3066       output_profile_result->set_algorithm(algotype);
3067       output_profile_result->set_elapsed_time_in_ms(
3068           timer->GetElapsedMilliseconds());
3069       output_profile_result->set_scratch_size(scratch_memory.size());
3070     }
3071     timer->Destroy();
3072   }
3073 
3074   if (status != miopenStatusSuccess) {
3075     return port::InternalError(absl::StrCat(
3076         "Failed to euqueue convolution on stream: ", ToString(status)));
3077   }
3078 
3079   return port::Status::OK();
3080 }
3081 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3082 bool MIOpenSupport::GetConvolveAlgorithms(
3083     // ROCM TODO: refactor cc_major / cc_minor
3084     bool with_winograd_nonfused, int cc_major, int cc_minor,
3085     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3086   out_algorithms->assign({
3087       // clang-format off
3088       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoGEMM, false),
3089       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoDirect, false),
3090       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoFFT, false),
3091       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoWinograd, false),
3092       // clang-format on
3093   });
3094   return true;
3095 }
3096 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3097 bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
3098     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3099     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3100     const dnn::FilterDescriptor& filter_descriptor,
3101     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3102     DeviceMemoryBase output_data,
3103     const dnn::ConvolutionDescriptor& convolution_descriptor,
3104     ScratchAllocator* scratch_allocator,
3105     std::vector<dnn::ProfileResult>* out_algorithms) {
3106   return use_immediate_mode_
3107              ? GetMIOpenConvolveAlgorithmsImmediateMode(
3108                    kind, element_type, stream, input_descriptor, input_data,
3109                    filter_descriptor, filter_data, output_descriptor,
3110                    output_data, convolution_descriptor, scratch_allocator,
3111                    out_algorithms)
3112              : GetMIOpenConvolveAlgorithmsFindMode(
3113                    kind, element_type, stream, input_descriptor, input_data,
3114                    filter_descriptor, filter_data, output_descriptor,
3115                    output_data, convolution_descriptor, scratch_allocator,
3116                    out_algorithms);
3117 }
3118 
GetMIOpenConvolveAlgorithmsImmediateMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3119 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode(
3120     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3121     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3122     const dnn::FilterDescriptor& filter_descriptor,
3123     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3124     DeviceMemoryBase output_data,
3125     const dnn::ConvolutionDescriptor& convolution_descriptor,
3126     ScratchAllocator* scratch_allocator,
3127     std::vector<dnn::ProfileResult>* out_algorithms) {
3128   auto miopen = miopen_->GetHandle(parent_, stream);
3129 
3130   ScopedTensorDescriptor input_nd{input_descriptor,
3131                                   ToMIOpenDataType(element_type)};
3132   ScopedTensorDescriptor output_nd{output_descriptor,
3133                                    ToMIOpenDataType(element_type)};
3134   ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
3135                                 ToMIOpenDataType(element_type)};
3136   ScopedConvolutionDescriptor conv{convolution_descriptor,
3137                                    ToMIOpenDataType(element_type)};
3138 
3139   // First determine the number of algorityhms available
3140   size_t maxSolutionCount = 0;
3141 
3142   switch (kind) {
3143     case dnn::ConvolutionKind::FORWARD: {
3144       auto status = wrap::miopenConvolutionForwardGetSolutionCount(
3145           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3146           output_nd.handle(), &maxSolutionCount);
3147       if (status != miopenStatusSuccess) {
3148         LOG(FATAL)
3149             << "call to miopenConvolutionForwardGetSolutionCount failed: "
3150             << ToString(status);
3151         return false;
3152       }
3153       break;
3154     }
3155     case dnn::ConvolutionKind::BACKWARD_DATA: {
3156       auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount(
3157           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3158           input_nd.handle(), &maxSolutionCount);
3159       if (status != miopenStatusSuccess) {
3160         LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
3161                       "failed: "
3162                    << ToString(status);
3163         return false;
3164       }
3165       break;
3166     }
3167     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3168       auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount(
3169           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3170           filter.handle(), &maxSolutionCount);
3171       if (status != miopenStatusSuccess) {
3172         LOG(FATAL)
3173             << "call to miopenConvolutionBackwardWeightsGetSolutionCount "
3174                "failed: "
3175             << ToString(status);
3176         return false;
3177       }
3178       break;
3179     }
3180     default: {
3181       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3182       return false;
3183       break;
3184     }
3185   }
3186 
3187   VLOG(kConvDebugVlogLevel)
3188       << "Number of conv solutions max: " << maxSolutionCount;
3189 
3190   if (return_best_algo_only_) {
3191     VLOG(kConvDebugVlogLevel) << "TF_ROCM_RETURN_BEST_ALGO_ONLY is set, "
3192                               << "setting maxSolutionCount to 1";
3193     maxSolutionCount = 1;
3194   }
3195 
3196   size_t solutionCount = 0;
3197   std::unique_ptr<miopenConvSolution_t[]> solutions(
3198       new miopenConvSolution_t[maxSolutionCount]);
3199 
3200   switch (kind) {
3201     case dnn::ConvolutionKind::FORWARD: {
3202       auto status = wrap::miopenConvolutionForwardGetSolution(
3203           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3204           output_nd.handle(), maxSolutionCount, &solutionCount,
3205           solutions.get());
3206 
3207       if (status != miopenStatusSuccess) {
3208         LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: "
3209                    << ToString(status);
3210         return false;
3211       }
3212 
3213       VLOG(kConvDebugVlogLevel)
3214           << "Number of conv solutions actual: " << solutionCount;
3215 
3216       for (size_t i = 0; i < solutionCount; i++) {
3217         miopenConvSolution_t solution = solutions[i];
3218 
3219         VLOG(kConvDebugVlogLevel)
3220             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3221             << ", " << solution.workspace_size << ", " << solution.solution_id
3222             << ", " << ToString(solution.algorithm);
3223 
3224         status = wrap::miopenConvolutionForwardCompileSolution(
3225             miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3226             output_nd.handle(), solution.solution_id);
3227 
3228         if (status != miopenStatusSuccess) {
3229           LOG(FATAL)
3230               << "call to miopenConvolutionForwardCompileSolution failed: "
3231               << ToString(status);
3232           return false;
3233         }
3234 
3235         out_algorithms->emplace_back(
3236             GetProfileResultFromConvSolution(solution));
3237       }
3238       break;
3239     }
3240 
3241     case dnn::ConvolutionKind::BACKWARD_DATA: {
3242       auto status = wrap::miopenConvolutionBackwardDataGetSolution(
3243           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3244           input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get());
3245       if (status != miopenStatusSuccess) {
3246         LOG(FATAL)
3247             << "call to miopenConvolutionBackwardDataGetSolution failed: "
3248             << ToString(status);
3249         return false;
3250       }
3251 
3252       VLOG(kConvDebugVlogLevel)
3253           << "Number of conv solutions actual: " << solutionCount;
3254 
3255       for (size_t i = 0; i < solutionCount; i++) {
3256         miopenConvSolution_t solution = solutions[i];
3257 
3258         VLOG(kConvDebugVlogLevel)
3259             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3260             << ", " << solution.workspace_size << ", " << solution.solution_id
3261             << ", " << ToString(solution.algorithm);
3262 
3263         status = wrap::miopenConvolutionBackwardDataCompileSolution(
3264             miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3265             input_nd.handle(), solution.solution_id);
3266 
3267         if (status != miopenStatusSuccess) {
3268           LOG(FATAL) << " call to miopenConvolutionBackwardDataCompileSolution "
3269                         "failed: "
3270                      << ToString(status);
3271           return false;
3272         }
3273 
3274         out_algorithms->emplace_back(
3275             GetProfileResultFromConvSolution(solution));
3276       }
3277       break;
3278     }
3279     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3280       auto status = wrap::miopenConvolutionBackwardWeightsGetSolution(
3281           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3282           filter.handle(), maxSolutionCount, &solutionCount, solutions.get());
3283       if (status != miopenStatusSuccess) {
3284         LOG(FATAL)
3285             << "call to miopenConvolutionBackwardWeightsGetSolution failed: "
3286             << ToString(status);
3287         return false;
3288       }
3289 
3290       VLOG(kConvDebugVlogLevel)
3291           << "Number of conv solutions actual: " << solutionCount;
3292 
3293       for (size_t i = 0; i < solutionCount; i++) {
3294         miopenConvSolution_t solution = solutions[i];
3295 
3296         VLOG(kConvDebugVlogLevel)
3297             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3298             << ", " << solution.workspace_size << ", " << solution.solution_id
3299             << ", " << ToString(solution.algorithm);
3300 
3301         status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
3302             miopen.handle(), output_nd.handle(), input_nd.handle(),
3303             conv.handle(), filter.handle(), solution.solution_id);
3304 
3305         if (status != miopenStatusSuccess) {
3306           LOG(FATAL)
3307               << "call to miopenConvolutionBackwardWeightsCompileSolution "
3308                  "failed: "
3309               << ToString(status);
3310           return false;
3311         }
3312 
3313         out_algorithms->emplace_back(
3314             GetProfileResultFromConvSolution(solution));
3315       }
3316       break;
3317     }
3318     default: {
3319       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3320       return false;
3321       break;
3322     }
3323   }
3324 
3325   return true;
3326 }
3327 
GetMIOpenConvolveAlgorithmsFindMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3328 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode(
3329     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3330     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3331     const dnn::FilterDescriptor& filter_descriptor,
3332     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3333     DeviceMemoryBase output_data,
3334     const dnn::ConvolutionDescriptor& convolution_descriptor,
3335     ScratchAllocator* scratch_allocator,
3336     std::vector<dnn::ProfileResult>* out_algorithms) {
3337   auto miopen = miopen_->GetHandle(parent_, stream);
3338 
3339   ScopedTensorDescriptor input_nd{input_descriptor,
3340                                   ToMIOpenDataType(element_type)};
3341   ScopedTensorDescriptor output_nd{output_descriptor,
3342                                    ToMIOpenDataType(element_type)};
3343   ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
3344                                 ToMIOpenDataType(element_type)};
3345   ScopedConvolutionDescriptor conv{convolution_descriptor,
3346                                    ToMIOpenDataType(element_type)};
3347 
3348   // Determine the workspace memory size that will need by the call to Find
3349   size_t scratch_memory_size = 0;
3350   switch (kind) {
3351     case dnn::ConvolutionKind::FORWARD: {
3352       auto status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
3353           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3354           output_nd.handle(), &scratch_memory_size);
3355       if (status != miopenStatusSuccess) {
3356         LOG(FATAL)
3357             << "call to miopenConvolutionForwardGetWorkspaceSize failed: "
3358             << ToString(status);
3359         return false;
3360       }
3361       break;
3362     }
3363     case dnn::ConvolutionKind::BACKWARD_DATA: {
3364       auto status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
3365           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3366           input_nd.handle(), &scratch_memory_size);
3367       if (status != miopenStatusSuccess) {
3368         LOG(FATAL)
3369             << "call to miopenConvolutionBackwardDataGetWorkspaceSize failed: "
3370             << ToString(status);
3371         return false;
3372       }
3373       break;
3374     }
3375     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3376       auto status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
3377           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3378           filter.handle(), &scratch_memory_size);
3379       if (status != miopenStatusSuccess) {
3380         LOG(FATAL)
3381             << "call to miopenConvolutionBackwardWeightsGetWorkspaceSize "
3382                "failed: "
3383             << ToString(status);
3384         return false;
3385       }
3386       break;
3387     }
3388     default: {
3389       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3390       return false;
3391       break;
3392     }
3393   }
3394 
3395   // allocate scratch memory
3396   DeviceMemory<uint8> scratch_memory;
3397   if (scratch_memory_size != 0) {
3398     if (scratch_allocator == nullptr) {
3399       LOG(FATAL)
3400           << "An allocator must be specified when scratch memory is needed";
3401       return false;
3402     }
3403     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
3404     if (allocated.ok()) {
3405       scratch_memory = allocated.ValueOrDie();
3406     } else {
3407       LOG(FATAL)
3408           << "Failed to allocate scratch memory - "
3409           << allocated.status().error_message() << "\n"
3410           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
3411              "larger number (e.g. 8192) to increase the max memory limit.\n"
3412           << "\tIncreasing the max memory limit might help resolve this "
3413              "error";
3414       return false;
3415     }
3416   }
3417 
3418   // Only get the best algorithm for Find Mode
3419   size_t requestedAlgorithmCount = 1;
3420 
3421   VLOG(kConvDebugVlogLevel)
3422       << "Number of conv algortihms to request: " << requestedAlgorithmCount;
3423 
3424   miopenConvAlgoPerf_t returnedAlgorithm;
3425 
3426   int returnedAlgorithmCount = 0;
3427   bool exhaustiveSearch = false;
3428 
3429   switch (kind) {
3430     case dnn::ConvolutionKind::FORWARD: {
3431       auto status = wrap::miopenFindConvolutionForwardAlgorithm(
3432           miopen.handle(), input_nd.handle(), input_data.opaque(),
3433           filter.handle(), filter_data.opaque(), conv.handle(),
3434           output_nd.handle(), output_data.opaque(), requestedAlgorithmCount,
3435           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3436           scratch_memory_size, exhaustiveSearch);
3437       if (status != miopenStatusSuccess) {
3438         LOG(FATAL) << "call to miopenFindConvolutionForwardAlgorithm failed: "
3439                    << ToString(status);
3440         return false;
3441       }
3442       break;
3443     }
3444     case dnn::ConvolutionKind::BACKWARD_DATA: {
3445       auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
3446           miopen.handle(), output_nd.handle(), output_data.opaque(),
3447           filter.handle(), filter_data.opaque(), conv.handle(),
3448           input_nd.handle(), input_data.opaque(), requestedAlgorithmCount,
3449           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3450           scratch_memory_size, exhaustiveSearch);
3451       if (status != miopenStatusSuccess) {
3452         LOG(FATAL)
3453             << "call to miopenFindConvolutionBackwardDataAlgorithm failed: "
3454             << ToString(status);
3455         return false;
3456       }
3457       break;
3458     }
3459     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3460       auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
3461           miopen.handle(), output_nd.handle(), output_data.opaque(),
3462           input_nd.handle(), input_data.opaque(), conv.handle(),
3463           filter.handle(), filter_data.opaque(), requestedAlgorithmCount,
3464           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3465           scratch_memory_size, exhaustiveSearch);
3466       if (status != miopenStatusSuccess) {
3467         LOG(FATAL) << "call to miopenConvolutionBackwardWeightsAlgorithm "
3468                       "failed: "
3469                    << ToString(status);
3470         return false;
3471       }
3472       break;
3473     }
3474     default: {
3475       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3476       return false;
3477       break;
3478     }
3479   }
3480 
3481   out_algorithms->emplace_back(
3482       GetProfileResultFromConvAlgoPerf(kind, returnedAlgorithm));
3483 
3484   return true;
3485 }
3486 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3487 bool MIOpenSupport::GetRnnAlgorithms(
3488     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3489   // ROCM TODO: implement this with proper MIOpen API
3490   return true;
3491 }
3492 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3493 bool MIOpenSupport::GetConvolveBackwardDataAlgorithms(
3494     // ROCM TODO: refactor cc_major / cc_minor
3495     bool with_winograd_nonfused, int cc_major, int cc_minor,
3496     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3497   out_algorithms->assign({
3498       // clang-format off
3499       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoGEMM, false),
3500       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoDirect, false),
3501       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoFFT, false),
3502       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoWinograd, false),
3503       // clang-format on
3504   });
3505   return true;
3506 }
3507 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)3508 bool MIOpenSupport::GetConvolveBackwardFilterAlgorithms(
3509     // ROCM TODO: refactor cc_major / cc_minor
3510     bool with_winograd_nonfused, int cc_major, int cc_minor,
3511     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3512   out_algorithms->assign({
3513       // clang-format off
3514       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoGEMM, false),
3515       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoDirect, false),
3516       // clang-format on
3517   });
3518   return true;
3519 }
3520 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3521 bool MIOpenSupport::DoBatchNormalizationForward(
3522     Stream* stream, const DeviceMemory<Eigen::half>& x,
3523     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3524     const DeviceMemory<float>& estimated_mean,
3525     const DeviceMemory<float>& estimated_variance,
3526     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3527     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3528     const double exponential_average_factor,
3529     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3530     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3531     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3532     bool is_training, ScratchAllocator* reserve_space_allocator,
3533     ScratchAllocator* workspace_allocator) {
3534   return DoBatchNormalizationForwardImpl<Eigen::half, float>(
3535       stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3536       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3537       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3538       batch_var, saved_mean, saved_inv_var, is_training);
3539 }
3540 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3541 bool MIOpenSupport::DoBatchNormalizationForward(
3542     Stream* stream, const DeviceMemory<float>& x,
3543     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3544     const DeviceMemory<float>& estimated_mean,
3545     const DeviceMemory<float>& estimated_variance,
3546     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3547     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3548     const double exponential_average_factor,
3549     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3550     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3551     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3552     bool is_training, ScratchAllocator* reserve_space_allocator,
3553     ScratchAllocator* workspace_allocator) {
3554   return DoBatchNormalizationForwardImpl<float, float>(
3555       stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
3556       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3557       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3558       batch_var, saved_mean, saved_inv_var, is_training);
3559 }
3560 
3561 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<U> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training)3562 bool MIOpenSupport::DoBatchNormalizationForwardImpl(
3563     Stream* stream, dnn::DataType input_data_type,
3564     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3565     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3566     const DeviceMemory<U>& estimated_mean,
3567     const DeviceMemory<U>& estimated_variance,
3568     const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
3569     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3570     const double exponential_average_factor,
3571     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3572     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3573     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3574     bool is_training) {
3575   auto miopen = miopen_->GetHandle(parent_, stream);
3576 
3577   ScopedTensorDescriptor x_descriptor{x_desc,
3578                                       ToMIOpenDataType(input_data_type)};
3579   ScopedTensorDescriptor scale_offset_descriptor{
3580       scale_offset_desc, ToMIOpenDataType(scale_data_type)};
3581   miopenBatchNormMode_t mode = miopenBNSpatial;
3582   float one = 1.0;
3583   float zero = 0.0;
3584 
3585   auto status = miopenStatusInvalidValue;
3586   if (is_training) {
3587     status = wrap::miopenBatchNormalizationForwardTraining(
3588         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3589         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3590         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3591         exponential_average_factor, batch_mean->opaque(), batch_var->opaque(),
3592         epsilon, saved_mean->opaque(), saved_inv_var->opaque());
3593   } else {
3594     const void* maybe_inv_var = estimated_variance.opaque();
3595     status = wrap::miopenBatchNormalizationForwardInference(
3596         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3597         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3598         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3599         const_cast<void*>(estimated_mean.opaque()),
3600         const_cast<void*>(maybe_inv_var), epsilon);
3601   }
3602   if (status != miopenStatusSuccess) {
3603     LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
3604                << ToString(status);
3605     return false;
3606   }
3607   return true;
3608 }
3609 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3610 bool MIOpenSupport::DoBatchNormalizationBackward(
3611     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3612     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3613     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3614     const dnn::BatchDescriptor& x_desc,
3615     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3616     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3617     DeviceMemory<float>* offset_backprop,
3618     DeviceMemory<uint8>* reserve_space_data,
3619     ScratchAllocator* workspace_allocator) {
3620   return DoBatchNormalizationBackwardImpl<Eigen::half, float>(
3621       stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, inv_var,
3622       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3623       offset_backprop);
3624 }
3625 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3626 bool MIOpenSupport::DoBatchNormalizationBackward(
3627     Stream* stream, const DeviceMemory<float>& y_backprop,
3628     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3629     const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
3630     const dnn::BatchDescriptor& x_desc,
3631     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3632     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3633     DeviceMemory<float>* offset_backprop,
3634     DeviceMemory<uint8>* reserve_space_data,
3635     ScratchAllocator* workspace_allocator) {
3636   return DoBatchNormalizationBackwardImpl<float, float>(
3637       stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, variance,
3638       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3639       offset_backprop);
3640 }
3641 
3642 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int miopen_input_type,int miopen_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3643 bool MIOpenSupport::DoBatchNormalizationBackwardImpl(
3644     Stream* stream, int miopen_input_type, int miopen_scale_type,
3645     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3646     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3647     const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
3648     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3649     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3650     DeviceMemory<U>* offset_backprop) {
3651   auto miopen = miopen_->GetHandle(parent_, stream);
3652   ScopedTensorDescriptor x_descriptor{
3653       x_desc, static_cast<miopenDataType_t>(miopen_input_type)};
3654   ScopedTensorDescriptor scale_offset_descriptor{
3655       scale_offset_desc, static_cast<miopenDataType_t>(miopen_scale_type)};
3656   miopenBatchNormMode_t mode = miopenBNSpatial;
3657   float one = 1.0;
3658   float zero = 0.0;
3659 
3660   auto status = wrap::miopenBatchNormalizationBackward(
3661       miopen.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3662       x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3663       x_descriptor.handle(), x_backprop->opaque(),
3664       scale_offset_descriptor.handle(), scale.opaque(),
3665       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3666       mean.opaque(), variance.opaque());
3667   if (status != miopenStatusSuccess) {
3668     LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
3669                << ToString(status);
3670     return false;
3671   }
3672   return true;
3673 }
3674 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3675 port::Status MIOpenSupport::DoFusedConvolve(
3676     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3677     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3678     const dnn::FilterDescriptor& filter_descriptor,
3679     const DeviceMemory<double>& filter_data,
3680     const dnn::ConvolutionDescriptor& convolution_descriptor,
3681     const DeviceMemory<double>& side_input_data, double side_input_scale,
3682     const dnn::BatchDescriptor& bias_descriptor,
3683     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3684     const dnn::BatchDescriptor& output_descriptor,
3685     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3686     const dnn::AlgorithmConfig& algorithm_config,
3687     dnn::ProfileResult* output_profile_result) {
3688   return port::UnimplementedError("fused convolve not implemented yet");
3689 }
3690 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3691 port::Status MIOpenSupport::DoFusedConvolve(
3692     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3693     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3694     const dnn::FilterDescriptor& filter_descriptor,
3695     const DeviceMemory<float>& filter_data,
3696     const dnn::ConvolutionDescriptor& convolution_descriptor,
3697     const DeviceMemory<float>& side_input_data, float side_input_scale,
3698     const dnn::BatchDescriptor& bias_descriptor,
3699     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3700     const dnn::BatchDescriptor& output_descriptor,
3701     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3702     const dnn::AlgorithmConfig& algorithm_config,
3703     dnn::ProfileResult* output_profile_result) {
3704   return port::UnimplementedError("fused convolve not implemented yet");
3705 }
3706 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3707 port::Status MIOpenSupport::DoFusedConvolve(
3708     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3709     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3710     const dnn::FilterDescriptor& filter_descriptor,
3711     const DeviceMemory<Eigen::half>& filter_data,
3712     const dnn::ConvolutionDescriptor& convolution_descriptor,
3713     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3714     const dnn::BatchDescriptor& bias_descriptor,
3715     const DeviceMemory<Eigen::half>& biases,
3716     dnn::ActivationMode activation_mode,
3717     const dnn::BatchDescriptor& output_descriptor,
3718     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3719     const dnn::AlgorithmConfig& algorithm_config,
3720     dnn::ProfileResult* output_profile_result) {
3721   return port::UnimplementedError("fused convolve not implemented yet");
3722 }
3723 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3724 port::Status MIOpenSupport::DoFusedConvolve(
3725     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3726     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3727     const dnn::FilterDescriptor& filter_descriptor,
3728     const DeviceMemory<int8>& filter_data,
3729     const dnn::ConvolutionDescriptor& convolution_descriptor,
3730     const DeviceMemory<int8>& side_input_data, float side_input_scale,
3731     const dnn::BatchDescriptor& bias_descriptor,
3732     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3733     const dnn::BatchDescriptor& output_descriptor,
3734     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3735     const dnn::AlgorithmConfig& algorithm_config,
3736     dnn::ProfileResult* output_profile_result) {
3737   return port::UnimplementedError("fused convolve not implemented yet");
3738 }
3739 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3740 bool MIOpenSupport::DoTransformTensor(Stream* stream,
3741                                       const dnn::BatchDescriptor& input_desc,
3742                                       dnn::DataType input_type,
3743                                       const DeviceMemoryBase& input_data,
3744                                       const dnn::BatchDescriptor& output_desc,
3745                                       dnn::DataType output_type, float scale,
3746                                       DeviceMemoryBase* output_data) {
3747   // ROCM TODO implement this operation
3748   LOG(ERROR) << "transform tensor not implemented yet";
3749   return false;
3750 }
3751 
3752 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)3753 bool MIOpenSupport::DoConvolveBackwardBiasImpl(
3754     Stream* stream, int miopen_type,  // Actually miopenDataType_t.
3755     const dnn::BatchDescriptor& input_descriptor,
3756     const DeviceMemory<T>& input_data,
3757     const dnn::BatchDescriptor& bias_descriptor,
3758     DeviceMemory<T>* backward_bias_data) {
3759   auto miopen = miopen_->GetHandle(parent_, stream);
3760 
3761   ScopedTensorDescriptor input_nd{input_descriptor,
3762                                   static_cast<miopenDataType_t>(miopen_type)};
3763   ScopedTensorDescriptor bias_nd{bias_descriptor,
3764                                  static_cast<miopenDataType_t>(miopen_type)};
3765 
3766   // Alpha is the scaling factor for input.
3767   float alpha = 1.0;
3768   // Beta is the scaling factor for output.
3769   float beta = 0.0;
3770 
3771   auto status = wrap::miopenConvolutionBackwardBias(
3772       miopen.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
3773       bias_nd.handle(), backward_bias_data->opaque());
3774   if (status != miopenStatusSuccess) {
3775     LOG(FATAL) << "failed to enqueue backward convolution on stream: "
3776                << ToString(status);
3777     return false;
3778   }
3779   return true;
3780 }
3781 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)3782 bool MIOpenSupport::DoConvolveBackwardBias(
3783     Stream* stream, const BatchDescriptor& input_descriptor,
3784     const DeviceMemory<double>& input_data,
3785     const BatchDescriptor& bias_descriptor,
3786     DeviceMemory<double>* backward_bias_data) {
3787   LOG(ERROR) << "miopen does not support double bwd bias yet";
3788   return false;
3789 }
3790 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)3791 bool MIOpenSupport::DoConvolveBackwardBias(
3792     Stream* stream, const BatchDescriptor& input_descriptor,
3793     const DeviceMemory<float>& input_data,
3794     const BatchDescriptor& bias_descriptor,
3795     DeviceMemory<float>* backward_bias_data) {
3796   return DoConvolveBackwardBiasImpl(stream, miopenFloat, input_descriptor,
3797                                     input_data, bias_descriptor,
3798                                     backward_bias_data);
3799 }
3800 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)3801 bool MIOpenSupport::DoConvolveBackwardBias(
3802     Stream* stream, const BatchDescriptor& input_descriptor,
3803     const DeviceMemory<Eigen::half>& input_data,
3804     const BatchDescriptor& bias_descriptor,
3805     DeviceMemory<Eigen::half>* backward_bias_data) {
3806   return DoConvolveBackwardBiasImpl(stream, miopenHalf, input_descriptor,
3807                                     input_data, bias_descriptor,
3808                                     backward_bias_data);
3809 }
3810 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3811 bool MIOpenSupport::DoMatMul(Stream* stream,
3812                              const DeviceMemory<float>& input_data,
3813                              const DeviceMemory<float>& weights,
3814                              const dnn::BatchDescriptor& input_dimensions,
3815                              const dnn::BatchDescriptor& output_dimensions,
3816                              DeviceMemory<float>* output_data) {
3817   if (input_dimensions.count() != output_dimensions.count()) {
3818     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3819     return false;
3820   }
3821 
3822   // We do not permute the input or output, instead we just
3823   // reinterpret the layout. We are working with row-major matrices
3824   // and the rows of the input and output correspond to batch, so
3825   // batch has to be outermost in both the input and output.
3826   //
3827   // By adding transposes to the BLAS gemm call we could perhaps make
3828   // the kYXDepthBatch layout work as well, but there has been no need
3829   // for that so far.
3830   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3831       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3832     LOG(ERROR) << "Unsupported MatMul input layout.";
3833     return false;
3834   }
3835   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3836       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3837     LOG(ERROR) << "Unsupported MatMul output layout.";
3838     return false;
3839   }
3840 
3841   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3842     // This is a fast path that also supports the kBatchYXDepth layout.
3843 
3844     // The matrices here are in row-major format while BLAS expects
3845     // column-major, i.e. our matrices are transposed as far as BLAS
3846     // is concerned. So we need to compute output^T =
3847     // input^T*weights^T. There is no parameter for transposing the
3848     // output in BLAS gemm, but instead we can transpose both sides of
3849     // the equality to see that this is equivalent to
3850     // output=weights*input. So we only need to swap the order of
3851     // weights and input in the matrix product to correct for the
3852     // row-major versus column-major difference.
3853     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3854     const float beta = 0.0f;   // Ignore the original values in output_data.
3855     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
3856     const int64 n = input_dimensions.count();
3857     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
3858     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
3859                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
3860                          m, input_data, k, beta, output_data, m);
3861   } else {
3862     // This is a slower and more complex path that supports output
3863     // width() * height() > 1, though it only supports the
3864     // kBatchYXDepth layout. Does support kBatchDepthYX if output
3865     // feature_map_count() == 1, as then there is no difference
3866     // between the two layouts.
3867     //
3868     // The operation here is the same as above, except that we have to
3869     // do the matrix multiplication for each (y,x) output coordinate
3870     // separately. We then interpret weights as containing K = width()
3871     // * height() different matrices, which we all multiply onto the
3872     // matrix from input_data, yielding K matrix products. We then
3873     // combine these together into one matrix by concatenating all the
3874     // first rows of these matrices, then all the seconds rows and so
3875     // on. We can do this with a batched matrix multiplication, where
3876     // the result is written to a different submatrix of the output
3877     // for each matrix multiplication.
3878     //
3879     // The reason that we only support the kBatchYXDepth output layout
3880     // is that we have to do something in the depth for each (y,x)
3881     // coordinate. The kBatchYXDepth layout has the depth information
3882     // for each point (y,x) in contiguous memory while the
3883     // kBatchDepthYX layout does not.
3884     //
3885     // TODO(broune): Consider a special case for when output depth ==
3886     // 1, as then possibly this could all be done as one matrix
3887     // multiplication instead of a batched one, which should be
3888     // faster. Another possibility would be to add a weights layout
3889     // parameter and then support kBatchDepthYX for a different
3890     // weights layout.
3891     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3892         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3893           output_dimensions.feature_map_count() == 1)) {
3894       LOG(ERROR) << "Unsupported MatMul output layout.";
3895       return false;
3896     }
3897 
3898     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3899     const float beta = 0.0f;   // Ignore the original values in output_data.
3900     const uint64 m = output_dimensions.feature_map_count();
3901     const uint64 n = input_dimensions.count();
3902     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
3903     const int lda = m;
3904     const int ldb = k;
3905     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3906     const int batch_count = output_dimensions.NodesPerFeatureMap();
3907 
3908     std::vector<DeviceMemory<float>> a(batch_count);
3909     std::vector<DeviceMemory<float>> b(batch_count);
3910     std::vector<DeviceMemory<float>> c(batch_count);
3911     for (int i = 0; i < batch_count; ++i) {
3912       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3913                                  output_dimensions.feature_map_count();
3914       a[i] = DeviceMemory<float>::MakeFromByteSize(
3915           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3916               weights_offset,
3917           weights.ElementCount() - weights_offset);
3918 
3919       b[i] = input_data;
3920 
3921       const int output_offset = i * output_dimensions.feature_map_count();
3922       c[i] = DeviceMemory<float>::MakeFromByteSize(
3923           const_cast<float*>(
3924               reinterpret_cast<const float*>(output_data->opaque())) +
3925               output_offset,
3926           output_data->ElementCount() - output_offset);
3927     }
3928     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3929       std::vector<DeviceMemory<float>*> ptrs;
3930       ptrs.reserve(v.size());
3931       for (auto& mem : v) {
3932         ptrs.push_back(&mem);
3933       }
3934       return ptrs;
3935     };
3936 
3937     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3938                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
3939                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3940                                 ldc, batch_count);
3941   }
3942 
3943   return stream->ok();
3944 }
3945 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3946 bool MIOpenSupport::DoBiasAdd(Stream* stream,
3947                               const DeviceMemory<float>& input_data,
3948                               const DeviceMemory<float>& biases,
3949                               const dnn::BatchDescriptor& dimensions,
3950                               DeviceMemory<float>* output_data) {
3951   ScopedTensorDescriptor input_descriptor{dimensions, miopenFloat};
3952 
3953   BatchDescriptor bias_dimensions;
3954   bias_dimensions.set_count(1)
3955       .set_feature_map_count(dimensions.feature_map_count())
3956       .set_height(1)
3957       .set_width(1)
3958       .set_layout(dnn::DataLayout::kBatchYXDepth);
3959   ScopedTensorDescriptor bias_descriptor{bias_dimensions, miopenFloat};
3960 
3961   if (input_data.opaque() != output_data->opaque()) {
3962     stream->ThenMemcpy(output_data, input_data,
3963                        dimensions.ElementCount() * sizeof(float));
3964     if (!stream->ok()) {
3965       LOG(ERROR)
3966           << "stream " << stream
3967           << " could not enqueue a tensor copy as part of bias addition.";
3968       return false;
3969     }
3970   }
3971 
3972   auto miopen = miopen_->GetHandle(parent_, stream);
3973 
3974   const float alpha1 = 1.0f;
3975   const float alpha2 = 0.0f;
3976   const float beta = 1.0f;
3977 
3978   auto status = wrap::miopenOpTensor(
3979       miopen.handle(), miopenTensorOpAdd, &alpha1, bias_descriptor.handle(),
3980       biases.opaque(), &alpha2, bias_descriptor.handle(), biases.opaque(),
3981       &beta, input_descriptor.handle(), output_data->opaque());
3982 
3983   if (status != miopenStatusSuccess) {
3984     LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
3985     return false;
3986   }
3987 
3988   return true;
3989 }
3990 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)3991 bool MIOpenSupport::DoActivate(Stream* stream,
3992                                dnn::ActivationMode activation_mode,
3993                                const dnn::BatchDescriptor& dimensions,
3994                                const DeviceMemory<float>& input_data,
3995                                DeviceMemory<float>* output_data,
3996                                uint64 options) {
3997   LOG(ERROR) << "miopen does not support activation yet";
3998   return false;
3999 }
4000 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)4001 bool MIOpenSupport::DoPoolForward(
4002     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4003     const dnn::BatchDescriptor& input_dimensions,
4004     const DeviceMemory<double>& input_data,
4005     const dnn::BatchDescriptor& output_dimensions,
4006     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
4007   LOG(ERROR) << "miopen does not support pooling for double type yet";
4008   return false;
4009 }
4010 
IsSame(const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type)4011 bool PoolingWorkspaceDescriptor::IsSame(
4012     const dnn::BatchDescriptor& input_dimensions,
4013     const dnn::BatchDescriptor& output_dimensions,
4014     const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
4015   return dtype == _type &&
4016          input_dims ==
4017              input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
4018          output_dims ==
4019              output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
4020          op.mode() == pooling_dimensions.mode() &&
4021          op.window() == pooling_dimensions.window() &&
4022          op.padding() == pooling_dimensions.padding() &&
4023          op.strides() == pooling_dimensions.strides();
4024 }
4025 
find(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,PoolingWorkspaceDescriptor * & pdesc)4026 bool PoolingWorkspaceCache::find(
4027     const void* p, const dnn::BatchDescriptor& input_dimensions,
4028     const dnn::BatchDescriptor& output_dimensions,
4029     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
4030     PoolingWorkspaceDescriptor*& pdesc) {
4031   pdesc = 0;
4032   auto it = cache.find(p);
4033   if (it == cache.end()) {
4034     return false;
4035   }
4036   if (!it->second.IsSame(input_dimensions, output_dimensions,
4037                          pooling_dimensions, _type)) {
4038     return false;
4039   }
4040   pdesc = &it->second;
4041   return true;
4042 }
4043 
insert(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,std::unique_ptr<TemporaryDeviceMemory<uint8>> & workspace,size_t wsp_size,hipStream_t hip_stream)4044 void PoolingWorkspaceCache::insert(
4045     const void* p, const dnn::BatchDescriptor& input_dimensions,
4046     const dnn::BatchDescriptor& output_dimensions,
4047     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
4048     std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
4049     hipStream_t hip_stream) {
4050   PoolingWorkspaceDescriptor* desc = 0;
4051   auto it = cache.find(p);
4052   if (it != cache.end()) {
4053     // replacing an entry with the same pointer but different attributes
4054     // (if everything matches, the caller is expected to reuse the entry)
4055     desc = &it->second;
4056     hipStreamSynchronize(hip_stream);
4057     memory_used -= desc->workspace_size;
4058   } else {
4059     cache[p] = PoolingWorkspaceDescriptor();
4060     desc = &cache[p];
4061   }
4062   desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4063   desc->output_dims =
4064       output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4065   desc->op = pooling_dimensions;
4066   desc->dtype = _type;
4067   desc->timestamp = timestamp;
4068   timestamp++;
4069   desc->workspace = std::move(workspace);
4070   desc->workspace_size = wsp_size;
4071   memory_used += wsp_size;
4072   trim(hip_stream);
4073 }
4074 
trim(hipStream_t hip_stream)4075 void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
4076   if (memory_used < memory_budget && cache.size() < trim_size) return;
4077   bool must_sync = true;
4078   while (true) {
4079     int new_size = cache.size() - (cache.size() >> 2);
4080     std::vector<const void*> old_entries;
4081     for (auto& x : cache)
4082       if (x.second.timestamp + new_size < timestamp)
4083         old_entries.push_back(x.first);
4084     if (old_entries.empty()) break;
4085     if (must_sync) hipStreamSynchronize(hip_stream);
4086     must_sync = true;
4087     for (auto x : old_entries) {
4088       memory_used -= cache[x].workspace_size;
4089       cache.erase(x);
4090     }
4091     if (memory_used < memory_budget || cache.size() < 10) break;
4092   }
4093 }
4094 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)4095 bool MIOpenSupport::DoPoolForward(
4096     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4097     const dnn::BatchDescriptor& input_dimensions,
4098     const DeviceMemory<float>& input_data,
4099     const dnn::BatchDescriptor& output_dimensions,
4100     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
4101   auto miopen = miopen_->GetHandle(parent_, stream);
4102   // Alpha is the scaling factor for input.
4103   float alpha = 1.0;
4104   // Beta is the scaling factor for output.
4105   float beta = 0.0;
4106 
4107   ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
4108   ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
4109   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4110 
4111   bool do_backward = false;
4112   uint8* workspace = 0;
4113   size_t workspace_size = 0;
4114   std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
4115   if (m_pooling_cache_enabled) {
4116     do_backward = true;
4117     auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4118         pooling_desc.handle(), dest_desc.handle(), &workspace_size);
4119     if (status != miopenStatusSuccess) {
4120       LOG(ERROR)
4121           << "failed to obtain workspace size for backward pooling on stream: "
4122           << ToString(status);
4123       return false;
4124     }
4125     if (workspace_size != 0) {
4126       PoolingWorkspaceDescriptor* pdesc = 0;
4127       bool cache_hit =
4128           m_pooling_cache_allowed &&
4129           m_pooling_cache.find(input_data.opaque(), input_dimensions,
4130                                output_dimensions, pooling_dimensions,
4131                                miopenFloat, pdesc);
4132       if (cache_hit) {
4133         // reusing the same buffer
4134         workspace = reinterpret_cast<uint8*>(
4135             pdesc->workspace->mutable_device_memory()->opaque());
4136       } else {
4137         wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
4138                       .ConsumeValueOrDie();
4139         workspace = reinterpret_cast<uint8*>(
4140             wsp_mem->mutable_device_memory()->opaque());
4141         m_pooling_cache.insert(input_data.opaque(), input_dimensions,
4142                                output_dimensions, pooling_dimensions,
4143                                miopenFloat, wsp_mem, workspace_size,
4144                                AsGpuStreamValue(stream));
4145       }
4146     }
4147   }
4148 
4149   auto status = wrap::miopenPoolingForward(
4150       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4151       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
4152       do_backward, workspace, workspace_size);
4153   if (status != miopenStatusSuccess) {
4154     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
4155                << ToString(status);
4156     return false;
4157   }
4158   return true;
4159 }
4160 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4161 bool MIOpenSupport::DoPoolForward(
4162     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4163     const dnn::BatchDescriptor& input_dimensions,
4164     const DeviceMemory<Eigen::half>& input_data,
4165     const dnn::BatchDescriptor& output_dimensions,
4166     DeviceMemory<Eigen::half>* output_data,
4167     ScratchAllocator* workspace_allocator) {
4168   auto miopen = miopen_->GetHandle(parent_, stream);
4169 
4170   // Alpha is the scaling factor for input.
4171   float alpha = 1.0;
4172   // Beta is the scaling factor for output.
4173   float beta = 0.0;
4174 
4175   ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
4176   ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
4177   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4178 
4179   auto status = wrap::miopenPoolingForward(
4180       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4181       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
4182       false, nullptr, 0);
4183   if (status != miopenStatusSuccess) {
4184     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
4185                << ToString(status);
4186     return false;
4187   }
4188   return true;
4189 }
4190 
4191 template <class T>
DoPoolBackwardImpl(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<T> & output_data,const DeviceMemory<T> & input_diff_data,DeviceMemory<T> * output_diff_data,ScratchAllocator * workspace_allocator)4192 bool MIOpenSupport::DoPoolBackwardImpl(
4193     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4194     const dnn::BatchDescriptor& input_dimensions,
4195     const DeviceMemory<T>& input_data,
4196     const dnn::BatchDescriptor& output_dimensions,
4197     const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
4198     DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
4199   auto miopen = miopen_->GetHandle(parent_, stream);
4200   if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
4201   // Alpha is the scaling factor for input.
4202   float alpha = 1.0;
4203   // Beta is the scaling factor for output.
4204   float beta = 0.0;
4205 
4206   auto type =
4207       std::is_same<T, float>::value
4208           ? miopenFloat
4209           : (std::is_same<T, Eigen::half>::value ? miopenHalf
4210                                                  : (miopenDataType_t)-1);
4211 
4212   ScopedTensorDescriptor src_desc{input_dimensions, type};
4213   ScopedTensorDescriptor dest_desc{output_dimensions, type};
4214   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4215 
4216   uint8* workspace_ptr = 0;
4217   DeviceMemory<uint8> workspace;
4218   PoolingWorkspaceDescriptor* pdesc = 0;
4219 
4220   size_t workspace_size_in_bytes = 0;
4221   auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4222       pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
4223   if (status != miopenStatusSuccess) {
4224     LOG(ERROR)
4225         << "failed to obtain workspace size for backward pooling on stream: "
4226         << ToString(status);
4227     return false;
4228   }
4229 
4230   // Allocate the workspace.
4231   if (workspace_size_in_bytes > 0) {
4232     bool cache_hit = m_pooling_cache_allowed &&
4233                      m_pooling_cache.find(input_data.opaque(), input_dimensions,
4234                                           output_dimensions, pooling_dimensions,
4235                                           type, pdesc);
4236     if (cache_hit) {
4237       assert(pdesc != 0);
4238       workspace_ptr = reinterpret_cast<uint8*>(
4239           pdesc->workspace->mutable_device_memory()->opaque());
4240       VLOG(1) << "Pooling cache hit";
4241     } else {
4242       VLOG(1) << "Pooling cache miss";
4243       assert(workspace_allocator);
4244       auto allocated =
4245           workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4246       if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4247         LOG(ERROR) << "Failed to allocate backward pooling workspace";
4248         return false;
4249       }
4250       DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4251       int64 dest2_size = 0;
4252 
4253       // miopen requires the strides and dims to be ordered as BDYX.
4254       std::vector<int64> dims64 =
4255           output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4256       // miopen does not use strides and must have 4D tensor.
4257       // std::vector<int> dims(pooling_dimensions.ndims() + 2);
4258 
4259       dest2_size = sizeof(T);
4260       for (auto& x : dims64) dest2_size *= x;
4261 
4262       if (dest2_size > 0) {
4263         assert(workspace_allocator);
4264         auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4265         if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4266           LOG(ERROR) << "Failed to allocate backward pooling workspace";
4267           return false;
4268         }
4269       } else {
4270         LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4271                       "backward pooling";
4272       }
4273 
4274       status = wrap::miopenPoolingForward(
4275           miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4276           input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
4277           workspace.opaque(), workspace_size_in_bytes);
4278 
4279       if (status != miopenStatusSuccess) {
4280         LOG(ERROR)
4281             << "failed to enqueue forward pooling (before backward) on stream: "
4282             << ToString(status);
4283         return false;
4284       }
4285       workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
4286     }
4287   }
4288   status = wrap::miopenPoolingBackward(
4289       miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4290       output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4291       src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4292       output_diff_data->opaque(), workspace_ptr);
4293 
4294   if (status != miopenStatusSuccess) {
4295     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
4296                << ToString(status);
4297     return false;
4298   }
4299 
4300   return true;
4301 }
4302 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4303 bool MIOpenSupport::DoPoolBackward(
4304     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4305     const dnn::BatchDescriptor& input_dimensions,
4306     const DeviceMemory<double>& input_data,
4307     const dnn::BatchDescriptor& output_dimensions,
4308     const DeviceMemory<double>& output_data,
4309     const DeviceMemory<double>& input_diff_data,
4310     DeviceMemory<double>* output_diff_data,
4311     ScratchAllocator* workspace_allocator) {
4312   LOG(ERROR) << "miopen does not support backward pooling on double type yet";
4313   return false;
4314 }
4315 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4316 bool MIOpenSupport::DoPoolBackward(
4317     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4318     const dnn::BatchDescriptor& input_dimensions,
4319     const DeviceMemory<float>& input_data,
4320     const dnn::BatchDescriptor& output_dimensions,
4321     const DeviceMemory<float>& output_data,
4322     const DeviceMemory<float>& input_diff_data,
4323     DeviceMemory<float>* output_diff_data,
4324     ScratchAllocator* workspace_allocator) {
4325   return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
4326                             input_data, output_dimensions, output_data,
4327                             input_diff_data, output_diff_data,
4328                             workspace_allocator);
4329 }
4330 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4331 bool MIOpenSupport::DoPoolBackward(
4332     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4333     const dnn::BatchDescriptor& input_dimensions,
4334     const DeviceMemory<Eigen::half>& input_data,
4335     const dnn::BatchDescriptor& output_dimensions,
4336     const DeviceMemory<Eigen::half>& output_data,
4337     const DeviceMemory<Eigen::half>& input_diff_data,
4338     DeviceMemory<Eigen::half>* output_diff_data,
4339     ScratchAllocator* workspace_allocator) {
4340   return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
4341                             input_data, output_dimensions, output_data,
4342                             input_diff_data, output_diff_data,
4343                             workspace_allocator);
4344 }
4345 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4346 bool MIOpenSupport::DoNormalizeWithDimensions(
4347     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4348     const dnn::BatchDescriptor& dimensions,
4349     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4350   // Check for unsupported modes.
4351   if (normalize_descriptor.wrap_around()) {
4352     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4353     return false;
4354   }
4355   if (normalize_descriptor.segment_size()) {
4356     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4357     return false;
4358   }
4359 
4360   auto miopen = miopen_->GetHandle(parent_, stream);
4361 
4362   // Launch the normalization.
4363   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4364   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4365 
4366   // Alpha is the scaling factor for input.
4367   float alpha = 1.0f;
4368   // Beta is the scaling factor for output.
4369   float beta = 0.0f;
4370 
4371   auto status = wrap::miopenLRNForward(
4372       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4373       input_data.opaque(), &beta, dims.handle(), output_data->opaque(), false,
4374       nullptr);
4375   if (status != miopenStatusSuccess) {
4376     LOG(ERROR) << "failed to run miopenLRNForward";
4377     return false;
4378   }
4379   return true;
4380 }
4381 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4382 bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
4383     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4384     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4385     const DeviceMemory<float>& normalized_data,
4386     const DeviceMemory<float>& normalized_variable_gradient,
4387     DeviceMemory<float>* raw_variable_gradient,
4388     ScratchAllocator* workspace_allocator) {
4389   // Check for unsupported modes.
4390   if (normalize_descriptor.wrap_around()) {
4391     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4392     return false;
4393   }
4394   if (normalize_descriptor.segment_size()) {
4395     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4396     return false;
4397   }
4398 
4399   auto miopen = miopen_->GetHandle(parent_, stream);
4400 
4401   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4402   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4403 
4404   float alpha = 1.0f;
4405   float beta = 0.0f;
4406 
4407   DeviceMemory<uint8> workspace;
4408   size_t workspace_size_in_bytes = 0;
4409   auto status =
4410       wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes);
4411 
4412   if (status != miopenStatusSuccess) {
4413     LOG(ERROR) << "failed to obtain workspace size for miopenLRNBackward";
4414     return false;
4415   }
4416 
4417   // Allocate the workspace.
4418   if (workspace_size_in_bytes > 0) {
4419     assert(workspace_allocator);
4420     auto allocated =
4421         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4422     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4423       LOG(ERROR) << "Failed to allocate backward pooling workspace";
4424       return false;
4425     }
4426   }
4427 
4428   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4429   int dest2_size = 0;
4430 
4431   // miopen requires the strides and dims to be ordered as BDYX.
4432   std::vector<int64> dims64 =
4433       dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4434 
4435   // miopen does not use strides and must have 4D tensor.
4436   std::vector<int> dimsint(4);
4437 
4438   std::transform(dims64.cbegin(), dims64.cend(), dimsint.begin(),
4439                  &CheckedNarrowing<int64, int>);
4440 
4441   dest2_size =
4442       dimsint[0] * dimsint[1] * dimsint[2] * dimsint[3] * sizeof(float);
4443 
4444   if (dest2_size > 0) {
4445     assert(workspace_allocator);
4446     auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4447     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4448       LOG(ERROR)
4449           << "Failed to allocate tensor to chain forward and backward LRN";
4450       return false;
4451     }
4452   } else {
4453     LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4454                   "backward LRN";
4455   }
4456 
4457   status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
4458                                   dims.handle(), raw_data.opaque(), &beta,
4459                                   dims.handle(), dest2.opaque(), true,
4460                                   workspace.opaque());
4461 
4462   if (status != miopenStatusSuccess) {
4463     LOG(ERROR) << "failed to run miopenLRNForward";
4464     return false;
4465   }
4466 
4467   status = wrap::miopenLRNBackward(
4468       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4469       normalized_data.opaque(), dims.handle(),
4470       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4471       &beta, dims.handle(), raw_variable_gradient->opaque(),
4472       workspace.opaque());
4473 
4474   if (status != miopenStatusSuccess) {
4475     LOG(ERROR) << "failed to run miopenLRNBackward";
4476     return false;
4477   }
4478   return true;
4479 }
4480 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4481 bool MIOpenSupport::DoDepthConcatenate(
4482     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4483     port::ArraySlice<const DeviceMemory<float>*> input_data,
4484     DeviceMemory<float>* output_data) {
4485   CHECK_EQ(input_dimensions.size(), input_data.size());
4486 
4487   for (const auto& dimensions : input_dimensions) {
4488     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4489       LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only "
4490                     "supports the kBatchDepthYX layout.";
4491       return false;
4492     }
4493   }
4494 
4495   if (input_dimensions.empty()) {
4496     return true;  // Nothing to do.
4497   }
4498 
4499   dnn::BatchDescriptor output_dimensions =
4500       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4501 
4502   const int64 area = output_dimensions.width() * output_dimensions.height();
4503   const auto index = [area](int64 batch, int64 depth, int64 yx,
4504                             int64 max_depth) {
4505     return (batch * max_depth + depth) * area + yx;
4506   };
4507 
4508   std::vector<float> output_host(output_dimensions.ElementCount());
4509   std::vector<float> tmp;
4510   int64 depth_sum = 0;
4511   for (size_t i = 0; i < input_data.size(); ++i) {
4512     const auto& dimensions = input_dimensions[i];
4513     tmp.resize(dimensions.ElementCount());
4514     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4515     port::Status block_status = stream->BlockHostUntilDone();
4516     if (!block_status.ok()) {
4517       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4518       return false;
4519     }
4520 
4521     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4522       for (int64 yx = 0; yx < area; ++yx) {
4523         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4524           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4525                     << yx << ' ' << depth;
4526           output_host[index(batch, depth + depth_sum, yx,
4527                             output_dimensions.feature_map_count())] =
4528               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4529         }
4530       }
4531     }
4532     depth_sum += dimensions.feature_map_count();
4533   }
4534   stream->ThenMemcpyH2D<float>(output_host, output_data);
4535   return true;
4536 }
4537 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4538 bool MIOpenSupport::DoElementwiseOperate(
4539     Stream* stream, dnn::ElementwiseOperation operation,
4540     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4541     port::ArraySlice<const DeviceMemory<float>*> input_data,
4542     const dnn::BatchDescriptor& output_dimensions,
4543     DeviceMemory<float>* output_data) {
4544   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4545   return false;
4546 }
4547 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4548 bool MIOpenSupport::DoXYPad(Stream* stream,
4549                             const dnn::BatchDescriptor& dimensions,
4550                             const DeviceMemory<float>& input_data,
4551                             int64 left_pad, int64 right_pad, int64 top_pad,
4552                             int64 bottom_pad,
4553                             DeviceMemory<float>* output_data) {
4554   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4555   return false;
4556 }
4557 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4558 bool MIOpenSupport::DoXYSlice(Stream* stream,
4559                               const dnn::BatchDescriptor& dimensions,
4560                               const DeviceMemory<float>& input_data,
4561                               int64 left_trim, int64 right_trim, int64 top_trim,
4562                               int64 bottom_trim,
4563                               DeviceMemory<float>* output_data) {
4564   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4565   return false;
4566 }
4567 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4568 bool MIOpenSupport::DoMemcpyD2HQuantized(
4569     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4570     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4571   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4572   return false;
4573 }
4574 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4575 bool MIOpenSupport::DoMemcpyH2DQuantized(
4576     Stream* stream, const void* host_src, int64 size,
4577     dnn::QuantizedActivationMode mode,
4578     DeviceMemory<float>* gpu_unquantized_dst) {
4579   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4580   return false;
4581 }
4582 
DeriveOutputBatchDescriptor(const BatchDescriptor & batch_descriptor,const FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4583 bool MIOpenSupport::DeriveOutputBatchDescriptor(
4584     const BatchDescriptor& batch_descriptor,
4585     const FilterDescriptor& filter_descriptor,
4586     const dnn::ConvolutionDescriptor& convolution_descriptor,
4587     dnn::BatchDescriptor* output_batch_descriptor) {
4588   ScopedTensorDescriptor input_nd{batch_descriptor, miopenFloat};
4589   ScopedFilterDescriptor filter{filter_descriptor, batch_descriptor,
4590                                 miopenFloat};
4591   ScopedConvolutionDescriptor conv{convolution_descriptor, miopenFloat};
4592 
4593   int dn = batch_descriptor.ndims() + 2;
4594   std::vector<int> dims(dn);  // in BDYX
4595   auto status = wrap::miopenGetConvolutionNdForwardOutputDim(
4596       conv.handle(), input_nd.handle(), filter.handle(), &dn, dims.data());
4597   if (status != miopenStatusSuccess) {
4598     LOG(ERROR) << "could not get output tensor for convolution: "
4599                << ToString(status);
4600     return false;
4601   }
4602 
4603   output_batch_descriptor->set_count(dims[0])
4604       .set_feature_map_count(dims[1])
4605       .set_layout(batch_descriptor.layout());
4606 
4607   for (int i = 0; i < batch_descriptor.ndims(); i++) {
4608     output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4609                                              dims.rbegin()[i]);
4610   }
4611 
4612   return true;
4613 }
4614 
4615 template <typename T>
DoFusedConvolutionBiasActivationImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<T> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<T> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<T> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<T> * output_data,dnn::ProfileResult * output_profile_result)4616 bool MIOpenSupport::DoFusedConvolutionBiasActivationImpl(
4617     Stream* stream,
4618     int miopen_type,  // Actually miopenDataType_t.
4619     const dnn::BatchDescriptor& conv_input_descriptor,
4620     const DeviceMemory<T>& conv_input_data,
4621     const dnn::FilterDescriptor& filter_descriptor,
4622     const DeviceMemory<T>& filter_data,
4623     const dnn::ConvolutionDescriptor& convolution_descriptor,
4624     const dnn::BatchDescriptor& bias_descriptor,
4625     const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
4626     const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
4627     dnn::ProfileResult* output_profile_result) {
4628   auto miopen = miopen_->GetHandle(parent_, stream);
4629 
4630   ScopedTensorDescriptor conv_input_nd{
4631       conv_input_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4632 
4633   ScopedTensorDescriptor bias_nd{bias_descriptor,
4634                                  static_cast<miopenDataType_t>(miopen_type)};
4635 
4636   ScopedTensorDescriptor output_nd{output_descriptor,
4637                                    static_cast<miopenDataType_t>(miopen_type)};
4638 
4639   ScopedConvolutionDescriptor conv{convolution_descriptor,
4640                                    static_cast<miopenDataType_t>(miopen_type)};
4641 
4642   ScopedFilterDescriptor filter{filter_descriptor, conv_input_descriptor,
4643                                 static_cast<miopenDataType_t>(miopen_type)};
4644 
4645   ScopedActivationDescriptor activation_desc{activation_mode};
4646 
4647   ScopedFusionPlanConvolutionBiasActivation fusion_plan{
4648       miopen.handle(), conv_input_nd.handle(), filter.handle(),
4649       conv.handle(),   bias_nd.handle(),       activation_desc};
4650 
4651   bool retval = false;
4652 
4653   if (fusion_plan.CompilationSucceeded()) {
4654     const bool is_profiling = output_profile_result != nullptr;
4655 
4656     std::unique_ptr<GpuTimer> timer;
4657     if (is_profiling) {
4658       timer.reset(new GpuTimer(parent_));
4659       timer->Init();
4660       timer->Start(AsGpuStream(stream));
4661     }
4662 
4663     miopenStatus_t status = miopenStatusSuccess;
4664 
4665     if (status == miopenStatusSuccess) {
4666       fusion_plan.SetConvolutionArgs(filter_data.opaque());
4667     }
4668 
4669     if (status == miopenStatusSuccess) {
4670       status = fusion_plan.SetBiasArgs(bias_data.opaque());
4671     }
4672 
4673     if (status == miopenStatusSuccess) {
4674       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4675     }
4676 
4677     if (status == miopenStatusSuccess) {
4678       status =
4679           fusion_plan.Execute(conv_input_nd.handle(), conv_input_data.opaque(),
4680                               output_nd.handle(), output_data->opaque());
4681     }
4682 
4683     if (is_profiling) {
4684       timer->Stop(AsGpuStream(stream));
4685       if (status == miopenStatusSuccess) {
4686         output_profile_result->set_elapsed_time_in_ms(
4687             timer->GetElapsedMilliseconds());
4688       }
4689       timer->Destroy();
4690     }
4691 
4692     if (status != miopenStatusSuccess) {
4693       // Silently return when we are profiling.
4694       if (!is_profiling) {
4695         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4696                    << ToString(status);
4697       }
4698     }
4699 
4700     retval = true;
4701   }
4702 
4703   return retval;
4704 }
4705 
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)4706 bool MIOpenSupport::DoFusedConvolutionBiasActivation(
4707     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
4708     const DeviceMemory<float>& conv_input_data,
4709     const dnn::FilterDescriptor& filter_descriptor,
4710     const DeviceMemory<float>& filter_data,
4711     const dnn::ConvolutionDescriptor& convolution_descriptor,
4712     const dnn::BatchDescriptor& bias_descriptor,
4713     const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
4714     const dnn::BatchDescriptor& output_descriptor,
4715     DeviceMemory<float>* output_data,
4716     dnn::ProfileResult* output_profile_result) {
4717   return DoFusedConvolutionBiasActivationImpl<float>(
4718       stream, miopenFloat, conv_input_descriptor, conv_input_data,
4719       filter_descriptor, filter_data, convolution_descriptor, bias_descriptor,
4720       bias_data, activation_mode, output_descriptor, output_data,
4721       output_profile_result);
4722 }
4723 
4724 template <typename T, typename U>
DoFusedBatchNormActivationInferenceImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & mean_data,const DeviceMemory<U> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,dnn::ProfileResult * output_profile_result)4725 bool MIOpenSupport::DoFusedBatchNormActivationInferenceImpl(
4726     Stream* stream,
4727     int miopen_type,  // Actually miopenDataType_t.
4728     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4729     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4730     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4731     const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
4732     double epsilon, dnn::ActivationMode activation_mode,
4733     DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result) {
4734   auto miopen = miopen_->GetHandle(parent_, stream);
4735 
4736   ScopedTensorDescriptor x_nd{x_descriptor,
4737                               static_cast<miopenDataType_t>(miopen_type)};
4738 
4739   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4740       scale_offset_mean_variance_descriptor,
4741       static_cast<miopenDataType_t>(miopen_type)};
4742 
4743   ScopedActivationDescriptor activation_desc{activation_mode};
4744 
4745   ScopedFusionPlanBatchNormActivationInference fusion_plan{
4746       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4747       activation_desc};
4748 
4749   bool retval = false;
4750 
4751   if (fusion_plan.CompilationSucceeded()) {
4752     const bool is_profiling = output_profile_result != nullptr;
4753 
4754     std::unique_ptr<GpuTimer> timer;
4755     if (is_profiling) {
4756       timer.reset(new GpuTimer(parent_));
4757       timer->Init();
4758       timer->Start(AsGpuStream(stream));
4759     }
4760 
4761     miopenStatus_t status = miopenStatusSuccess;
4762 
4763     if (status == miopenStatusSuccess) {
4764       fusion_plan.SetBatchNormInferenceArgs(
4765           scale_data.opaque(), offset_data.opaque(), mean_data.opaque(),
4766           variance_data.opaque(), epsilon);
4767     }
4768 
4769     if (status == miopenStatusSuccess) {
4770       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4771     }
4772 
4773     if (status == miopenStatusSuccess) {
4774       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4775                                    x_nd.handle(), y_data->opaque());
4776     }
4777 
4778     if (is_profiling) {
4779       timer->Stop(AsGpuStream(stream));
4780       if (status == miopenStatusSuccess) {
4781         output_profile_result->set_elapsed_time_in_ms(
4782             timer->GetElapsedMilliseconds());
4783       }
4784       timer->Destroy();
4785     }
4786 
4787     if (status != miopenStatusSuccess) {
4788       // Silently return when we are profiling.
4789       if (!is_profiling) {
4790         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4791                    << ToString(status);
4792       }
4793     }
4794 
4795     retval = true;
4796   }
4797 
4798   return retval;
4799 }
4800 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)4801 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4802     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4803     const DeviceMemory<float>& x_data,
4804     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4805     const DeviceMemory<float>& scale_data,
4806     const DeviceMemory<float>& offset_data,
4807     const DeviceMemory<float>& mean_data,
4808     const DeviceMemory<float>& variance_data, double epsilon,
4809     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4810     dnn::ProfileResult* output_profile_result) {
4811   return DoFusedBatchNormActivationInferenceImpl<float, float>(
4812       stream, miopenFloat, x_descriptor, x_data,
4813       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4814       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4815 }
4816 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)4817 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4818     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4819     const DeviceMemory<Eigen::half>& x_data,
4820     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4821     const DeviceMemory<float>& scale_data,
4822     const DeviceMemory<float>& offset_data,
4823     const DeviceMemory<float>& mean_data,
4824     const DeviceMemory<float>& variance_data, double epsilon,
4825     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4826     dnn::ProfileResult* output_profile_result) {
4827   return DoFusedBatchNormActivationInferenceImpl<Eigen::half, float>(
4828       stream, miopenHalf, x_descriptor, x_data,
4829       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4830       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4831 }
4832 
4833 template <typename T, typename U>
DoFusedBatchNormActivationForwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,DeviceMemory<U> * batch_mean_data,DeviceMemory<U> * batch_var_data,DeviceMemory<U> * saved_mean_data,DeviceMemory<U> * saved_var_data,dnn::ProfileResult * output_profile_result)4834 bool MIOpenSupport::DoFusedBatchNormActivationForwardImpl(
4835     Stream* stream,
4836     int miopen_type,  // Actually miopenDataType_t.
4837     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4838     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4839     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4840     double epsilon, dnn::ActivationMode activation_mode,
4841     DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
4842     DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
4843     DeviceMemory<U>* saved_var_data,
4844     dnn::ProfileResult* output_profile_result) {
4845   auto miopen = miopen_->GetHandle(parent_, stream);
4846 
4847   ScopedTensorDescriptor x_nd{x_descriptor,
4848                               static_cast<miopenDataType_t>(miopen_type)};
4849 
4850   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4851       scale_offset_mean_variance_descriptor,
4852       static_cast<miopenDataType_t>(miopen_type)};
4853 
4854   ScopedActivationDescriptor activation_desc{activation_mode};
4855 
4856   ScopedFusionPlanBatchNormActivationForward fusion_plan{
4857       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4858       activation_desc};
4859 
4860   bool retval = false;
4861 
4862   if (fusion_plan.CompilationSucceeded()) {
4863     const bool is_profiling = output_profile_result != nullptr;
4864 
4865     std::unique_ptr<GpuTimer> timer;
4866     if (is_profiling) {
4867       timer.reset(new GpuTimer(parent_));
4868       timer->Init();
4869       timer->Start(AsGpuStream(stream));
4870     }
4871 
4872     miopenStatus_t status = miopenStatusSuccess;
4873 
4874     if (status == miopenStatusSuccess) {
4875       fusion_plan.SetBatchNormForwardArgs(
4876           scale_data.opaque(), offset_data.opaque(), batch_mean_data->opaque(),
4877           batch_var_data->opaque(), saved_mean_data->opaque(),
4878           saved_var_data->opaque(), epsilon);
4879     }
4880 
4881     if (status == miopenStatusSuccess) {
4882       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4883     }
4884 
4885     if (status == miopenStatusSuccess) {
4886       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4887                                    x_nd.handle(), y_data->opaque());
4888     }
4889 
4890     if (is_profiling) {
4891       timer->Stop(AsGpuStream(stream));
4892       if (status == miopenStatusSuccess) {
4893         output_profile_result->set_elapsed_time_in_ms(
4894             timer->GetElapsedMilliseconds());
4895       }
4896       timer->Destroy();
4897     }
4898 
4899     if (status != miopenStatusSuccess) {
4900       // Silently return when we are profiling.
4901       if (!is_profiling) {
4902         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4903                    << ToString(status);
4904       }
4905     }
4906 
4907     retval = true;
4908   }
4909 
4910   return retval;
4911 }
4912 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4913 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4914     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4915     const DeviceMemory<float>& x_data,
4916     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4917     const DeviceMemory<float>& scale_data,
4918     const DeviceMemory<float>& offset_data, double epsilon,
4919     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4920     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4921     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4922     dnn::ProfileResult* output_profile_result) {
4923   return DoFusedBatchNormActivationForwardImpl<float, float>(
4924       stream, miopenFloat, x_descriptor, x_data,
4925       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4926       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4927       saved_var_data, output_profile_result);
4928 }
4929 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4930 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4931     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4932     const DeviceMemory<Eigen::half>& x_data,
4933     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4934     const DeviceMemory<float>& scale_data,
4935     const DeviceMemory<float>& offset_data, double epsilon,
4936     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4937     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4938     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4939     dnn::ProfileResult* output_profile_result) {
4940   return DoFusedBatchNormActivationForwardImpl<Eigen::half, float>(
4941       stream, miopenHalf, x_descriptor, x_data,
4942       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4943       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4944       saved_var_data, output_profile_result);
4945 }
4946 
4947 template <typename T, typename U>
DoFusedBatchNormActivationBackwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<T> & y_act_backprop_data,const DeviceMemory<T> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<T> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & saved_mean_data,const DeviceMemory<U> & saved_var_data,DeviceMemory<T> * x_bn_backprop_data,DeviceMemory<U> * scale_backprop_data,DeviceMemory<U> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4948 bool MIOpenSupport::DoFusedBatchNormActivationBackwardImpl(
4949     Stream* stream,
4950     int miopen_type,  // Actually miopenDataType_t.
4951     const dnn::BatchDescriptor& y_act_backprop_descriptor,
4952     const DeviceMemory<T>& y_act_backprop_data,
4953     const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
4954     const DeviceMemory<T>& x_bn_data,
4955     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4956     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4957     const DeviceMemory<U>& saved_mean_data,
4958     const DeviceMemory<U>& saved_var_data, DeviceMemory<T>* x_bn_backprop_data,
4959     DeviceMemory<U>* scale_backprop_data, DeviceMemory<U>* offset_backprop_data,
4960     dnn::ProfileResult* output_profile_result) {
4961   auto miopen = miopen_->GetHandle(parent_, stream);
4962 
4963   ScopedTensorDescriptor y_act_backprop_nd{
4964       y_act_backprop_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4965 
4966   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4967       scale_offset_mean_variance_descriptor,
4968       static_cast<miopenDataType_t>(miopen_type)};
4969 
4970   ScopedActivationDescriptor activation_desc{activation_mode};
4971 
4972   ScopedFusionPlanBatchNormActivationBackward fusion_plan{
4973       miopen.handle(), y_act_backprop_nd.handle(),
4974       scale_offset_mean_variance_nd.handle(), activation_desc};
4975 
4976   bool retval = false;
4977 
4978   if (fusion_plan.CompilationSucceeded()) {
4979     const bool is_profiling = output_profile_result != nullptr;
4980 
4981     std::unique_ptr<GpuTimer> timer;
4982     if (is_profiling) {
4983       timer.reset(new GpuTimer(parent_));
4984       timer->Init();
4985       timer->Start(AsGpuStream(stream));
4986     }
4987 
4988     miopenStatus_t status = miopenStatusSuccess;
4989 
4990     if (status == miopenStatusSuccess) {
4991       fusion_plan.SetBatchNormBackwardArgs(
4992           x_bn_data.opaque(), scale_data.opaque(), offset_data.opaque(),
4993           saved_mean_data.opaque(), saved_var_data.opaque(),
4994           scale_backprop_data->opaque(), offset_backprop_data->opaque());
4995     }
4996 
4997     if (status == miopenStatusSuccess) {
4998       status = fusion_plan.SetActivationBackwardArgs(activation_desc,
4999                                                      y_act_data.opaque());
5000     }
5001 
5002     if (status == miopenStatusSuccess) {
5003       status = fusion_plan.Execute(
5004           y_act_backprop_nd.handle(), y_act_backprop_data.opaque(),
5005           y_act_backprop_nd.handle(), x_bn_backprop_data->opaque());
5006     }
5007 
5008     if (is_profiling) {
5009       timer->Stop(AsGpuStream(stream));
5010       if (status == miopenStatusSuccess) {
5011         output_profile_result->set_elapsed_time_in_ms(
5012             timer->GetElapsedMilliseconds());
5013       }
5014       timer->Destroy();
5015     }
5016 
5017     if (status != miopenStatusSuccess) {
5018       // Silently return when we are profiling.
5019       if (!is_profiling) {
5020         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
5021                    << ToString(status);
5022       }
5023     }
5024 
5025     retval = true;
5026   }
5027 
5028   return retval;
5029 }
5030 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)5031 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
5032     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
5033     const DeviceMemory<float>& y_act_backprop_data,
5034     const DeviceMemory<float>& y_act_data, dnn::ActivationMode activation_mode,
5035     const DeviceMemory<float>& x_bn_data,
5036     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
5037     const DeviceMemory<float>& scale_data,
5038     const DeviceMemory<float>& offset_data,
5039     const DeviceMemory<float>& saved_mean_data,
5040     const DeviceMemory<float>& saved_var_data,
5041     DeviceMemory<float>* x_bn_backprop_data,
5042     DeviceMemory<float>* scale_backprop_data,
5043     DeviceMemory<float>* offset_backprop_data,
5044     dnn::ProfileResult* output_profile_result) {
5045   return DoFusedBatchNormActivationBackwardImpl<float, float>(
5046       stream, miopenFloat, y_act_backprop_descriptor, y_act_backprop_data,
5047       y_act_data, activation_mode, x_bn_data,
5048       scale_offset_mean_variance_descriptor, scale_data, offset_data,
5049       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
5050       offset_backprop_data, output_profile_result);
5051 }
5052 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)5053 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
5054     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
5055     const DeviceMemory<Eigen::half>& y_act_backprop_data,
5056     const DeviceMemory<Eigen::half>& y_act_data,
5057     dnn::ActivationMode activation_mode,
5058     const DeviceMemory<Eigen::half>& x_bn_data,
5059     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
5060     const DeviceMemory<float>& scale_data,
5061     const DeviceMemory<float>& offset_data,
5062     const DeviceMemory<float>& saved_mean_data,
5063     const DeviceMemory<float>& saved_var_data,
5064     DeviceMemory<Eigen::half>* x_bn_backprop_data,
5065     DeviceMemory<float>* scale_backprop_data,
5066     DeviceMemory<float>* offset_backprop_data,
5067     dnn::ProfileResult* output_profile_result) {
5068   return DoFusedBatchNormActivationBackwardImpl<Eigen::half, float>(
5069       stream, miopenHalf, y_act_backprop_descriptor, y_act_backprop_data,
5070       y_act_data, activation_mode, x_bn_data,
5071       scale_offset_mean_variance_descriptor, scale_data, offset_data,
5072       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
5073       offset_backprop_data, output_profile_result);
5074 }
5075 
5076 }  // namespace gpu
5077 
initialize_miopen()5078 void initialize_miopen() {
5079   auto miopenAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
5080       rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
5081 
5082   if (!miopenAlreadyRegistered) {
5083     port::Status status =
5084         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
5085             rocm::kROCmPlatformId, gpu::kMIOpenPlugin, "MIOpen",
5086             [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
5087               gpu::GpuExecutor* rocm_executor =
5088                   dynamic_cast<gpu::GpuExecutor*>(parent);
5089               if (rocm_executor == nullptr) {
5090                 LOG(ERROR)
5091                     << "Attempting to initialize an instance of the MIOpen "
5092                     << "support library with a non-ROCM StreamExecutor";
5093                 return nullptr;
5094               }
5095 
5096               gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(rocm_executor);
5097               if (!dnn->Init().ok()) {
5098                 // Note: Init() will log a more specific error.
5099                 delete dnn;
5100                 return nullptr;
5101               }
5102               return dnn;
5103             });
5104 
5105     if (!status.ok()) {
5106       LOG(ERROR) << "Unable to register MIOpen factory: "
5107                  << status.error_message();
5108     }
5109 
5110     PluginRegistry::Instance()->SetDefaultFactory(
5111         rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
5112   }
5113 }
5114 
5115 }  // namespace stream_executor
5116 
5117 REGISTER_MODULE_INITIALIZER(register_miopen,
5118                             { stream_executor::initialize_miopen(); });
5119