1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/strings/str_cat.h"
22 #include "third_party/eigen3/Eigen/Core"
23 #include "rocm/include/miopen/miopen.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 #include "tensorflow/stream_executor/dnn.h"
26 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
27 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
28 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
29 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
30 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
31 #include "tensorflow/stream_executor/lib/env.h"
32 #include "tensorflow/stream_executor/lib/error.h"
33 #include "tensorflow/stream_executor/lib/initialize.h"
34 #include "tensorflow/stream_executor/lib/threadpool.h"
35 #include "tensorflow/stream_executor/platform/dso_loader.h"
36 #include "tensorflow/stream_executor/platform/logging.h"
37 #include "tensorflow/stream_executor/plugin_registry.h"
38 #include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
39 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
40 #include "tensorflow/stream_executor/scratch_allocator.h"
41 #include "tensorflow/stream_executor/stream.h"
42 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
43 
44 namespace {
45 
46 // Converts (via narrowing) a type T value to a type U, and checks that the
47 // value has no value change due to the conversion.
48 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)49 NarrowT CheckedNarrowing(const WideT& wide) {
50   NarrowT narrow = wide;
51   CHECK_EQ(narrow, wide)
52       << "checked narrowing failed; values not equal post-conversion";
53   return narrow;
54 }
55 
56 }  // namespace
57 
58 namespace stream_executor {
59 
60 using dnn::BatchDescriptor;
61 using dnn::ConvolutionDescriptor;
62 using dnn::FilterDescriptor;
63 using dnn::NormalizeDescriptor;
64 using dnn::PoolingDescriptor;
65 
66 namespace gpu {
67 
68 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kMIOpenPlugin);
69 
ToString(miopenStatus_t status)70 string ToString(miopenStatus_t status) {
71   switch (status) {
72     case miopenStatusSuccess:
73       return "miopenStatusSuccess";
74     case miopenStatusNotInitialized:
75       return "miopenStatusNotInitialized";
76     case miopenStatusAllocFailed:
77       return "miopenStatusAllocFailed";
78     case miopenStatusBadParm:
79       return "miopenStatusBadParm";
80     case miopenStatusInternalError:
81       return "miopenStatusInternalError";
82     case miopenStatusInvalidValue:
83       return "miopenStatusInvalidValue";
84     case miopenStatusNotImplemented:
85       return "miopenStatusNotImplemented";
86     case miopenStatusUnknownError:
87       return "miopenStatusUnknownError";
88     default:
89       return absl::StrCat("<unknown miopen status: ", static_cast<int>(status),
90                           ">");
91   }
92 }
93 
94 // RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
95 //
96 // See MIOpenAccess::GetHandle() for details.
97 class MIOpenHandle {
98  public:
99   // Takes ownership of the executor context and the lock to access MIOpen
100   // using handle.
MIOpenHandle(gpu::ScopedActivateExecutorContext context,mutex_lock lock,miopenHandle_t handle)101   MIOpenHandle(gpu::ScopedActivateExecutorContext context, mutex_lock lock,
102                miopenHandle_t handle)
103       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
104 
105   // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
106   // a copy.
handle() const107   miopenHandle_t handle() const { return handle_; }
108 
109  private:
110   gpu::ScopedActivateExecutorContext context_;
111   mutex_lock lock_;
112   miopenHandle_t handle_;  // Not owned.
113 };
114 
115 namespace wrap {
116 
117 #ifdef PLATFORM_GOOGLE
118 
119 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)      \
120   struct WrapperShim__##__name {                 \
121     template <typename... Args>                  \
122     miopenStatus_t operator()(Args... args) {    \
123       miopenStatus_t retval = ::__name(args...); \
124       return retval;                             \
125     }                                            \
126   } __name;
127 
128 #else
129 
130 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)                               \
131   struct DynLoadShim__##__name {                                          \
132     static const char* kName;                                             \
133     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
134     static void* GetDsoHandle() {                                         \
135       auto s = internal::CachedDsoLoader::GetMiopenDsoHandle();           \
136       return s.ValueOrDie();                                              \
137     }                                                                     \
138     static FuncPtrT LoadOrDie() {                                         \
139       void* f;                                                            \
140       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
141                                                           kName, &f);     \
142       CHECK(s.ok()) << "could not find " << kName                         \
143                     << " in miopen DSO; dlerror: " << s.error_message();  \
144       return reinterpret_cast<FuncPtrT>(f);                               \
145     }                                                                     \
146     static FuncPtrT DynLoad() {                                           \
147       static FuncPtrT f = LoadOrDie();                                    \
148       return f;                                                           \
149     }                                                                     \
150     template <typename... Args>                                           \
151     miopenStatus_t operator()(Args... args) {                             \
152       return DynLoad()(args...);                                          \
153     }                                                                     \
154   } __name;                                                               \
155   const char* DynLoadShim__##__name::kName = #__name;
156 
157 #endif
158 
159 // clang-format off
160 #define MIOPEN_DNN_ROUTINE_EACH(__macro)                   \
161   __macro(miopenBatchNormalizationBackward)                \
162   __macro(miopenBatchNormalizationForwardInference)        \
163   __macro(miopenBatchNormalizationForwardTraining)         \
164   __macro(miopenGetConvolutionForwardOutputDim)            \
165   __macro(miopenFindConvolutionForwardAlgorithm)           \
166   __macro(miopenCreateTensorDescriptor)                    \
167   __macro(miopenDestroyTensorDescriptor)                   \
168   __macro(miopenSet2dPoolingDescriptor)                    \
169   __macro(miopenSetLRNDescriptor)                          \
170   __macro(miopenLRNGetWorkSpaceSize)                       \
171   __macro(miopenCreateConvolutionDescriptor)               \
172   __macro(miopenCreatePoolingDescriptor)                   \
173   __macro(miopenDestroyPoolingDescriptor)                  \
174   __macro(miopenCreateLRNDescriptor)                       \
175   __macro(miopenDestroyLRNDescriptor)                      \
176   __macro(miopenDestroyConvolutionDescriptor)              \
177   __macro(miopenCreateWithStream)                          \
178   __macro(miopenDestroy)                                   \
179   __macro(miopenSetStream)                                 \
180   __macro(miopenSetAllocator)                              \
181   __macro(miopenActivationForward)                         \
182   __macro(miopenConvolutionForward)                        \
183   __macro(miopenConvolutionBackwardBias)                   \
184   __macro(miopenConvolutionForwardGetWorkSpaceSize)        \
185   __macro(miopenInitConvolutionDescriptor)                 \
186   __macro(miopenGetConvolutionDescriptor)                  \
187   __macro(miopenSetConvolutionGroupCount)                  \
188   __macro(miopenSet4dTensorDescriptor)                     \
189   __macro(miopenGetTensorDescriptor)                       \
190   __macro(miopenSetTensorDescriptor)                       \
191   __macro(miopenGetTensorDescriptorSize)                   \
192   __macro(miopenPoolingForward)                            \
193   __macro(miopenPoolingGetWorkSpaceSize)                   \
194   __macro(miopenPoolingBackward)                           \
195   __macro(miopenLRNForward)                                \
196   __macro(miopenLRNBackward)                               \
197   __macro(miopenOpTensor)                                  \
198   __macro(miopenConvolutionBackwardData)                   \
199   __macro(miopenConvolutionBackwardWeights)                \
200   __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)\
201   __macro(miopenFindConvolutionBackwardDataAlgorithm)      \
202   __macro(miopenFindConvolutionBackwardWeightsAlgorithm)   \
203   __macro(miopenConvolutionBackwardDataGetWorkSpaceSize)   \
204   __macro(miopenCreateRNNDescriptor)                       \
205   __macro(miopenSetRNNDescriptor)                          \
206   __macro(miopenDestroyRNNDescriptor)                      \
207   __macro(miopenGetRNNParamsSize)                          \
208   __macro(miopenGetRNNLayerParam)                          \
209   __macro(miopenGetRNNLayerBias)                           \
210   __macro(miopenGetRNNWorkspaceSize)                       \
211   __macro(miopenGetRNNTrainingReserveSize)                 \
212   __macro(miopenRNNForwardInference)                       \
213   __macro(miopenRNNForwardTraining)                        \
214   __macro(miopenRNNBackwardData)                           \
215   __macro(miopenRNNBackwardWeights)                        \
216   __macro(miopenGetRNNLayerParamOffset)                    \
217   __macro(miopenGetRNNLayerParamSize)                      \
218   __macro(miopenGetRNNLayerBiasOffset)                     \
219   __macro(miopenGetRNNLayerBiasSize)                       \
220   __macro(miopenGetRNNParamsDescriptor)                    \
221   __macro(miopenCreateActivationDescriptor)                \
222   __macro(miopenSetActivationDescriptor)                   \
223   __macro(miopenGetActivationDescriptor)                   \
224   __macro(miopenDestroyActivationDescriptor)               \
225   __macro(miopenCreateFusionPlan)                          \
226   __macro(miopenCreateOpConvForward)                       \
227   __macro(miopenCreateOpBiasForward)                       \
228   __macro(miopenCreateOpActivationForward)                 \
229   __macro(miopenCreateOpActivationBackward)                \
230   __macro(miopenCreateOpBatchNormInference)                \
231   __macro(miopenCreateOpBatchNormForward)                  \
232   __macro(miopenCreateOpBatchNormBackward)                 \
233   __macro(miopenCompileFusionPlan)                         \
234   __macro(miopenFusionPlanGetOp)                           \
235   __macro(miopenCreateOperatorArgs)                        \
236   __macro(miopenSetOpArgsConvForward)                      \
237   __macro(miopenSetOpArgsBiasForward)                      \
238   __macro(miopenSetOpArgsActivForward)                     \
239   __macro(miopenSetOpArgsActivBackward)                    \
240   __macro(miopenSetOpArgsBatchNormInference)               \
241   __macro(miopenSetOpArgsBatchNormForward)                 \
242   __macro(miopenSetOpArgsBatchNormBackward)                \
243   __macro(miopenExecuteFusionPlan)                         \
244   __macro(miopenDestroyOperatorArgs)                       \
245   __macro(miopenDestroyFusionPlan)
246 
247 // clang-format on
248 
249 MIOPEN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_MIOPEN_WRAP)
250 
251 #undef MIOPEN_DNN_ROUTINE_EACH
252 
253 }  // namespace wrap
254 
255 namespace {
256 
257 // These routines should ideally be provided as an MIOpen API.
258 // They are called for *every* _ROCMmFusedOp*::Compute call, and they need to be
259 // efficient! Instead of calculating the hash value by quering the MIOpen Get*
260 // APIs for the descriptor components, it would be a lot more efficient if,
261 // MIOpen calculated the hash value when creating the descriptor, stored it on
262 // the descriptor datastructure, and provided an API routine to query it.
263 
264 const int kMaxMIOpenTensorSize = 5;
265 
GetHashValue(miopenTensorDescriptor_t tensor_desc)266 uint64 GetHashValue(miopenTensorDescriptor_t tensor_desc) {
267   miopenDataType_t datatype = miopenFloat;
268   int dims[kMaxMIOpenTensorSize] = {0};
269   int strides[kMaxMIOpenTensorSize] = {0};
270   wrap::miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides);
271 
272   uint64 hash_value = tensorflow::hash<int>()(datatype);
273   for (int dim : dims)
274     hash_value =
275         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(dim));
276   for (int stride : strides)
277     hash_value =
278         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(stride));
279 
280   return hash_value;
281 }
282 
GetHashValue(miopenConvolutionDescriptor_t conv_desc)283 uint64 GetHashValue(miopenConvolutionDescriptor_t conv_desc) {
284   miopenConvolutionMode_t c_mode = miopenConvolution;
285   int pad_h = 0, pad_w = 0, u = 0, v = 0, dilation_h = 0, dilation_w = 0;
286   wrap::miopenGetConvolutionDescriptor(conv_desc, &c_mode, &pad_h, &pad_w, &u,
287                                        &v, &dilation_h, &dilation_w);
288 
289   uint64 hash_value = tensorflow::hash<int>()(c_mode);
290   hash_value =
291       tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(pad_h));
292   hash_value =
293       tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(pad_w));
294   hash_value =
295       tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(u));
296   hash_value =
297       tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(v));
298   hash_value = tensorflow::Hash64Combine(hash_value,
299                                          tensorflow::hash<int>()(dilation_h));
300   hash_value = tensorflow::Hash64Combine(hash_value,
301                                          tensorflow::hash<int>()(dilation_w));
302 
303   return hash_value;
304 }
305 
306 // Class to implement a cache of compiled fusion plans.
307 class CachedFusionPlans {
308  public:
309   // Check if we already have a fusion_plan corresponding to the given hash
310   // value.
311   // If we do, then
312   //   return true (+ the cached fusion plan via given pointer)
313   // Else
314   //   create a new fusion plan descriptor,
315   //   associate it with the given hash value in the cache
316   //   return false (+ newly created fusion plan via given pointer)
FindOrCreate(uint64 hash,miopenFusionPlanDescriptor_t * fusion_plan,miopenFusionDirection_t fusion_direction,miopenTensorDescriptor_t input_descriptor)317   static bool FindOrCreate(uint64 hash,
318                            miopenFusionPlanDescriptor_t* fusion_plan,
319                            miopenFusionDirection_t fusion_direction,
320                            miopenTensorDescriptor_t input_descriptor) {
321     mutex_lock lock{cached_plans_mutex};
322 
323     bool found_cached_plan = false;
324 
325     auto it = cached_plans.find(hash);
326     if (it != cached_plans.end()) {
327       *fusion_plan = it->second;
328       found_cached_plan = true;
329     } else {
330       auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction,
331                                                  input_descriptor);
332       if (status != miopenStatusSuccess) {
333         LOG(FATAL) << "call to miopenCreateFusionPlan failed: "
334                    << ToString(status);
335       } else {
336         cached_plans[hash] = *fusion_plan;
337       }
338     }
339 
340     return found_cached_plan;
341   }
342 
343   // Need to figure out the right place to call this routine.
Clear()344   static void Clear() {
345     mutex_lock lock{cached_plans_mutex};
346 
347     for (auto it : cached_plans) {
348       auto status = wrap::miopenDestroyFusionPlan(it.second);
349       if (status != miopenStatusSuccess) {
350         LOG(FATAL) << "call to miopenDestroyFusionPlan failed: "
351                    << ToString(status);
352       }
353     }
354 
355     cached_plans.clear();
356 
357     unsupported_plans.clear();
358   }
359 
360   // Is the Fusion plan corresponding to this hash unsupported.
IsUnsupportedFusionPlan(uint64 hash)361   static bool IsUnsupportedFusionPlan(uint64 hash) {
362     mutex_lock lock{cached_plans_mutex};
363     return unsupported_plans.count(hash) > 0;
364   }
365 
366   // Mark the given hash value as corresponding to an unsupported fusion plan.
MarkFusionPlanUnsupported(uint64 hash)367   static void MarkFusionPlanUnsupported(uint64 hash) {
368     mutex_lock lock{cached_plans_mutex};
369     unsupported_plans.insert(hash);
370   }
371 
372  private:
373   // Mutex to guard access to all data within this class.
374   static mutex cached_plans_mutex;
375 
376   // Map of hash-value to MIOpen Fusion plan descriptors.
377   // Need to be able share this across more than one stream and hence static.
378   static std::map<uint64, miopenFusionPlanDescriptor_t> cached_plans;
379 
380   // Set of hash-values that correspond to MIOpen Fusion plans that will fail
381   // compile and hence are not supported.
382   static std::set<uint64> unsupported_plans;
383 };
384 
385 mutex CachedFusionPlans::cached_plans_mutex;
386 std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
387 std::set<uint64> CachedFusionPlans::unsupported_plans;
388 
ToHandle(void * opaque_handle)389 miopenHandle_t ToHandle(void* opaque_handle) {
390   return static_cast<miopenHandle_t>(opaque_handle);
391 }
392 
ToConvForwardAlgo(dnn::AlgorithmDesc algorithm)393 miopenConvFwdAlgorithm_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
394   miopenConvFwdAlgorithm_t algo = miopenConvFwdAlgorithm_t(algorithm.algo_id());
395   switch (algo) {
396     case miopenConvolutionFwdAlgoGEMM:
397     case miopenConvolutionFwdAlgoDirect:
398     case miopenConvolutionFwdAlgoFFT:
399     case miopenConvolutionFwdAlgoWinograd:
400       return algo;
401     default:
402       LOG(FATAL) << "Unsupported MIOpen convolution forward algorithm: "
403                  << algorithm.algo_id();
404   }
405 }
406 
ToConvBackwardDataAlgo(dnn::AlgorithmDesc algorithm)407 miopenConvBwdDataAlgorithm_t ToConvBackwardDataAlgo(
408     dnn::AlgorithmDesc algorithm) {
409   miopenConvBwdDataAlgorithm_t algo =
410       miopenConvBwdDataAlgorithm_t(algorithm.algo_id());
411   switch (algo) {
412     case miopenConvolutionBwdDataAlgoGEMM:
413     case miopenConvolutionBwdDataAlgoDirect:
414     case miopenConvolutionBwdDataAlgoFFT:
415     case miopenConvolutionBwdDataAlgoWinograd:
416       return algo;
417     default:
418       LOG(FATAL)
419           << "Unsupported MIOpen convolution backward algorithm for data: "
420           << algorithm.algo_id();
421   }
422 }
423 
ToConvBackwardFilterAlgo(dnn::AlgorithmDesc algorithm)424 miopenConvBwdWeightsAlgorithm_t ToConvBackwardFilterAlgo(
425     dnn::AlgorithmDesc algorithm) {
426   miopenConvBwdWeightsAlgorithm_t algo =
427       miopenConvBwdWeightsAlgorithm_t(algorithm.algo_id());
428   switch (algo) {
429     case miopenConvolutionBwdWeightsAlgoGEMM:
430     case miopenConvolutionBwdWeightsAlgoDirect:
431       return algo;
432     default:
433       LOG(FATAL)
434           << "Unsupported MIOpen convolution backward algorithm for filter: "
435           << algorithm.algo_id();
436   }
437 }
438 
439 }  // namespace
440 
441 // Wraps a MIOpen handle and provides access to it through miopenHandle_t
442 // instances, which also locks a mutex, acquires the ROCm context, and sets
443 // the stream that MIOpen should use to enqueue any work.
444 //
445 // Note: MIOpenSupport::miopen_ should be the only instantiation of this class.
446 class MIOpenAccess {
447  public:
448   // Takes ownership of the handle.
MIOpenAccess(miopenHandle_t handle)449   explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
450 
~MIOpenAccess()451   ~MIOpenAccess() {
452     mutex_lock lock(mutex_);
453     wrap::miopenDestroy(handle_);
454   }
455 
456   // Creates a MIOpenHandle instance for stream.
457   //
458   // MIOpen API calls using the same handle instance need to be serialized
459   // across threads. This is guaranteed by MIOpenHandle instances locking the
460   // mutex owned by this class.
461   //
462   // Most MIOpen APIs taking a handle perform work on a HIP stream. The
463   // MIOpenHandle instance acquires the executor's ROCm context and sets MIOpen
464   // to use the provided stream.
465   //
466   // The stream argument may be null, which translates to the null stream.
467   // The null stream synchronizes with all other streams and it is
468   // therefore a bad idea (performance wise) to call any MIOpen APIs that
469   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)470   MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
471     mutex_lock lock(mutex_);
472     gpu::ScopedActivateExecutorContext context(executor);
473     hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
474     auto status = wrap::miopenSetStream(handle_, hip_stream);
475     CHECK_EQ(status, miopenStatusSuccess) << "Failed to set MIOpen stream.";
476     return MIOpenHandle(std::move(context), std::move(lock), handle_);
477   }
478 
479  private:
480   // Guards the enqueueing of MIOpen operations via the handle_ below.
481   mutex mutex_;
482 
483   // MIOpen library handle.
484   miopenHandle_t handle_ GUARDED_BY(mutex_);  // Owned.
485 };
486 
MIOpenSupport(GpuExecutor * parent)487 MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {}
488 
Init()489 port::Status MIOpenSupport::Init() {
490   ScopedActivateExecutorContext context(parent_);
491   miopenHandle_t miopen_handle = nullptr;
492   auto status = wrap::miopenCreateWithStream(
493       reinterpret_cast<miopenHandle_t*>(&miopen_handle), (hipStream_t)(0));
494   if (status == miopenStatusSuccess) {
495     miopen_.reset(new MIOpenAccess(miopen_handle));
496     return port::Status::OK();
497   }
498 
499   CHECK_EQ(miopen_handle, nullptr);
500   LOG(ERROR) << "could not create miopen handle: " << ToString(status);
501   if (status == miopenStatusNotInitialized) {
502     auto result = rocm::Diagnostician::FindKernelDriverVersion();
503     if (!result.ok()) {
504       LOG(ERROR) << "error retrieving driver version: "
505                  << rocm::DriverVersionStatusToString(result);
506     } else {
507       const auto& version = result.ValueOrDie();
508       LOG(INFO) << "possibly insufficient driver version: "
509                 << rocm::DriverVersionToString(version);
510     }
511   }
512 
513   return port::Status{port::error::INTERNAL,
514                       absl::StrCat("miopen library could not create a handle: ",
515                                    ToString(status))};
516 }
517 
518 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()519 MIOpenSupport::GetVersion() {
520   // ROCM TODO: retrieve MIOpen version with its API
521   return perftools::gputools::dnn::VersionInfo(1, 3, 0);
522 }
523 
524 // Turns a BatchDescriptor structure into a miopen tensor handle within a scope.
525 class ScopedTensorDescriptor {
526  public:
ScopedTensorDescriptor(const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)527   ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
528                          miopenDataType_t elem_type)
529       : handle_(nullptr) {
530     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
531     if (status != miopenStatusSuccess) {
532       LOG(FATAL) << "could not create miopen tensor descriptor: "
533                  << ToString(status);
534     }
535 
536     switch (batch_descriptor.layout()) {
537       case dnn::DataLayout::kBatchYXDepth:
538       case dnn::DataLayout::kBatchDepthYX: {
539         const int nd = batch_descriptor.ndims() + 2;
540         if (nd != 4) {
541           LOG(FATAL) << "miopen only supports 4D tensors, dim=" << nd
542                      << " not allowed";
543         }
544 
545         // MIOpen requires the strides and dims to be ordered as BDYX.
546         std::vector<int64> strides64 =
547             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
548         std::vector<int64> dims64 =
549             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
550 
551         // MIOpen requires arrays of ints.
552         std::vector<int> strides(nd);
553         std::vector<int> dims(nd);
554         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
555                        &CheckedNarrowing<int64, int>);
556         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
557                        &CheckedNarrowing<int64, int>);
558         status = wrap::miopenSet4dTensorDescriptor(handle_, elem_type, dims[0],
559                                                    dims[1], dims[2], dims[3]);
560 
561         if (status != miopenStatusSuccess) {
562           LOG(FATAL) << "could not convert BatchDescriptor "
563                      << batch_descriptor.ToString()
564                      << " to miopen tensor descriptor: " << ToString(status);
565         }
566       } break;
567       default:
568         LOG(FATAL) << "Unsupported tensor format "
569                    << DataLayoutString(batch_descriptor.layout());
570         break;
571     }
572   }
573 
~ScopedTensorDescriptor()574   ~ScopedTensorDescriptor() {
575     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
576     if (status != miopenStatusSuccess) {
577       LOG(ERROR) << "could not destroy miopen tensor descriptor: "
578                  << ToString(status);
579     }
580   }
581 
handle() const582   miopenTensorDescriptor_t handle() const { return handle_; }
583 
584  private:
585   miopenTensorDescriptor_t handle_;  // Owned.
586 
587   SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
588 };
589 
590 // Turns a FilterDescriptor structure into a miopen filter handle within a
591 // scope.
592 class ScopedFilterDescriptor {
593  public:
ScopedFilterDescriptor(const FilterDescriptor & filter_descriptor,const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)594   ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
595                          const BatchDescriptor& batch_descriptor,
596                          miopenDataType_t elem_type)
597       : handle_(nullptr) {
598     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
599     if (status != miopenStatusSuccess) {
600       LOG(FATAL) << "could not create miopen filter descriptor: "
601                  << ToString(status);
602     }
603 
604     const int nd = batch_descriptor.ndims() + 2;
605 
606     if (nd != 4) {
607       LOG(FATAL) << "miopen only supports 4D filters, dim=" << nd
608                  << "not allowed" << ToString(status);
609     }
610 
611     std::vector<int> dims(2 + filter_descriptor.ndims());
612     dims[0] = filter_descriptor.output_feature_map_count();
613     dims[1] = filter_descriptor.input_feature_map_count();
614     const auto& spatial_dims = filter_descriptor.input_filter_dims();
615     std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
616 
617     status = wrap::miopenSet4dTensorDescriptor(handle_, elem_type, dims[0],
618                                                dims[1], dims[2], dims[3]);
619     if (status != miopenStatusSuccess) {
620       LOG(FATAL) << "could not set miopen filter descriptor: "
621                  << ToString(status);
622     }
623   }
624 
~ScopedFilterDescriptor()625   ~ScopedFilterDescriptor() {
626     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
627     if (status != miopenStatusSuccess) {
628       LOG(ERROR) << "could not destroy miopen filter descriptor: "
629                  << ToString(status);
630     }
631   }
632 
handle() const633   miopenTensorDescriptor_t handle() const { return handle_; }
634 
635  private:
636   // miopen filter descriptor this object creates. Owned.
637   miopenTensorDescriptor_t handle_;
638 
639   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
640 };
641 
642 // Turns a ConvolutionDescriptor structure into a miopen convolution handle
643 // within a scope.
644 class ScopedConvolutionDescriptor {
645  public:
ScopedConvolutionDescriptor(const ConvolutionDescriptor & convolution_descriptor,miopenDataType_t data_type)646   ScopedConvolutionDescriptor(
647       const ConvolutionDescriptor& convolution_descriptor,
648       miopenDataType_t data_type)
649       : handle_(nullptr) {
650     auto status = wrap::miopenCreateConvolutionDescriptor(&handle_);
651     if (status != miopenStatusSuccess) {
652       LOG(FATAL) << "could not create miopen convolution descriptor: "
653                  << ToString(status);
654     }
655     const auto& strides64 = convolution_descriptor.strides();
656     const auto& padding64 = convolution_descriptor.padding();
657     if (convolution_descriptor.pad_alignment() ==
658         dnn::PadAlignment::kTensorFlowPadding) {
659       LOG(ERROR) << "TensorFlow padding alignment is not supported.";
660     }
661 
662     // MIOpen requires arrays of ints.
663     std::vector<int> strides(convolution_descriptor.ndims());
664     std::vector<int> padding(convolution_descriptor.ndims());
665     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
666                    &CheckedNarrowing<int64, int>);
667     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
668                    &CheckedNarrowing<int64, int>);
669     std::vector<int> upscale(convolution_descriptor.ndims(), 1);
670 
671     status = wrap::miopenInitConvolutionDescriptor(
672         handle_, miopenConvolution, padding[0], padding[1], strides[0],
673         strides[1], upscale[0], upscale[1]);
674     if (status != miopenStatusSuccess) {
675       LOG(FATAL) << "could not set miopen convolution descriptor: "
676                  << ToString(status);
677     }
678 
679     VLOG(2) << "Requesting grouped convolution: "
680             << convolution_descriptor.group_count();
681     status = wrap::miopenSetConvolutionGroupCount(
682         handle_, convolution_descriptor.group_count());
683     if (status != miopenStatusSuccess) {
684       LOG(FATAL) << "could not set miopen convolution group count: "
685                  << ToString(status);
686     }
687   }
~ScopedConvolutionDescriptor()688   ~ScopedConvolutionDescriptor() {
689     auto status = wrap::miopenDestroyConvolutionDescriptor(handle_);
690     if (status != miopenStatusSuccess) {
691       LOG(ERROR) << "could not destroy miopen convolution descriptor: "
692                  << ToString(status);
693     }
694   }
695 
handle() const696   miopenConvolutionDescriptor_t handle() const { return handle_; }
697 
698  private:
699   miopenConvolutionDescriptor_t handle_;  // Owned.
700 
701   SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
702 };
703 
704 // Turns a PoolingDescriptor structure into a miopen pooling descriptor handle
705 // within a scope.
706 class ScopedPoolingDescriptor {
707  public:
ScopedPoolingDescriptor(const PoolingDescriptor & pooling_descriptor)708   ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
709       : handle_(nullptr) {
710     auto status = wrap::miopenCreatePoolingDescriptor(&handle_);
711     if (status != miopenStatusSuccess) {
712       LOG(FATAL) << "could not create miopen pooling descriptor: "
713                  << ToString(status);
714     }
715 
716     absl::Span<const int64> strides64 = pooling_descriptor.strides();
717     absl::Span<const int64> padding64 = pooling_descriptor.padding();
718     absl::Span<const int64> shape64 = pooling_descriptor.window();
719 
720     const int nd = pooling_descriptor.ndims();
721     std::vector<int> shape(nd);
722     std::vector<int> padding(nd);
723     std::vector<int> strides(nd);
724     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
725                    &CheckedNarrowing<int64, int>);
726     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
727                    &CheckedNarrowing<int64, int>);
728     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
729                    &CheckedNarrowing<int64, int>);
730 
731     if (nd != 2) {
732       LOG(FATAL) << "miopen requires pooling dimensions be 2"
733                  << ToString(status);
734     }
735 
736     status = wrap::miopenSet2dPoolingDescriptor(
737         handle_,
738         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
739              ? miopenPoolingMax
740              : miopenPoolingAverage),
741         shape[0], shape[1], padding[0], padding[1], strides[0], strides[1]);
742     if (status != miopenStatusSuccess) {
743       LOG(FATAL) << "could not set miopen pooling descriptor: "
744                  << ToString(status);
745     }
746   }
~ScopedPoolingDescriptor()747   ~ScopedPoolingDescriptor() {
748     auto status = wrap::miopenDestroyPoolingDescriptor(handle_);
749     if (status != miopenStatusSuccess) {
750       LOG(ERROR) << "could not destroy miopen pooling descriptor: "
751                  << ToString(status);
752     }
753   }
754 
handle() const755   miopenPoolingDescriptor_t handle() const { return handle_; }
756 
757  private:
758   miopenPoolingDescriptor_t handle_;  // Owned.
759 
760   SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
761 };
762 
763 // Turns a NormalizeDescriptor structure into a miopen LRN descriptor handle.
764 class ScopedNormalizeDescriptor {
765  public:
ScopedNormalizeDescriptor(const NormalizeDescriptor & normalize_descriptor)766   ScopedNormalizeDescriptor(const NormalizeDescriptor& normalize_descriptor)
767       : handle_(nullptr) {
768     auto status = wrap::miopenCreateLRNDescriptor(&handle_);
769     if (status != miopenStatusSuccess) {
770       LOG(FATAL) << "could not create miopen LRN descriptor: "
771                  << ToString(status);
772     }
773 
774     // The range specifies that the indices in the closed range
775     // [i - range, i + range] should be included in the normalization for index
776     // i. The lrnN value is the total number of elements in the range, so
777     // lrnN = 2*range + 1.
778     unsigned lrn_N = 2 * normalize_descriptor.range() + 1;
779 
780     // Note that SE defines the normalization operation as
781     //
782     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
783     //
784     // but MIOpen defines it as
785     //
786     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
787     //
788     // i.e. there is a factor of n difference between the meaning of the alphas
789     // in the two contexts. The MIOpen alpha is n times the SE alpha.
790     double lrn_alpha = lrn_N * normalize_descriptor.alpha();
791 
792     double lrn_beta = normalize_descriptor.beta();
793     double lrn_k = normalize_descriptor.bias();
794     status = wrap::miopenSetLRNDescriptor(handle_, miopenLRNCrossChannel, lrn_N,
795                                           lrn_alpha, lrn_beta, lrn_k);
796     if (status != miopenStatusSuccess) {
797       LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status);
798     }
799   }
800 
~ScopedNormalizeDescriptor()801   ~ScopedNormalizeDescriptor() {
802     auto status = wrap::miopenDestroyLRNDescriptor(handle_);
803     if (status != miopenStatusSuccess) {
804       LOG(ERROR) << "could not destroy miopen LRN descriptor: "
805                  << ToString(status);
806     }
807   }
808 
handle() const809   miopenLRNDescriptor_t handle() const { return handle_; }
810 
811  private:
812   miopenLRNDescriptor_t handle_;  // Owned.
813 
814   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
815 };
816 
817 // Turns a activation mode into a miopen activation mode descriptor with a scope
818 // around it
819 class ScopedActivationDescriptor {
820  public:
ScopedActivationDescriptor(dnn::ActivationMode activation_mode)821   ScopedActivationDescriptor(dnn::ActivationMode activation_mode)
822       : handle_(nullptr),
823         miopen_activation_mode_(miopenActivationPASTHRU),
824         alpha_(0.0),
825         beta_(0.0),
826         gamma_(0.0) {
827     auto status = wrap::miopenCreateActivationDescriptor(&handle_);
828     if (status != miopenStatusSuccess) {
829       LOG(FATAL) << "call to miopenCreateActivationDescriptor failed: "
830                  << ToString(status);
831     } else {
832       switch (activation_mode) {
833         case dnn::ActivationMode::kNone:
834           miopen_activation_mode_ = miopenActivationPASTHRU;
835           break;
836 
837         case dnn::ActivationMode::kSigmoid:
838           miopen_activation_mode_ = miopenActivationLOGISTIC;
839           break;
840 
841         case dnn::ActivationMode::kRelu:
842           miopen_activation_mode_ = miopenActivationRELU;
843           break;
844 
845         case dnn::ActivationMode::kRelu6:
846           miopen_activation_mode_ = miopenActivationRELU;
847           alpha_ = 6.0;
848           break;
849 
850         case dnn::ActivationMode::kTanh:
851           miopen_activation_mode_ = miopenActivationTANH;
852           break;
853 
854         default:
855           LOG(FATAL) << "Activation mode ("
856                      << dnn::ActivationModeString(activation_mode)
857                      << ") not yet implemented";
858           break;
859       }
860 
861       status = wrap::miopenSetActivationDescriptor(
862           handle_, miopen_activation_mode_, alpha_, beta_, gamma_);
863       if (status != miopenStatusSuccess) {
864         LOG(FATAL) << "call to miopenSetActivationDescriptor failed: "
865                    << ToString(status);
866       }
867     }
868   }
869 
~ScopedActivationDescriptor()870   ~ScopedActivationDescriptor() {
871     auto status = wrap::miopenDestroyActivationDescriptor(handle_);
872     if (status != miopenStatusSuccess) {
873       LOG(FATAL) << "call to miopenDestroyActivationDescriptor failed: "
874                  << ToString(status);
875     }
876   }
877 
handle() const878   miopenActivationDescriptor_t handle() const { return handle_; }
879 
GetHashValue()880   uint64 GetHashValue() {
881     uint64 hash_value = tensorflow::hash<int>()(miopen_activation_mode_);
882     hash_value = tensorflow::Hash64Combine(hash_value,
883                                            tensorflow::hash<double>()(alpha_));
884     hash_value = tensorflow::Hash64Combine(hash_value,
885                                            tensorflow::hash<double>()(beta_));
886     hash_value = tensorflow::Hash64Combine(hash_value,
887                                            tensorflow::hash<double>()(gamma_));
888 
889     return hash_value;
890   }
891 
892  private:
893   miopenActivationDescriptor_t handle_;  // Owned.
894 
895   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
896 
897  public:
898   // caching these values here to avoid calling miopenGetActivationDescriptor
899   // to do the same. miopenGetActivationDescriptor gets called twice during each
900   // call to execute a fusion plan (that involves the activation op)...once call
901   // during calculating hashvalue for the fusion op, and another before calling
902   // SetOpArgs for the activation op
903   miopenActivationMode_t miopen_activation_mode_;
904   double alpha_;
905   double beta_;
906   double gamma_;
907 };
908 
909 // base class for all fusion plan implementations to derive from
910 class ScopedFusionPlanBase {
911  public:
ScopedFusionPlanBase(miopenHandle_t miopen_handle,const miopenFusionDirection_t fuse_direction,const miopenTensorDescriptor_t input_descriptor)912   ScopedFusionPlanBase(miopenHandle_t miopen_handle,
913                        const miopenFusionDirection_t fuse_direction,
914                        const miopenTensorDescriptor_t input_descriptor)
915       : miopen_handle_(miopen_handle),
916         fusion_plan_(nullptr),
917         fusion_args_(nullptr),
918         fusion_plan_compiled_(false) {
919     auto status = wrap::miopenCreateOperatorArgs(&fusion_args_);
920     if (status != miopenStatusSuccess) {
921       LOG(FATAL) << "call to miopenCreateOperatorArgs failed: "
922                  << ToString(status);
923     }
924   }
925 
~ScopedFusionPlanBase()926   virtual ~ScopedFusionPlanBase() {
927     auto status = wrap::miopenDestroyOperatorArgs(fusion_args_);
928     if (status != miopenStatusSuccess) {
929       LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: "
930                  << ToString(status);
931     }
932   }
933 
Execute(miopenTensorDescriptor_t input_descriptor,const void * input_data,miopenTensorDescriptor_t output_descriptor,void * output_data)934   miopenStatus_t Execute(miopenTensorDescriptor_t input_descriptor,
935                          const void* input_data,
936                          miopenTensorDescriptor_t output_descriptor,
937                          void* output_data) {
938     auto status = wrap::miopenExecuteFusionPlan(
939         miopen_handle_, fusion_plan_, input_descriptor, input_data,
940         output_descriptor, output_data, fusion_args_);
941     if (status != miopenStatusSuccess) {
942       LOG(FATAL) << "call to miopenExecuteFusionPlan failed: "
943                  << ToString(status);
944     }
945 
946     return status;
947   }
948 
CompilationSucceeded()949   bool CompilationSucceeded() { return fusion_plan_compiled_; }
950 
951  protected:
SetConvolutionArgs(const int op_idx,const float * alpha,const float * beta,const void * data)952   miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha,
953                                     const float* beta, const void* data) {
954     miopenFusionOpDescriptor_t conv_op;
955     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op);
956     if (status != miopenStatusSuccess) {
957       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
958                  << ToString(status);
959     }
960 
961     status = wrap::miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha,
962                                               beta, data);
963     if (status != miopenStatusSuccess) {
964       LOG(FATAL) << "call to miopenSetOpArgsConvForward failed: "
965                  << ToString(status);
966     }
967     return status;
968   }
969 
SetBiasArgs(const int op_idx,const float * alpha,const float * beta,const void * data)970   miopenStatus_t SetBiasArgs(const int op_idx, const float* alpha,
971                              const float* beta, const void* data) {
972     miopenFusionOpDescriptor_t bias_op;
973     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op);
974     if (status != miopenStatusSuccess) {
975       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
976                  << ToString(status);
977     }
978 
979     status = wrap::miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha,
980                                               beta, data);
981     if (status != miopenStatusSuccess) {
982       LOG(FATAL) << "call to miopenSetOpArgsBiasForward failed: "
983                  << ToString(status);
984     }
985     return status;
986   }
987 
SetBatchNormInferenceArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)988   miopenStatus_t SetBatchNormInferenceArgs(const int op_idx, const float* alpha,
989                                            const float* beta, const void* scale,
990                                            const void* offset, const void* mean,
991                                            const void* variance,
992                                            double epsilon) {
993     miopenFusionOpDescriptor_t batchnorm_op;
994     auto status =
995         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
996     if (status != miopenStatusSuccess) {
997       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
998                  << ToString(status);
999     }
1000 
1001     status = wrap::miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op,
1002                                                      alpha, beta, scale, offset,
1003                                                      mean, variance, epsilon);
1004     if (status != miopenStatusSuccess) {
1005       LOG(FATAL) << "call to miopenSetOpArgsBatchNormInference failed: "
1006                  << ToString(status);
1007     }
1008     return status;
1009   }
1010 
SetBatchNormForwardArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,void * running_mean,void * running_variance,void * saved_mean,void * saved_inv_variance,double epsilon)1011   miopenStatus_t SetBatchNormForwardArgs(const int op_idx, const float* alpha,
1012                                          const float* beta, const void* scale,
1013                                          const void* offset, void* running_mean,
1014                                          void* running_variance,
1015                                          void* saved_mean,
1016                                          void* saved_inv_variance,
1017                                          double epsilon) {
1018     miopenFusionOpDescriptor_t batchnorm_op;
1019     auto status =
1020         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1021     if (status != miopenStatusSuccess) {
1022       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1023                  << ToString(status);
1024     }
1025 
1026     double exp_avg_factor = 1.0;
1027 
1028     status = wrap::miopenSetOpArgsBatchNormForward(
1029         fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean,
1030         saved_inv_variance, running_mean, running_variance, exp_avg_factor,
1031         epsilon);
1032     if (status != miopenStatusSuccess) {
1033       LOG(FATAL) << "call to miopenSetOpArgsBatchNormForward failed: "
1034                  << ToString(status);
1035     }
1036     return status;
1037   }
1038 
SetBatchNormBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * x,const void * scale,const void * offset,void * scale_grad,void * offset_grad,const void * saved_mean,const void * saved_inv_variance)1039   miopenStatus_t SetBatchNormBackwardArgs(const int op_idx, const float* alpha,
1040                                           const float* beta, const void* x,
1041                                           const void* scale, const void* offset,
1042                                           void* scale_grad, void* offset_grad,
1043                                           const void* saved_mean,
1044                                           const void* saved_inv_variance) {
1045     miopenFusionOpDescriptor_t batchnorm_op;
1046     auto status =
1047         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1048     if (status != miopenStatusSuccess) {
1049       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1050                  << ToString(status);
1051     }
1052 
1053     status = wrap::miopenSetOpArgsBatchNormBackward(
1054         fusion_args_, batchnorm_op, alpha, beta, x, scale, offset, scale_grad,
1055         offset_grad, saved_mean, saved_inv_variance);
1056     if (status != miopenStatusSuccess) {
1057       LOG(FATAL) << "call to miopenSetOpArgsBatchNormBackward failed: "
1058                  << ToString(status);
1059     }
1060     return status;
1061   }
1062 
SetActivationForwardArgs(const int op_idx,const float * alpha,const float * beta,double activ_alpha,double activ_beta,double activ_gamma)1063   miopenStatus_t SetActivationForwardArgs(const int op_idx, const float* alpha,
1064                                           const float* beta, double activ_alpha,
1065                                           double activ_beta,
1066                                           double activ_gamma) {
1067     miopenFusionOpDescriptor_t actv_op;
1068     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1069     if (status != miopenStatusSuccess) {
1070       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1071                  << ToString(status);
1072     }
1073 
1074     status =
1075         wrap::miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta,
1076                                           activ_alpha, activ_beta, activ_gamma);
1077     if (status != miopenStatusSuccess) {
1078       LOG(FATAL) << "call to miopenSetOpArgsActivForward failed: "
1079                  << ToString(status);
1080     }
1081     return status;
1082   }
1083 
SetActivationBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * y,double activ_alpha,double activ_beta,double activ_gamma)1084   miopenStatus_t SetActivationBackwardArgs(const int op_idx, const float* alpha,
1085                                            const float* beta, const void* y,
1086                                            double activ_alpha,
1087                                            double activ_beta,
1088                                            double activ_gamma) {
1089     miopenFusionOpDescriptor_t actv_op;
1090     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1091     if (status != miopenStatusSuccess) {
1092       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1093                  << ToString(status);
1094     }
1095 
1096     status = wrap::miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha,
1097                                                 beta, y, nullptr, activ_alpha,
1098                                                 activ_beta, activ_gamma);
1099     if (status != miopenStatusSuccess) {
1100       LOG(FATAL) << "call to miopenSetOpArgsActivBackward failed: "
1101                  << ToString(status);
1102     }
1103     return status;
1104   }
1105 
1106   miopenHandle_t miopen_handle_;
1107   miopenFusionPlanDescriptor_t fusion_plan_;
1108   miopenOperatorArgs_t fusion_args_;  // Owned.
1109   bool fusion_plan_compiled_;
1110 
1111   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBase);
1112 };
1113 
1114 // class to represent the Convolution+Bias+Activation fusion plan
1115 class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase {
1116  public:
ScopedFusionPlanConvolutionBiasActivation(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1117   ScopedFusionPlanConvolutionBiasActivation(
1118       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1119       miopenTensorDescriptor_t filter_descriptor,
1120       miopenConvolutionDescriptor_t conv_descriptor,
1121       miopenTensorDescriptor_t bias_descriptor,
1122       ScopedActivationDescriptor& activation_descriptor)
1123       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1124                              input_descriptor) {
1125     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1126                                        filter_descriptor, conv_descriptor,
1127                                        bias_descriptor, activation_descriptor);
1128 
1129     bool is_compiled = CachedFusionPlans::FindOrCreate(
1130         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1131     if (!is_compiled) {
1132       miopenFusionOpDescriptor_t conv_op;
1133       auto status = wrap::miopenCreateOpConvForward(
1134           fusion_plan_, &conv_op, conv_descriptor, filter_descriptor);
1135       if (status != miopenStatusSuccess) {
1136         LOG(FATAL) << "call to miopenCreateOpConvForward failed: "
1137                    << ToString(status);
1138       }
1139 
1140       miopenFusionOpDescriptor_t bias_op;
1141       status = wrap::miopenCreateOpBiasForward(fusion_plan_, &bias_op,
1142                                                bias_descriptor);
1143       if (status != miopenStatusSuccess) {
1144         LOG(FATAL) << "call to miopenCreateOpBiasForward failed: "
1145                    << ToString(status);
1146       }
1147 
1148       miopenFusionOpDescriptor_t actv_op;
1149       status = wrap::miopenCreateOpActivationForward(
1150           fusion_plan_, &actv_op,
1151           activation_descriptor.miopen_activation_mode_);
1152       if (status != miopenStatusSuccess) {
1153         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1154                    << ToString(status);
1155       }
1156 
1157       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1158       if (status != miopenStatusSuccess) {
1159         VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: "
1160                 << ToString(status);
1161 
1162         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1163       } else {
1164         VLOG(2) << "Fusion Plan compile succedded (CBA) ";
1165         fusion_plan_compiled_ = true;
1166       }
1167     } else {
1168       // fusion plan was already compiled...check whether it failed to compile
1169       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1170     }
1171   }
1172 
SetConvolutionArgs(const void * filter_data)1173   miopenStatus_t SetConvolutionArgs(const void* filter_data) {
1174     float alpha = 1.0;
1175     float beta = 0.0;
1176     return ScopedFusionPlanBase::SetConvolutionArgs(k_conv_op_idx, &alpha,
1177                                                     &beta, filter_data);
1178   }
1179 
SetBiasArgs(const void * bias_data)1180   miopenStatus_t SetBiasArgs(const void* bias_data) {
1181     float alpha = 1.0;
1182     float beta = 0.0;
1183     return ScopedFusionPlanBase::SetBiasArgs(k_bias_op_idx, &alpha, &beta,
1184                                              bias_data);
1185   }
1186 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1187   miopenStatus_t SetActivationForwardArgs(
1188       ScopedActivationDescriptor& activation_descriptor) {
1189     float alpha = 1.0;
1190     float beta = 0.0;
1191 
1192     return ScopedFusionPlanBase::SetActivationForwardArgs(
1193         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1194         activation_descriptor.beta_, activation_descriptor.gamma_);
1195   }
1196 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1197   uint64 GetFusionOpHashValue(
1198       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1199       miopenTensorDescriptor_t filter_descriptor,
1200       miopenConvolutionDescriptor_t conv_descriptor,
1201       miopenTensorDescriptor_t bias_descriptor,
1202       ScopedActivationDescriptor& activation_descriptor) {
1203     uint64 hash_value = tensorflow::Hash64("ConvolutionBiasActivation");
1204 
1205     hash_value = tensorflow::Hash64Combine(
1206         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1207 
1208     hash_value =
1209         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1210     hash_value =
1211         tensorflow::Hash64Combine(hash_value, GetHashValue(filter_descriptor));
1212     hash_value =
1213         tensorflow::Hash64Combine(hash_value, GetHashValue(conv_descriptor));
1214     hash_value =
1215         tensorflow::Hash64Combine(hash_value, GetHashValue(bias_descriptor));
1216     hash_value = tensorflow::Hash64Combine(
1217         hash_value, activation_descriptor.GetHashValue());
1218     return hash_value;
1219   }
1220 
1221  private:
1222   const int k_conv_op_idx = 0;
1223   const int k_bias_op_idx = 1;
1224   const int k_actv_op_idx = 2;
1225 
1226   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanConvolutionBiasActivation);
1227 };
1228 
1229 // class to represent the BatchNorm+Activation (inference) fusion plan
1230 class ScopedFusionPlanBatchNormActivationInference
1231     : public ScopedFusionPlanBase {
1232  public:
ScopedFusionPlanBatchNormActivationInference(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1233   ScopedFusionPlanBatchNormActivationInference(
1234       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1235       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1236       ScopedActivationDescriptor& activation_descriptor)
1237       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1238                              input_descriptor) {
1239     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1240                                        scale_offset_mean_variance_descriptor,
1241                                        activation_descriptor);
1242 
1243     bool is_compiled = CachedFusionPlans::FindOrCreate(
1244         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1245 
1246     if (!is_compiled) {
1247       miopenFusionOpDescriptor_t batchnorm_op;
1248       auto status = wrap::miopenCreateOpBatchNormInference(
1249           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1250           scale_offset_mean_variance_descriptor);
1251 
1252       if (status != miopenStatusSuccess) {
1253         LOG(FATAL) << "call to miopenCreateOpBatchNormInference failed: "
1254                    << ToString(status);
1255       }
1256 
1257       miopenFusionOpDescriptor_t actv_op;
1258       status = wrap::miopenCreateOpActivationForward(
1259           fusion_plan_, &actv_op,
1260           activation_descriptor.miopen_activation_mode_);
1261       if (status != miopenStatusSuccess) {
1262         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1263                    << ToString(status);
1264       }
1265 
1266       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1267       if (status != miopenStatusSuccess) {
1268         VLOG(2) << "call to miopenCompileFusionPlan (BnA inference) failed: "
1269                 << ToString(status);
1270 
1271         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1272       } else {
1273         VLOG(2) << "Fusion Plan compile succedded (BnA inference) ";
1274         fusion_plan_compiled_ = true;
1275       }
1276     } else {
1277       // fusion plan was already compiled...check whether it failed to compile
1278       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1279     }
1280   }
1281 
SetBatchNormInferenceArgs(const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1282   miopenStatus_t SetBatchNormInferenceArgs(const void* scale,
1283                                            const void* offset, const void* mean,
1284                                            const void* variance,
1285                                            double epsilon) {
1286     float alpha = 1.0;
1287     float beta = 0.0;
1288     return ScopedFusionPlanBase::SetBatchNormInferenceArgs(
1289         k_batchnorm_op_idx, &alpha, &beta, scale, offset, mean, variance,
1290         epsilon);
1291   }
1292 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1293   miopenStatus_t SetActivationForwardArgs(
1294       ScopedActivationDescriptor& activation_descriptor) {
1295     float alpha = 1.0;
1296     float beta = 0.0;
1297 
1298     return ScopedFusionPlanBase::SetActivationForwardArgs(
1299         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1300         activation_descriptor.beta_, activation_descriptor.gamma_);
1301   }
1302 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1303   uint64 GetFusionOpHashValue(
1304       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1305       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1306       ScopedActivationDescriptor& activation_descriptor) {
1307     uint64 hash_value = tensorflow::Hash64("BatchNormActivationInference");
1308 
1309     hash_value = tensorflow::Hash64Combine(
1310         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1311 
1312     hash_value =
1313         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1314 
1315     hash_value = tensorflow::Hash64Combine(
1316         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1317 
1318     hash_value = tensorflow::Hash64Combine(
1319         hash_value, activation_descriptor.GetHashValue());
1320     return hash_value;
1321   }
1322 
1323  private:
1324   const int k_batchnorm_op_idx = 0;
1325   const int k_actv_op_idx = 1;
1326 
1327   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationInference);
1328 };
1329 
1330 // class to represent the BatchNorm+Activation (training-forward) fusion plan
1331 class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase {
1332  public:
ScopedFusionPlanBatchNormActivationForward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1333   ScopedFusionPlanBatchNormActivationForward(
1334       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1335       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1336       ScopedActivationDescriptor& activation_descriptor)
1337       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1338                              input_descriptor) {
1339     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1340                                        scale_offset_mean_variance_descriptor,
1341                                        activation_descriptor);
1342 
1343     bool is_compiled = CachedFusionPlans::FindOrCreate(
1344         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1345 
1346     if (!is_compiled) {
1347       miopenFusionOpDescriptor_t batchnorm_op;
1348       auto status = wrap::miopenCreateOpBatchNormForward(
1349           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1350           true /* runningMeanVariance */);
1351 
1352       if (status != miopenStatusSuccess) {
1353         LOG(FATAL) << "call to miopenCreateOpBatchNormForward failed: "
1354                    << ToString(status);
1355       }
1356 
1357       miopenFusionOpDescriptor_t actv_op;
1358       status = wrap::miopenCreateOpActivationForward(
1359           fusion_plan_, &actv_op,
1360           activation_descriptor.miopen_activation_mode_);
1361       if (status != miopenStatusSuccess) {
1362         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1363                    << ToString(status);
1364       }
1365 
1366       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1367       if (status != miopenStatusSuccess) {
1368         VLOG(2) << "call to miopenCompileFusionPlan (BnA forward) failed: "
1369                 << ToString(status);
1370 
1371         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1372       } else {
1373         VLOG(2) << "Fusion Plan compile succedded (BnA forward) ";
1374         fusion_plan_compiled_ = true;
1375       }
1376     } else {
1377       // fusion plan was already compiled...check whether it failed to compile
1378       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1379     }
1380   }
1381 
SetBatchNormForwardArgs(const void * scale,const void * offset,void * batch_mean,void * batch_var,void * saved_mean,void * saved_var,double epsilon)1382   miopenStatus_t SetBatchNormForwardArgs(const void* scale, const void* offset,
1383                                          void* batch_mean, void* batch_var,
1384                                          void* saved_mean, void* saved_var,
1385                                          double epsilon) {
1386     float alpha = 1.0;
1387     float beta = 0.0;
1388     return ScopedFusionPlanBase::SetBatchNormForwardArgs(
1389         k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var,
1390         saved_mean, saved_var, epsilon);
1391   }
1392 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1393   miopenStatus_t SetActivationForwardArgs(
1394       ScopedActivationDescriptor& activation_descriptor) {
1395     float alpha = 1.0;
1396     float beta = 0.0;
1397 
1398     return ScopedFusionPlanBase::SetActivationForwardArgs(
1399         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1400         activation_descriptor.beta_, activation_descriptor.gamma_);
1401   }
1402 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1403   uint64 GetFusionOpHashValue(
1404       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1405       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1406       ScopedActivationDescriptor& activation_descriptor) {
1407     uint64 hash_value = tensorflow::Hash64("BatchNormActivationForward");
1408 
1409     hash_value = tensorflow::Hash64Combine(
1410         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1411 
1412     hash_value =
1413         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1414 
1415     hash_value = tensorflow::Hash64Combine(
1416         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1417 
1418     hash_value = tensorflow::Hash64Combine(
1419         hash_value, activation_descriptor.GetHashValue());
1420     return hash_value;
1421   }
1422 
1423  private:
1424   const int k_batchnorm_op_idx = 0;
1425   const int k_actv_op_idx = 1;
1426 
1427   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationForward);
1428 };
1429 
1430 // class to represent the BatchNorm+Activation (training-backward) fusion plan
1431 class ScopedFusionPlanBatchNormActivationBackward
1432     : public ScopedFusionPlanBase {
1433  public:
ScopedFusionPlanBatchNormActivationBackward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1434   ScopedFusionPlanBatchNormActivationBackward(
1435       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1436       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1437       ScopedActivationDescriptor& activation_descriptor)
1438       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1439                              input_descriptor) {
1440     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1441                                        scale_offset_mean_variance_descriptor,
1442                                        activation_descriptor);
1443 
1444     bool is_compiled = CachedFusionPlans::FindOrCreate(
1445         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1446 
1447     if (!is_compiled) {
1448       miopenFusionOpDescriptor_t batchnorm_op;
1449       auto status = wrap::miopenCreateOpBatchNormBackward(
1450           fusion_plan_, &batchnorm_op, miopenBNSpatial);
1451 
1452       if (status != miopenStatusSuccess) {
1453         LOG(FATAL) << "call to miopenCreateOpBatchNormBackward failed: "
1454                    << ToString(status);
1455       }
1456 
1457       miopenFusionOpDescriptor_t actv_op;
1458       status = wrap::miopenCreateOpActivationBackward(
1459           fusion_plan_, &actv_op,
1460           activation_descriptor.miopen_activation_mode_);
1461       if (status != miopenStatusSuccess) {
1462         LOG(FATAL) << "call to miopenCreateOpActivationBackward failed: "
1463                    << ToString(status);
1464       }
1465 
1466       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1467       if (status != miopenStatusSuccess) {
1468         VLOG(2) << "call to miopenCompileFusionPlan (BnA backward) failed: "
1469                 << ToString(status);
1470 
1471         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1472       } else {
1473         VLOG(2) << "Fusion Plan compile succedded (BnA backward) ";
1474         fusion_plan_compiled_ = true;
1475       }
1476     } else {
1477       // fusion plan was already compiled...check whether it failed to compile
1478       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1479     }
1480   }
1481 
SetBatchNormBackwardArgs(const void * x,const void * scale,const void * offset,const void * saved_mean,const void * saved_var,void * scale_grad,void * offset_grad)1482   miopenStatus_t SetBatchNormBackwardArgs(const void* x, const void* scale,
1483                                           const void* offset,
1484                                           const void* saved_mean,
1485                                           const void* saved_var,
1486                                           void* scale_grad, void* offset_grad) {
1487     float alpha = 1.0;
1488     float beta = 0.0;
1489 
1490     return ScopedFusionPlanBase::SetBatchNormBackwardArgs(
1491         k_batchnorm_op_idx, &alpha, &beta, x, scale, offset, scale_grad,
1492         offset_grad, saved_mean, saved_var);
1493   }
1494 
SetActivationBackwardArgs(ScopedActivationDescriptor & activation_descriptor,const void * y)1495   miopenStatus_t SetActivationBackwardArgs(
1496       ScopedActivationDescriptor& activation_descriptor, const void* y) {
1497     float alpha = 1.0;
1498     float beta = 0.0;
1499 
1500     return ScopedFusionPlanBase::SetActivationBackwardArgs(
1501         k_actv_op_idx, &alpha, &beta, y, activation_descriptor.alpha_,
1502         activation_descriptor.beta_, activation_descriptor.gamma_);
1503   }
1504 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1505   uint64 GetFusionOpHashValue(
1506       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1507       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1508       ScopedActivationDescriptor& activation_descriptor) {
1509     uint64 hash_value = tensorflow::Hash64("BatchNormActivationBackward");
1510 
1511     hash_value = tensorflow::Hash64Combine(
1512         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1513 
1514     hash_value =
1515         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1516 
1517     hash_value = tensorflow::Hash64Combine(
1518         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1519 
1520     hash_value = tensorflow::Hash64Combine(
1521         hash_value, activation_descriptor.GetHashValue());
1522     return hash_value;
1523   }
1524 
1525  private:
1526   const int k_batchnorm_op_idx = 0;
1527   const int k_actv_op_idx = 1;
1528 
1529   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationBackward);
1530 };
1531 
1532 namespace {
ToMIOpenDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1533 miopenDataType_t ToMIOpenDataType(
1534     dnn::DataType data_type,
1535     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1536   switch (data_type) {
1537     case dnn::DataType::kFloat:
1538       return miopenFloat;
1539     case dnn::DataType::kHalf:
1540       return miopenHalf;
1541     case dnn::DataType::kDouble:
1542     default:
1543       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1544   }
1545 }
1546 
ToMIOpenDataType(dnn::DataType data_type,dnn::FilterLayout filter_layout)1547 miopenDataType_t ToMIOpenDataType(dnn::DataType data_type,
1548                                   dnn::FilterLayout filter_layout) {
1549   return ToMIOpenDataType(data_type);
1550 }
1551 
ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode)1552 miopenRNNInputMode_t ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode) {
1553   switch (input_mode) {
1554     case dnn::RnnInputMode::kRnnLinearSkip:
1555       return miopenRNNlinear;
1556     case dnn::RnnInputMode::kRnnSkipInput:
1557       return miopenRNNskip;
1558     default:
1559       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1560   }
1561 }
1562 
ToMIOpenRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1563 miopenRNNDirectionMode_t ToMIOpenRnnDirectionMode(
1564     dnn::RnnDirectionMode direction_mode) {
1565   switch (direction_mode) {
1566     case dnn::RnnDirectionMode::kRnnUnidirectional:
1567       return miopenRNNunidirection;
1568     case dnn::RnnDirectionMode::kRnnBidirectional:
1569       return miopenRNNbidirection;
1570     default:
1571       LOG(FATAL) << "Invalid RNN direction mode: "
1572                  << static_cast<int>(direction_mode);
1573   }
1574 }
1575 
ToMIOpenRnnMode(dnn::RnnMode rnn_mode)1576 miopenRNNMode_t ToMIOpenRnnMode(dnn::RnnMode rnn_mode) {
1577   switch (rnn_mode) {
1578     case dnn::RnnMode::kRnnRelu:
1579       return miopenRNNRELU;
1580     case dnn::RnnMode::kRnnTanh:
1581       return miopenRNNTANH;
1582     case dnn::RnnMode::kRnnLstm:
1583       return miopenLSTM;
1584     case dnn::RnnMode::kRnnGru:
1585       return miopenGRU;
1586     default:
1587       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1588   }
1589 }
1590 
MIOpenDataTypeToByteSize(miopenDataType_t data_type)1591 int MIOpenDataTypeToByteSize(miopenDataType_t data_type) {
1592   switch (data_type) {
1593     case miopenFloat:
1594       return sizeof(float);
1595     case miopenHalf:
1596       return sizeof(Eigen::half);
1597     default:
1598       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1599   }
1600 }
1601 
1602 template <typename Base>
1603 class MixinBase : public Base {};
1604 template <>
1605 class MixinBase<void> {};
1606 
GetConvAccumulatorType(dnn::DataType data_type)1607 dnn::DataType GetConvAccumulatorType(dnn::DataType data_type) {
1608   switch (data_type) {
1609     case dnn::DataType::kFloat:
1610     case dnn::DataType::kDouble:
1611       return data_type;
1612     case dnn::DataType::kHalf:
1613       // FIXME: Check if MIOpen can switch dynamically change accumulator type
1614       return dnn::DataType::kFloat;
1615     case dnn::DataType::kInt8:
1616     case dnn::DataType::kInt32:
1617       return dnn::DataType::kInt32;
1618     default:
1619       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1620   }
1621 }
1622 
1623 }  // namespace
1624 
1625 #define RETURN_IF_MIOPEN_ERROR(STATUS, ...)                              \
1626   if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) {               \
1627     string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
1628     SetFailure(port::Status(port::error::UNKNOWN, error_msg));           \
1629     LOG(ERROR) << error_msg;                                             \
1630     return;                                                              \
1631   }
1632 
1633 template <typename Base>
1634 class MIOpenDescriptorCommon : public MixinBase<Base> {
1635  public:
ok() const1636   bool ok() const { return status_.ok(); }
Status() const1637   port::Status Status() const { return status_; }
1638 
1639  protected:
SetFailure(const port::Status & status)1640   void SetFailure(const port::Status& status) { status_.Update(status); }
1641   port::Status status_;
1642 };
1643 
1644 class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon<void> {
1645  public:
1646   typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
1647   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1648   MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,
1649                             const MIOpenRnnDescriptor& rnn_desc);
~MIOpenRnnParamsDescriptor()1650   ~MIOpenRnnParamsDescriptor() {
1651     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1652     RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy RNN tensor descriptor");
1653   }
handle() const1654   miopenTensorDescriptor_t handle() const {
1655     if (!ok()) return nullptr;
1656     return handle_;
1657   }
params_size_in_bytes() const1658   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1659   ParamsRegions params_weights() const {
1660     if (!ok()) return ParamsRegions();
1661     return weights_;
1662   }
params_biases() const1663   ParamsRegions params_biases() const {
1664     if (!ok()) return ParamsRegions();
1665     return biases_;
1666   }
1667 
1668  private:
1669   int GetRegionCountPerLayer() const;
1670   miopenTensorDescriptor_t handle_;
1671   const MIOpenRnnDescriptor* rnn_desc_;
1672   int64 params_size_in_bytes_;
1673   ParamsRegions weights_;
1674   ParamsRegions biases_;
1675   port::Status status_;
1676   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnParamsDescriptor);
1677 };
1678 
1679 class MIOpenRnnDescriptor : public MIOpenDescriptorCommon<dnn::RnnDescriptor> {
1680  public:
MIOpenRnnDescriptor(miopenHandle_t miopen_handle,int num_layers,int hidden_size,int input_size,miopenRNNInputMode_t input_mode,miopenRNNDirectionMode_t direction_mode,miopenRNNMode_t rnn_mode,miopenDataType_t data_type,float dropout,uint64 seed,ScratchAllocator * state_allocator)1681   MIOpenRnnDescriptor(miopenHandle_t miopen_handle, int num_layers,
1682                       int hidden_size, int input_size,
1683                       miopenRNNInputMode_t input_mode,
1684                       miopenRNNDirectionMode_t direction_mode,
1685                       miopenRNNMode_t rnn_mode, miopenDataType_t data_type,
1686                       float dropout, uint64 seed,
1687                       ScratchAllocator* state_allocator)
1688       : rnn_desc_(nullptr),
1689         num_layers_(num_layers),
1690         hidden_size_(hidden_size),
1691         input_size_(input_size),
1692         input_mode_(input_mode),
1693         direction_mode_(direction_mode),
1694         rnn_mode_(rnn_mode),
1695         data_type_(data_type) {
1696     // Create the RNN handle
1697     auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_);
1698     RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor");
1699     status = wrap::miopenSetRNNDescriptor(
1700         rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
1701         num_layers /*numLayers*/, input_mode /*inputMode*/,
1702         direction_mode /*direction*/, rnn_mode /*mode*/,
1703         miopenRNNwithBias /*biasMode*/, miopenRNNdefault /*algo*/,
1704         data_type /*dataType*/);
1705     RETURN_IF_MIOPEN_ERROR(status, "Unable to update RNN descriptor");
1706     // Create the params handle.
1707     miopen_params_desc_.reset(
1708         new MIOpenRnnParamsDescriptor(miopen_handle, *this));
1709     if (!miopen_params_desc_->ok()) {
1710       SetFailure(miopen_params_desc_->Status());
1711       return;
1712     }
1713   }
~MIOpenRnnDescriptor()1714   ~MIOpenRnnDescriptor() override {
1715     if (rnn_desc_) {
1716       auto status = wrap::miopenDestroyRNNDescriptor(rnn_desc_);
1717       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN descriptor");
1718     }
1719   }
handle() const1720   miopenRNNDescriptor_t handle() const {
1721     if (!ok()) return nullptr;
1722     return rnn_desc_;
1723   }
num_layers() const1724   int num_layers() const { return num_layers_; }
hidden_size() const1725   int hidden_size() const { return hidden_size_; }
input_size() const1726   int input_size() const { return input_size_; }
input_mode() const1727   miopenRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1728   miopenRNNDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1729   miopenRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1730   miopenDataType_t data_type() const { return data_type_; }
ParamsSizeInBytes() const1731   int64 ParamsSizeInBytes() const override {
1732     return miopen_params_desc_->params_size_in_bytes();
1733   }
params_handle() const1734   miopenTensorDescriptor_t params_handle() const {
1735     if (!miopen_params_desc_) return nullptr;
1736     return miopen_params_desc_->handle();
1737   }
ParamsWeightRegions() const1738   ParamsRegions ParamsWeightRegions() const override {
1739     if (!ok()) return ParamsRegions();
1740     return miopen_params_desc_->params_weights();
1741   }
ParamsBiasRegions() const1742   ParamsRegions ParamsBiasRegions() const override {
1743     if (!ok()) return ParamsRegions();
1744     return miopen_params_desc_->params_biases();
1745   }
1746 
1747  private:
1748   miopenRNNDescriptor_t rnn_desc_;
1749   int num_layers_;
1750   int hidden_size_;
1751   int input_size_;
1752   miopenRNNInputMode_t input_mode_;
1753   miopenRNNDirectionMode_t direction_mode_;
1754   miopenRNNMode_t rnn_mode_;
1755   miopenDataType_t data_type_;
1756   port::Status status_;
1757   // no dropout in MIOpen.
1758   // std::unique_ptr<miopenDropoutDescriptor> miopen_dropout_desc_;
1759   std::unique_ptr<MIOpenRnnParamsDescriptor> miopen_params_desc_;
1760   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnDescriptor);
1761 };
1762 
1763 // Get ID of the internal parameter tensor.
1764 //
GetRegionCountPerLayer() const1765 int MIOpenRnnParamsDescriptor::GetRegionCountPerLayer() const {
1766   auto rnn_mode = rnn_desc_->rnn_mode();
1767   switch (rnn_mode) {
1768     case miopenRNNRELU:
1769     case miopenRNNTANH:
1770       return 2;
1771     case miopenLSTM:
1772       return 8;
1773     case miopenGRU:
1774       return 6;
1775     default:
1776       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1777   }
1778 }
1779 
1780 class MIOpenRnnSequenceTensorDescriptor
1781     : public MIOpenDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
1782  public:
MIOpenRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,miopenDataType_t data_type)1783   MIOpenRnnSequenceTensorDescriptor(int seq_length, int batch_size,
1784                                     int data_size, miopenDataType_t data_type)
1785       : seq_length_(seq_length),
1786         batch_size_(batch_size),
1787         data_size_(data_size),
1788         data_type_(data_type) {
1789     miopenTensorDescriptor_t handle = nullptr;
1790     if (seq_length <= 0) {
1791       string error_msg =
1792           absl::StrCat("sequence length must be positive: ", seq_length);
1793       LOG(ERROR) << error_msg;
1794       SetFailure(port::Status(port::error::UNKNOWN, error_msg));
1795       return;
1796     }
1797     auto status = wrap::miopenCreateTensorDescriptor(&handle);
1798     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1799     std::array<int, 2> dims = {{batch_size, data_size}};
1800     status = wrap::miopenSetTensorDescriptor(
1801         handle /*tensorDesc*/, data_type /*dataType*/, 2 /*nbDims*/,
1802         dims.data() /*dimA*/, nullptr /*strideA*/);
1803     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1804     // Replicate handle across the number of steps.
1805     handles_.assign(seq_length, handle);
1806   }
1807 
~MIOpenRnnSequenceTensorDescriptor()1808   ~MIOpenRnnSequenceTensorDescriptor() override {
1809     // Only the first one needs to be destroyed. All others are the same.
1810     auto status = wrap::miopenDestroyTensorDescriptor(handles_[0]);
1811     RETURN_IF_MIOPEN_ERROR(status,
1812                            "Failed to destroy sequence tensor descriptor");
1813   }
1814 
handles() const1815   const miopenTensorDescriptor_t* handles() const {
1816     if (!ok()) return nullptr;
1817     CHECK(!handles_.empty()) << "handles cannot be empty";
1818     return handles_.data();
1819   }
1820 
seq_length() const1821   int seq_length() const { return seq_length_; }
batch_size() const1822   int batch_size() const { return batch_size_; }
data_size() const1823   int data_size() const { return data_size_; }
1824 
1825  private:
1826   int seq_length_;
1827   int batch_size_;
1828   int data_size_;
1829   miopenDataType_t data_type_;
1830   std::vector<miopenTensorDescriptor_t> handles_;
1831   port::Status status_;
1832   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnSequenceTensorDescriptor);
1833 };
1834 
1835 class MIOpenRnnStateTensorDescriptor
1836     : public MIOpenDescriptorCommon<dnn::RnnStateTensorDescriptor> {
1837  public:
MIOpenRnnStateTensorDescriptor(int num_layers,int batch_size,int data_size,miopenDataType_t data_type)1838   MIOpenRnnStateTensorDescriptor(int num_layers, int batch_size, int data_size,
1839                                  miopenDataType_t data_type)
1840       : handle_(nullptr),
1841         num_layers_(num_layers),
1842         batch_size_(batch_size),
1843         data_size_(data_size),
1844         data_type_(data_type) {
1845     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
1846     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1847     std::array<int, 3> dims = {{num_layers, batch_size, data_size}};
1848     status = wrap::miopenSetTensorDescriptor(
1849         handle_ /*tensorDesc*/, data_type /*dataType*/, 3 /*nbDims*/,
1850         dims.data() /*dimA*/, nullptr /*strideA*/);
1851     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1852   }
1853 
~MIOpenRnnStateTensorDescriptor()1854   ~MIOpenRnnStateTensorDescriptor() override {
1855     if (!handle_) {
1856       auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1857       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN state tensor");
1858     }
1859   }
1860 
handle() const1861   miopenTensorDescriptor_t handle() const {
1862     if (!ok()) return nullptr;
1863     return handle_;
1864   }
num_layers() const1865   int num_layers() const { return num_layers_; }
batch_size() const1866   int batch_size() const { return batch_size_; }
data_size() const1867   int data_size() const { return data_size_; }
1868 
1869  private:
1870   miopenTensorDescriptor_t handle_;
1871   int num_layers_;
1872   int batch_size_;
1873   int data_size_;
1874   port::Status status_;
1875   miopenDataType_t data_type_;
1876   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnStateTensorDescriptor);
1877 };
1878 
1879 namespace {
1880 
1881 struct RnnModelDims {
1882   int num_layers = 0;
1883   int batch_size = 0;
1884   int seq_length = 0;
1885   int hidden_size = 0;
1886   int input_size = 0;
1887   int dir_count = 0;
1888 };
1889 
1890 template <class T>
ExtractAndCheckRnnForward(const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,RnnModelDims * model_dims)1891 bool ExtractAndCheckRnnForward(
1892     const MIOpenRnnDescriptor& rnn_desc,
1893     const MIOpenRnnSequenceTensorDescriptor& input_desc,
1894     const DeviceMemory<T>& input_data,
1895     const MIOpenRnnStateTensorDescriptor& input_h_desc,
1896     const DeviceMemory<T>& input_h_data,
1897     const MIOpenRnnStateTensorDescriptor& input_c_desc,
1898     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
1899     const MIOpenRnnSequenceTensorDescriptor& output_desc,
1900     const DeviceMemory<T>& output_data,
1901     const MIOpenRnnStateTensorDescriptor& output_h_desc,
1902     const DeviceMemory<T>& output_h_data,
1903     const MIOpenRnnStateTensorDescriptor& output_c_desc,
1904     const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
1905   // extract model parameters
1906   model_dims->num_layers = rnn_desc.num_layers();
1907   model_dims->batch_size = input_desc.batch_size();
1908   model_dims->seq_length = input_desc.seq_length();
1909   model_dims->hidden_size = rnn_desc.hidden_size();
1910   model_dims->input_size = input_desc.data_size();
1911   model_dims->dir_count =
1912       (rnn_desc.direction_mode() == miopenRNNbidirection) ? 2 : 1;
1913 
1914   // check parameters
1915   if (!(input_h_desc.num_layers() ==
1916             model_dims->num_layers * model_dims->dir_count &&
1917         input_h_desc.batch_size() == model_dims->batch_size &&
1918         input_h_desc.data_size() == model_dims->hidden_size)) {
1919     LOG(ERROR) << "Invalid input_h shape";
1920     return false;
1921   }
1922   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
1923         input_h_desc.batch_size() == input_c_desc.batch_size() &&
1924         input_h_desc.data_size() == input_c_desc.data_size())) {
1925     LOG(ERROR) << "Invalid input_c shape";
1926     return false;
1927   }
1928   if (!(output_desc.seq_length() == model_dims->seq_length &&
1929         output_desc.batch_size() == model_dims->batch_size &&
1930         output_desc.data_size() ==
1931             model_dims->hidden_size * model_dims->dir_count)) {
1932     LOG(ERROR) << "Invalid output shape";
1933     return false;
1934   }
1935   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
1936         input_h_desc.batch_size() == output_h_desc.batch_size() &&
1937         input_h_desc.data_size() == output_h_desc.data_size())) {
1938     LOG(ERROR) << "Invalid output_h shape";
1939     return false;
1940   }
1941   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
1942         input_h_desc.batch_size() == output_c_desc.batch_size() &&
1943         input_h_desc.data_size() == output_c_desc.data_size())) {
1944     LOG(ERROR) << "Invalid output_h shape";
1945     return false;
1946   }
1947 
1948   return true;
1949 }
1950 
CheckRNNParameterSize(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc)1951 bool CheckRNNParameterSize(
1952     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc,
1953     const MIOpenRnnSequenceTensorDescriptor& input_desc) {
1954   size_t params_size_in_bytes = 0;
1955   auto status = wrap::miopenGetRNNParamsSize(
1956       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
1957       input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
1958       rnn_desc.data_type() /*dataType*/);
1959   if (status != miopenStatusSuccess) {
1960     LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
1961     return false;
1962   }
1963   return static_cast<int64>(params_size_in_bytes) ==
1964          rnn_desc.ParamsSizeInBytes();
1965 }
1966 
CreateRnnWorkspace(Stream * stream,miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * workspace)1967 bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle,
1968                         const MIOpenRnnDescriptor& rnn_desc,
1969                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
1970                         ScratchAllocator* workspace_allocator,
1971                         DeviceMemory<uint8>* workspace) {
1972   // Query the workspace size.
1973   size_t workspace_size_in_bytes = 0;
1974   auto status = wrap::miopenGetRNNWorkspaceSize(
1975       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
1976       input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
1977       &workspace_size_in_bytes /*sizeInBytes*/);
1978   if (status != miopenStatusSuccess) {
1979     LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
1980     return false;
1981   }
1982   // Allocate the workspace.
1983   if (workspace_size_in_bytes > 0) {
1984     auto allocated =
1985         workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
1986     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
1987       LOG(ERROR) << "Failed to allocate RNN workspace";
1988 
1989       return false;
1990     }
1991     stream->ThenMemZero(workspace, workspace_size_in_bytes);
1992   } else {
1993     *workspace = DeviceMemory<uint8>();
1994   }
1995   return true;
1996 }
1997 
1998 }  // namespace
1999 
2000 template <class T>
DoRnnForwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)2001 bool MIOpenSupport::DoRnnForwardImpl(
2002     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2003     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2004     const DeviceMemory<T>& input_data,
2005     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2006     const DeviceMemory<T>& input_h_data,
2007     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2008     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2009     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2010     DeviceMemory<T>* output_data,
2011     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2012     DeviceMemory<T>* output_h_data,
2013     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2014     DeviceMemory<T>* output_c_data, bool is_training,
2015     ScratchAllocator* reserve_space_allocator,
2016     ScratchAllocator* workspace_allocator) {
2017   // extract model parameters
2018   RnnModelDims model_dims;
2019   bool res = ExtractAndCheckRnnForward(
2020       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2021       input_c_desc, input_c_data, params, output_desc, *output_data,
2022       output_h_desc, *output_h_data, output_c_desc, *output_c_data,
2023       &model_dims);
2024   if (!res) {
2025     LOG(ERROR) << "Invalid parameters for RNN Model";
2026     return false;
2027   }
2028 
2029   auto miopen = miopen_->GetHandle(parent_, stream);
2030 
2031   // check params size
2032 
2033   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2034     LOG(ERROR) << "Invalid parameters";
2035     return false;
2036   }
2037 
2038   // create the workspace
2039   DeviceMemory<uint8> workspace;
2040   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2041                           workspace_allocator, &workspace)) {
2042     LOG(ERROR) << "Unable to create rnn workspace";
2043 
2044     return false;
2045   }
2046 
2047   // query the reserve space size
2048   // allocate the reserve space
2049   DeviceMemory<uint8> reserve_space;
2050   if (is_training) {
2051     size_t reserve_space_size_in_bytes = 0;
2052     auto status = wrap::miopenGetRNNTrainingReserveSize(
2053         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2054         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2055         &reserve_space_size_in_bytes /*sizeInBytes*/);
2056     if (status != miopenStatusSuccess) {
2057       LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
2058       return false;
2059     }
2060 
2061     if (reserve_space_size_in_bytes > 0) {
2062       auto allocated = reserve_space_allocator->AllocateBytes(
2063           stream, reserve_space_size_in_bytes);
2064       if (!allocated.ok() ||
2065           (reserve_space = allocated.ValueOrDie()) == nullptr) {
2066         LOG(ERROR) << "Fail to allocate RNN reserve space";
2067         return false;
2068       }
2069       stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes);
2070     }
2071   }
2072 
2073   // make the forward call
2074   if (!is_training) {
2075     auto status = wrap::miopenRNNForwardInference(
2076         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2077         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2078         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2079         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2080         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2081         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2082         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2083         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2084         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2085         workspace.size() /*workSpaceSizeInBytes*/);
2086 
2087     if (status != miopenStatusSuccess) {
2088       LOG(ERROR) << "Failed to call miopenRNNForwardInference: "
2089                  << ToString(status);
2090       return false;
2091     }
2092   } else {
2093     auto status = wrap::miopenRNNForwardTraining(
2094         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2095         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2096         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2097         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2098         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2099         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2100         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2101         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2102         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2103         workspace.size() /*workSpaceSizeInBytes*/,
2104         reserve_space.opaque() /*reserveSpace*/,
2105         reserve_space.size() /*reserveSpaceSizeInBytes*/);
2106     if (status != miopenStatusSuccess) {
2107       LOG(ERROR) << "Failed to call miopenRNNForwardTraining"
2108                  << ToString(status);
2109       return false;
2110     }
2111   }
2112   return true;
2113 }
2114 
2115 template <class T>
DoRnnBackwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)2116 bool MIOpenSupport::DoRnnBackwardImpl(
2117     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2118     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2119     const DeviceMemory<T>& input_data,
2120     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2121     const DeviceMemory<T>& input_h_data,
2122     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2123     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2124     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2125     const DeviceMemory<T>& output_data,
2126     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2127     const DeviceMemory<T>& output_h_data,
2128     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2129     const DeviceMemory<T>& output_c_data,
2130     const DeviceMemory<T>& output_backprop_data,
2131     const DeviceMemory<T>& output_h_backprop_data,
2132     const DeviceMemory<T>& output_c_backprop_data,
2133     DeviceMemory<T>* input_backprop_data,
2134     DeviceMemory<T>* input_h_backprop_data,
2135     DeviceMemory<T>* input_c_backprop_data,
2136     DeviceMemory<T>* params_backprop_data,
2137     DeviceMemory<uint8>* reserve_space_data,
2138     ScratchAllocator* workspace_allocator) {
2139   // extract model parameters
2140   RnnModelDims model_dims;
2141   bool res = ExtractAndCheckRnnForward(
2142       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2143       input_c_desc, input_c_data, params, output_desc, output_data,
2144       output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
2145   if (!res) {
2146     LOG(ERROR) << "Invalid parameters for RNN Model";
2147     return false;
2148   }
2149 
2150   auto miopen = miopen_->GetHandle(parent_, stream);
2151 
2152   // check params size
2153 
2154   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2155     LOG(ERROR) << "Invalid parameters";
2156     return false;
2157   }
2158 
2159   // create the workspace
2160   DeviceMemory<uint8> workspace;
2161   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2162                           workspace_allocator, &workspace)) {
2163     LOG(ERROR) << "Unable to create rnn workspace";
2164     return false;
2165   }
2166 
2167   // workaround for missing initialization support in MIOpen.
2168   // TODO: remove this when MIOpen is ready.
2169   auto size_data = input_desc.seq_length() * input_desc.batch_size() *
2170                    input_desc.data_size();
2171   if ((size_data > 0) && (input_backprop_data->opaque() != nullptr))
2172     stream->ThenMemZero(input_backprop_data, size_data * sizeof(float));
2173 
2174   size_data = input_h_desc.num_layers() * input_h_desc.batch_size() *
2175               input_h_desc.data_size();
2176   if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr))
2177     stream->ThenMemZero(input_h_backprop_data, size_data * sizeof(float));
2178 
2179   size_data = input_c_desc.num_layers() * input_c_desc.batch_size() *
2180               input_c_desc.data_size();
2181   if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr))
2182     stream->ThenMemZero(input_c_backprop_data, size_data * sizeof(float));
2183 
2184   // make the backward data call
2185   auto status = wrap::miopenRNNBackwardData(
2186       miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2187       model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
2188       output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
2189       output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
2190       output_h_backprop_data.opaque() /*dhy*/,
2191       output_c_desc.handle() /*dcyDesc*/,
2192       output_c_backprop_data.opaque() /*dcy*/,
2193       rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
2194       input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
2195       input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
2196       input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
2197       input_h_desc.handle() /*dhxDesc*/,
2198       input_h_backprop_data->opaque() /*dhx*/,
2199       input_c_desc.handle() /*dcxDesc*/,
2200       input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
2201       workspace.size() /*workSpaceSizeInBytes*/,
2202       reserve_space_data->opaque() /*reserveSpace*/,
2203       reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2204   if (status != miopenStatusSuccess) {
2205     LOG(ERROR) << "Failed to call miopenRNNBackwardData: " << ToString(status);
2206     return false;
2207   }
2208 
2209   if (params_backprop_data != nullptr) {
2210     // Clear the dw to zeros.
2211     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2212     // make the backward weight call
2213     status = wrap::miopenRNNBackwardWeights(
2214         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2215         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2216         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2217         input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
2218         output_data.opaque() /*y*/, rnn_desc.params_handle() /*dwDesc*/,
2219         params_backprop_data->opaque() /*dw*/, workspace.opaque() /*workspace*/,
2220         workspace.size() /*workSpaceSizeInBytes*/,
2221         reserve_space_data->opaque() /*reserveSpace*/,
2222         reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2223     if (status != miopenStatusSuccess) {
2224       LOG(ERROR) << "Failed to call miopenRNNBackwardWeights: "
2225                  << ToString(status);
2226       return false;
2227     }
2228   }
2229 
2230   return true;
2231 }
2232 
MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc)2233 MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
2234     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc)
2235     : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
2236   miopenTensorDescriptor_t input_desc = nullptr;
2237   {
2238     // Query the params size.
2239     auto status = wrap::miopenCreateTensorDescriptor(&input_desc);
2240     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create tensor descriptor");
2241     std::array<int, 2> dims = {{1, rnn_desc.input_size()}};
2242     status = wrap::miopenSetTensorDescriptor(
2243         input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
2244         2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/);
2245     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to set tensor descriptor");
2246 
2247     size_t params_size = 0;
2248     status = wrap::miopenGetRNNParamsSize(
2249         miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2250         input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
2251         rnn_desc.data_type() /*dataType*/);
2252     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to get RNN parameter size");
2253     params_size_in_bytes_ = static_cast<int64>(params_size);
2254   }
2255 
2256   {
2257     // Create the params descriptor.
2258     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
2259     RETURN_IF_MIOPEN_ERROR(status,
2260                            "MIOpen fails to create RNN params descriptor");
2261     status = wrap::miopenGetRNNParamsDescriptor(miopen_handle,
2262                                                 rnn_desc.handle(), input_desc,
2263                                                 handle_, rnn_desc.data_type());
2264     RETURN_IF_MIOPEN_ERROR(status,
2265                            "MIOpen fails to update RNN filter descriptor");
2266   }
2267   {
2268     // Release the dummy input tensor descriptor.
2269     auto status = wrap::miopenDestroyTensorDescriptor(input_desc);
2270     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to destroy tensor descriptor");
2271   }
2272 }
2273 
2274 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator)2275 MIOpenSupport::createRnnDescriptor(
2276     int num_layers, int hidden_size, int input_size, int batch_size,
2277     dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
2278     dnn::RnnMode rnn_mode, dnn::DataType data_type,
2279     const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
2280     ScratchAllocator* state_allocator) {
2281   // ROCM TODO: batch_size is ignored for now
2282 
2283   auto miopen = miopen_->GetHandle(parent_, nullptr);
2284   std::unique_ptr<MIOpenRnnDescriptor> rnn_desc(new MIOpenRnnDescriptor(
2285       miopen.handle(), num_layers, hidden_size, input_size,
2286       ToMIOpenRnnInputMode(input_mode),
2287       ToMIOpenRnnDirectionMode(direction_mode), ToMIOpenRnnMode(rnn_mode),
2288       ToMIOpenDataType(data_type), dropout, seed, state_allocator));
2289   if (!rnn_desc->ok()) {
2290     return rnn_desc->Status();
2291   }
2292   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
2293       std::move(rnn_desc));
2294 }
2295 
2296 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)2297 MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
2298                                                  int data_size,
2299                                                  dnn::DataType data_type) {
2300   std::unique_ptr<MIOpenRnnSequenceTensorDescriptor> seq_desc(
2301       new MIOpenRnnSequenceTensorDescriptor(seq_length, batch_size, data_size,
2302                                             ToMIOpenDataType(data_type)));
2303   if (!seq_desc->ok()) {
2304     return seq_desc->Status();
2305   }
2306   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
2307       std::move(seq_desc));
2308 }
2309 
2310 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2311 MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2312                                               int data_size,
2313                                               dnn::DataType data_type) {
2314   std::unique_ptr<MIOpenRnnStateTensorDescriptor> state_desc(
2315       new MIOpenRnnStateTensorDescriptor(num_layer, batch_size, data_size,
2316                                          ToMIOpenDataType(data_type)));
2317   if (!state_desc->ok()) {
2318     return state_desc->Status();
2319   }
2320   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
2321       std::move(state_desc));
2322 }
2323 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2324 bool MIOpenSupport::DoRnnForward(
2325     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2326     const dnn::RnnSequenceTensorDescriptor& input_desc,
2327     const DeviceMemory<Eigen::half>& input_data,
2328     const dnn::RnnStateTensorDescriptor& input_h_desc,
2329     const DeviceMemory<Eigen::half>& input_h_data,
2330     const dnn::RnnStateTensorDescriptor& input_c_desc,
2331     const DeviceMemory<Eigen::half>& input_c_data,
2332     const DeviceMemory<Eigen::half>& params,
2333     const dnn::RnnSequenceTensorDescriptor& output_desc,
2334     DeviceMemory<Eigen::half>* output_data,
2335     const dnn::RnnStateTensorDescriptor& output_h_desc,
2336     DeviceMemory<Eigen::half>* output_h_data,
2337     const dnn::RnnStateTensorDescriptor& output_c_desc,
2338     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2339     ScratchAllocator* reserve_space_allocator,
2340     ScratchAllocator* workspace_allocator,
2341     dnn::ProfileResult* output_profile_result) {
2342   // ROCM TODO: output_profile_result is ignore for now
2343 
2344   const MIOpenRnnDescriptor& miopen_rnn_desc =
2345       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2346   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2347       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2348   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2349       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2350   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2351       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2352   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2353       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2354   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2355       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2356   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2357       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2358 
2359   return DoRnnForwardImpl<Eigen::half>(
2360       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2361       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2362       params, miopen_output_desc, output_data, miopen_output_h_desc,
2363       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2364       reserve_space_allocator, workspace_allocator);
2365 }
2366 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2367 bool MIOpenSupport::DoRnnForward(
2368     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2369     const dnn::RnnSequenceTensorDescriptor& input_desc,
2370     const DeviceMemory<float>& input_data,
2371     const dnn::RnnStateTensorDescriptor& input_h_desc,
2372     const DeviceMemory<float>& input_h_data,
2373     const dnn::RnnStateTensorDescriptor& input_c_desc,
2374     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2375     const dnn::RnnSequenceTensorDescriptor& output_desc,
2376     DeviceMemory<float>* output_data,
2377     const dnn::RnnStateTensorDescriptor& output_h_desc,
2378     DeviceMemory<float>* output_h_data,
2379     const dnn::RnnStateTensorDescriptor& output_c_desc,
2380     DeviceMemory<float>* output_c_data, bool is_training,
2381     ScratchAllocator* reserve_space_allocator,
2382     ScratchAllocator* workspace_allocator,
2383     dnn::ProfileResult* output_profile_result) {
2384   // ROCM TODO: output_profile_result is ignore for now
2385 
2386   const MIOpenRnnDescriptor& miopen_rnn_desc =
2387       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2388   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2389       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2390   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2391       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2392   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2393       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2394   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2395       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2396   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2397       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2398   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2399       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2400 
2401   return DoRnnForwardImpl<float>(
2402       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2403       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2404       params, miopen_output_desc, output_data, miopen_output_h_desc,
2405       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2406       reserve_space_allocator, workspace_allocator);
2407 }
2408 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2409 bool MIOpenSupport::DoRnnForward(
2410     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2411     const dnn::RnnSequenceTensorDescriptor& input_desc,
2412     const DeviceMemory<double>& input_data,
2413     const dnn::RnnStateTensorDescriptor& input_h_desc,
2414     const DeviceMemory<double>& input_h_data,
2415     const dnn::RnnStateTensorDescriptor& input_c_desc,
2416     const DeviceMemory<double>& input_c_data,
2417     const DeviceMemory<double>& params,
2418     const dnn::RnnSequenceTensorDescriptor& output_desc,
2419     DeviceMemory<double>* output_data,
2420     const dnn::RnnStateTensorDescriptor& output_h_desc,
2421     DeviceMemory<double>* output_h_data,
2422     const dnn::RnnStateTensorDescriptor& output_c_desc,
2423     DeviceMemory<double>* output_c_data, bool is_training,
2424     ScratchAllocator* reserve_space_allocator,
2425     ScratchAllocator* workspace_allocator,
2426     dnn::ProfileResult* output_profile_result) {
2427   LOG(ERROR) << "miopen does not support double type RNN fwd yet";
2428   return false;
2429 }
2430 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2431 bool MIOpenSupport::DoRnnBackward(
2432     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2433     const dnn::RnnSequenceTensorDescriptor& input_desc,
2434     const DeviceMemory<Eigen::half>& input_data,
2435     const dnn::RnnStateTensorDescriptor& input_h_desc,
2436     const DeviceMemory<Eigen::half>& input_h_data,
2437     const dnn::RnnStateTensorDescriptor& input_c_desc,
2438     const DeviceMemory<Eigen::half>& input_c_data,
2439     const DeviceMemory<Eigen::half>& params,
2440     const dnn::RnnSequenceTensorDescriptor& output_desc,
2441     const DeviceMemory<Eigen::half>& output_data,
2442     const dnn::RnnStateTensorDescriptor& output_h_desc,
2443     const DeviceMemory<Eigen::half>& output_h_data,
2444     const dnn::RnnStateTensorDescriptor& output_c_desc,
2445     const DeviceMemory<Eigen::half>& output_c_data,
2446     const DeviceMemory<Eigen::half>& output_backprop_data,
2447     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2448     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2449     DeviceMemory<Eigen::half>* input_backprop_data,
2450     DeviceMemory<Eigen::half>* input_h_backprop_data,
2451     DeviceMemory<Eigen::half>* input_c_backprop_data,
2452     DeviceMemory<Eigen::half>* params_backprop_data,
2453     DeviceMemory<uint8>* reserve_space_data,
2454     ScratchAllocator* workspace_allocator,
2455     dnn::ProfileResult* output_profile_result) {
2456   // ROCM TODO: output_profile_result is ignore for now
2457 
2458   const MIOpenRnnDescriptor& miopen_rnn_desc =
2459       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2460   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2461       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2462   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2463       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2464   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2465       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2466   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2467       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2468   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2469       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2470   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2471       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2472 
2473   return DoRnnBackwardImpl<Eigen::half>(
2474       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2475       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2476       params, miopen_output_desc, output_data, miopen_output_h_desc,
2477       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2478       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2479       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2480       reserve_space_data, workspace_allocator);
2481 }
2482 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2483 bool MIOpenSupport::DoRnnBackward(
2484     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2485     const dnn::RnnSequenceTensorDescriptor& input_desc,
2486     const DeviceMemory<float>& input_data,
2487     const dnn::RnnStateTensorDescriptor& input_h_desc,
2488     const DeviceMemory<float>& input_h_data,
2489     const dnn::RnnStateTensorDescriptor& input_c_desc,
2490     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2491     const dnn::RnnSequenceTensorDescriptor& output_desc,
2492     const DeviceMemory<float>& output_data,
2493     const dnn::RnnStateTensorDescriptor& output_h_desc,
2494     const DeviceMemory<float>& output_h_data,
2495     const dnn::RnnStateTensorDescriptor& output_c_desc,
2496     const DeviceMemory<float>& output_c_data,
2497     const DeviceMemory<float>& output_backprop_data,
2498     const DeviceMemory<float>& output_h_backprop_data,
2499     const DeviceMemory<float>& output_c_backprop_data,
2500     DeviceMemory<float>* input_backprop_data,
2501     DeviceMemory<float>* input_h_backprop_data,
2502     DeviceMemory<float>* input_c_backprop_data,
2503     DeviceMemory<float>* params_backprop_data,
2504     DeviceMemory<uint8>* reserve_space_data,
2505     ScratchAllocator* workspace_allocator,
2506     dnn::ProfileResult* output_profile_result) {
2507   // ROCM TODO: output_profile_result is ignore for now
2508 
2509   const MIOpenRnnDescriptor& miopen_rnn_desc =
2510       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2511   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2512       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2513   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2514       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2515   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2516       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2517   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2518       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2519   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2520       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2521   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2522       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2523 
2524   return DoRnnBackwardImpl<float>(
2525       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2526       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2527       params, miopen_output_desc, output_data, miopen_output_h_desc,
2528       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2529       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2530       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2531       reserve_space_data, workspace_allocator);
2532 }
2533 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2534 bool MIOpenSupport::DoRnnBackward(
2535     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2536     const dnn::RnnSequenceTensorDescriptor& input_desc,
2537     const DeviceMemory<double>& input_data,
2538     const dnn::RnnStateTensorDescriptor& input_h_desc,
2539     const DeviceMemory<double>& input_h_data,
2540     const dnn::RnnStateTensorDescriptor& input_c_desc,
2541     const DeviceMemory<double>& input_c_data,
2542     const DeviceMemory<double>& params,
2543     const dnn::RnnSequenceTensorDescriptor& output_desc,
2544     const DeviceMemory<double>& output_data,
2545     const dnn::RnnStateTensorDescriptor& output_h_desc,
2546     const DeviceMemory<double>& output_h_data,
2547     const dnn::RnnStateTensorDescriptor& output_c_desc,
2548     const DeviceMemory<double>& output_c_data,
2549     const DeviceMemory<double>& output_backprop_data,
2550     const DeviceMemory<double>& output_h_backprop_data,
2551     const DeviceMemory<double>& output_c_backprop_data,
2552     DeviceMemory<double>* input_backprop_data,
2553     DeviceMemory<double>* input_h_backprop_data,
2554     DeviceMemory<double>* input_c_backprop_data,
2555     DeviceMemory<double>* params_backprop_data,
2556     DeviceMemory<uint8>* reserve_space_data,
2557     ScratchAllocator* workspace_allocator,
2558     dnn::ProfileResult* output_profile_result) {
2559   LOG(ERROR) << "miopen does not support half type RNN bwd yet";
2560   return false;
2561 }
2562 
2563 // This is the context required to use the TF scratch allocator:
2564 struct MIOpenAllocatorContext {
MIOpenAllocatorContextstream_executor::gpu::MIOpenAllocatorContext2565   MIOpenAllocatorContext(ScratchAllocator* scratch_allocator, Stream* stream)
2566       : scratch_allocator_(scratch_allocator), stream_(stream) {}
2567 
2568   ScratchAllocator* scratch_allocator_;
2569   Stream* stream_;
2570 };
2571 
MIOpenAllocatorCallback(void * ctx,size_t size_in_bytes)2572 void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
2573   auto* mac = static_cast<MIOpenAllocatorContext*>(ctx);
2574   auto allocated =
2575       mac->scratch_allocator_->AllocateBytes(mac->stream_, size_in_bytes);
2576 
2577   DeviceMemory<uint8> scratch;
2578   if (allocated.ok()) {
2579     scratch = allocated.ValueOrDie();
2580     return scratch.opaque();
2581   } else {
2582     return nullptr;
2583   }
2584 }
2585 
MIOpenDeallocatorCallback(void * ctx,void * mem)2586 void MIOpenDeallocatorCallback(void* ctx, void* mem) {
2587   // Don't need dealloactor since the TensorFlow heap will automatically reclaim
2588   // the memory
2589 }
2590 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2591 port::Status MIOpenSupport::DoPrepareForConvolution(
2592     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2593     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2594     const dnn::FilterDescriptor& filter_descriptor,
2595     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2596     DeviceMemoryBase output_data,
2597     const dnn::ConvolutionDescriptor& convolution_descriptor,
2598     const dnn::AlgorithmConfig& algorithm_config,
2599     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2600     DeviceMemory<uint8>* scratch_memory) {
2601   ScopedTensorDescriptor input_nd{
2602       input_descriptor,
2603       ToMIOpenDataType(element_type, input_descriptor.layout())};
2604   ScopedFilterDescriptor filter{
2605       filter_descriptor, input_descriptor,
2606       ToMIOpenDataType(element_type, filter_descriptor.layout())};
2607   ScopedTensorDescriptor output_nd{
2608       output_descriptor,
2609       ToMIOpenDataType(element_type, output_descriptor.layout())};
2610   ScopedConvolutionDescriptor conv{
2611       convolution_descriptor,
2612       ToMIOpenDataType(GetConvAccumulatorType(element_type))};
2613 
2614   auto miopen = miopen_->GetHandle(parent_, stream);
2615 
2616   absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
2617   size_t scratch_memory_size;
2618 
2619   if (!algo_desc.has_value()) {
2620     // With the default algorithm, use MIOpen's heuristics.
2621     assert(scratch_allocator);
2622 
2623     DeviceMemory<uint8> scratch_memory_temp;
2624     MIOpenAllocatorContext mac(scratch_allocator, stream);
2625     wrap::miopenSetAllocator(miopen.handle(), MIOpenAllocatorCallback,
2626                              MIOpenDeallocatorCallback, &mac);
2627     size_t size_in_bytes;
2628     miopenStatus_t status = miopenStatusSuccess;
2629 
2630     switch (kind) {
2631       case dnn::ConvolutionKind::FORWARD: {
2632         status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
2633             miopen.handle(), /*filterDesc=*/filter.handle(),
2634             /*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
2635             /*destDesc=*/output_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
2636         break;
2637       }
2638       case dnn::ConvolutionKind::BACKWARD_DATA: {
2639         status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
2640             miopen.handle(), /*diffDesc=*/output_nd.handle(),
2641             /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
2642             /*gradDesc=*/input_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
2643         break;
2644       }
2645       case dnn::ConvolutionKind::BACKWARD_FILTER: {
2646         status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
2647             miopen.handle(), /*diffDesc=*/output_nd.handle(),
2648             /*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
2649             /*gradDesc=*/filter.handle(), /*sizeInBytes=*/&size_in_bytes);
2650         break;
2651       }
2652       default:
2653         return port::InternalError(absl::StrCat("Unexpected convolution kind ",
2654                                                 static_cast<int>(kind)));
2655     }
2656 
2657     if (status == miopenStatusSuccess && size_in_bytes != 0) {
2658       auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
2659       if (allocated.ok()) {
2660         scratch_memory_temp = allocated.ValueOrDie();
2661       }
2662     }
2663 
2664     miopenConvAlgoPerf_t preference;
2665     int returnedAlgoCount;
2666 
2667     switch (kind) {
2668       case dnn::ConvolutionKind::FORWARD: {
2669         auto status = wrap::miopenFindConvolutionForwardAlgorithm(
2670             miopen.handle(), input_nd.handle(), input_data.opaque(),
2671             filter.handle(), filter_data.opaque(), conv.handle(),
2672             output_nd.handle(), output_data.opaque(),
2673             /*requestAlgoCount=*/1, &returnedAlgoCount,
2674             /*preference=*/&preference,
2675             /*workspace*/ scratch_memory_temp.opaque(),
2676             /*WorkSpaceSize*/ scratch_memory_temp.size(),
2677             /*exhaustiveSearch*/ false);
2678         CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
2679                                                  "algorithm for doing forward "
2680                                                  "convolution";
2681         *algorithm_desc = dnn::AlgorithmDesc(preference.fwd_algo, false);
2682         break;
2683       }
2684       case dnn::ConvolutionKind::BACKWARD_DATA: {
2685         auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
2686             miopen.handle(),
2687             /*diffDesc=*/output_nd.handle(), output_data.opaque(),
2688             /*filterDesc=*/filter.handle(), filter_data.opaque(),
2689             /*convDesc=*/conv.handle(),
2690             /*gradDesc=*/input_nd.handle(), input_data.opaque(),
2691             /*requestCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
2692             /*preference=*/&preference,
2693             /*WorkSpace=*/scratch_memory_temp.opaque(),
2694             /*WorkSpaceSize=*/scratch_memory_temp.size(),
2695             /*exhaustiveSearch=*/false);
2696         CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
2697                                                  "algorithm for doing backward "
2698                                                  "data convolution";
2699         *algorithm_desc = dnn::AlgorithmDesc(preference.bwd_data_algo, false);
2700         break;
2701       }
2702       case dnn::ConvolutionKind::BACKWARD_FILTER: {
2703         auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
2704             miopen.handle(),
2705             /*diffDesc=*/output_nd.handle(), output_data.opaque(),
2706             /*srcDesc=*/input_nd.handle(), input_data.opaque(),
2707             /*convDesc=*/conv.handle(),
2708             /*gradDesc=*/filter.handle(), filter_data.opaque(),
2709             /*requestAlgoCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
2710             /*preference=*/&preference,
2711             /*WorkSpace=*/scratch_memory_temp.opaque(),
2712             /*WorkSpaceSize=*/scratch_memory_temp.size(),
2713             /*exhaustiveSearch=*/false);
2714         CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
2715                                                  "algorithm for doing backward "
2716                                                  "filter convolution";
2717         *algorithm_desc =
2718             dnn::AlgorithmDesc(preference.bwd_weights_algo, false);
2719         break;
2720       }
2721       default:
2722         return port::InternalError(absl::StrCat("Unexpected convolution kind ",
2723                                                 static_cast<int>(kind)));
2724     }
2725 
2726     // Restore default allocator, note mac is stack temp
2727     wrap::miopenSetAllocator(miopen.handle(), nullptr, nullptr, nullptr);
2728 
2729     scratch_memory_size = preference.memory;
2730   } else {
2731     // An algorithm has been specified.
2732     *algorithm_desc = *algo_desc;
2733     // commenting this line out for the upstream repo, since
2734     // AlgorithmConfig::scratch_size_ has been removed in the upstream repo but
2735     // is still used in the ROCM develop-upstream repo
2736     //
2737     // scratch_memory_size = *(algorithm_config.scratch_size());
2738     //
2739   }
2740 
2741   // allocate scratch memory
2742   if (scratch_memory_size != 0) {
2743     if (scratch_allocator == nullptr) {
2744       return port::InternalError(
2745           absl::StrCat("An allocator must be specified when scratch memory is "
2746                        "needed"));
2747     }
2748     auto allocated =
2749         scratch_allocator->AllocateBytes(stream, scratch_memory_size);
2750     if (!allocated.ok()) {
2751       return port::InternalError(absl::StrCat(
2752           "Failed to allocate scratch memory of size: ", scratch_memory_size));
2753     }
2754     if (allocated.ok()) {
2755       *scratch_memory = allocated.ValueOrDie();
2756     }
2757   }
2758 
2759   return port::Status::OK();
2760 }
2761 
2762 // NOTE(keveman): Temporary data layout transformation until MIOpen supports
2763 // kBatchYXDepth for backward pass. This function allocates temporary memory,
2764 // lays out the source data into the temporary but in the kBatchDepthXY
2765 // layout, and returns the temporary memory. The caller is responsible for
2766 // deallocating the temporary. Since the allocation is done using Stream's
2767 // AllocateTemporaryMemory, a later BlockHostUntilDone could be used for
2768 // deallocation.
2769 //
2770 // transform_scratch is populated with a legitimate temporary allocation iff
2771 // the original output data needs to be transformed.
MaybeTransformLayout(Stream * stream,miopenHandle_t handle_,int miopen_type,BatchDescriptor * output_descriptor,DeviceMemoryBase backward_output_data,std::unique_ptr<TemporaryDeviceMemory<uint8>> * transform_scratch)2772 static DeviceMemoryBase MaybeTransformLayout(
2773     Stream* stream, miopenHandle_t handle_,
2774     int miopen_type,  // Actually miopenDataType_t.
2775     BatchDescriptor* output_descriptor, DeviceMemoryBase backward_output_data,
2776     std::unique_ptr<TemporaryDeviceMemory<uint8>>* transform_scratch) {
2777   if (output_descriptor->layout() == dnn::DataLayout::kBatchDepthYX) {
2778     return backward_output_data;
2779   }
2780   CHECK(output_descriptor->layout() == dnn::DataLayout::kBatchYXDepth);
2781   *transform_scratch =
2782       stream->AllocateTemporaryArray<uint8>(backward_output_data.size())
2783           .ConsumeValueOrDie();
2784   BatchDescriptor transformed_output_descriptor;
2785   transformed_output_descriptor.CloneFrom(*output_descriptor);
2786   transformed_output_descriptor.set_layout(dnn::DataLayout::kBatchDepthYX);
2787   ScopedTensorDescriptor orig_out_back_nd{
2788       *output_descriptor, static_cast<miopenDataType_t>(miopen_type)};
2789   ScopedTensorDescriptor transformed_out_back_nd{
2790       transformed_output_descriptor,
2791       static_cast<miopenDataType_t>(miopen_type)};
2792 
2793   float alpha1 = 1.0f;
2794   float alpha2 = 0.0f;
2795   float beta = 0.0f;
2796   auto status = wrap::miopenOpTensor(
2797       handle_, miopenTensorOpAdd, &alpha1, orig_out_back_nd.handle(),
2798       backward_output_data.opaque(), &alpha2, orig_out_back_nd.handle(),
2799       backward_output_data.opaque(), &beta, transformed_out_back_nd.handle(),
2800       (*transform_scratch)->mutable_device_memory()->opaque());
2801 
2802   if (status != miopenStatusSuccess) {
2803     LOG(FATAL) << "Failed to transform the data layout.";
2804   }
2805   output_descriptor->set_layout(dnn::DataLayout::kBatchDepthYX);
2806   return (*transform_scratch)->device_memory();
2807 }
2808 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2809 port::Status MIOpenSupport::DoConvolve(
2810     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2811     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2812     const dnn::FilterDescriptor& filter_descriptor,
2813     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2814     DeviceMemoryBase output_data,
2815     const dnn::ConvolutionDescriptor& convolution_descriptor,
2816     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2817     dnn::ProfileResult* output_profile_result) {
2818   auto miopen = miopen_->GetHandle(parent_, stream);
2819   ScopedTensorDescriptor input_nd{input_descriptor,
2820                                   ToMIOpenDataType(element_type)};
2821   ScopedTensorDescriptor output_nd{output_descriptor,
2822                                    ToMIOpenDataType(element_type)};
2823   ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
2824                                 ToMIOpenDataType(element_type)};
2825   ScopedConvolutionDescriptor conv{convolution_descriptor,
2826                                    ToMIOpenDataType(element_type)};
2827 
2828   // Alpha is the scaling factor for input.
2829   float alpha = 1.0;
2830   // Beta is the scaling factor for output.
2831   float beta = 0.0;
2832 
2833   const bool is_profiling = output_profile_result != nullptr;
2834 
2835   std::unique_ptr<GpuTimer> timer;
2836   if (is_profiling) {
2837     timer.reset(new GpuTimer(parent_));
2838     if (!timer->Init()) {
2839       return port::Status(port::error::INTERNAL, "Failed to init timer");
2840     }
2841     // The start and stop of the timer should be as close to the MIOpen call as
2842     // possible. It is still possible for other threads to issue workload on
2843     // to this stream. So it could take multiple profiling measurements.
2844     if (!timer->Start(AsGpuStream(stream))) {
2845       timer->Destroy();
2846       return port::Status(port::error::INTERNAL, "Failed to start timer");
2847     }
2848   }
2849 
2850   miopenStatus_t status = miopenStatusSuccess;
2851   switch (kind) {
2852     case dnn::ConvolutionKind::FORWARD: {
2853       status = wrap::miopenConvolutionForward(
2854           miopen.handle(),
2855           /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
2856           /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
2857           /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
2858           /*algo=*/
2859           static_cast<miopenConvFwdAlgorithm_t>(algorithm_desc.algo_id()),
2860           /*beta=*/&beta, /*destDesc=*/output_nd.handle(),
2861           /*destData=*/output_data.opaque(),
2862           /*workSpace=*/scratch_memory.opaque(),
2863           /*workSpaceSizeInBytes=*/scratch_memory.size());
2864       break;
2865     }
2866     case dnn::ConvolutionKind::BACKWARD_DATA: {
2867       // TBD: remove once MIOpen supports kBatchYXDepth for backward pass.
2868       BatchDescriptor output_back_descriptor;
2869       output_back_descriptor.CloneFrom(output_descriptor);
2870       std::unique_ptr<TemporaryDeviceMemory<uint8>> transform_scratch;
2871       output_data = MaybeTransformLayout(
2872           stream, miopen.handle(), ToMIOpenDataType(element_type),
2873           &output_back_descriptor, output_data, &transform_scratch);
2874 
2875       status = wrap::miopenConvolutionBackwardData(
2876           miopen.handle(),
2877           /*alpha=*/&alpha,
2878           /*diffDesc=*/output_nd.handle(),
2879           /*diffData=*/output_data.opaque(),
2880           /*filterDesc=*/filter.handle(),
2881           /*filterData=*/filter_data.opaque(),
2882           /*convDesc=*/conv.handle(),
2883           /*algo=*/
2884           static_cast<miopenConvBwdDataAlgorithm_t>(algorithm_desc.algo_id()),
2885           /*beta=*/&beta,
2886           /*gradDesc=*/input_nd.handle(),
2887           /*gradData=*/input_data.opaque(),
2888           /*workSpace=*/scratch_memory.opaque(),
2889           /*workSpaceSizeInBytes=*/scratch_memory.size());
2890       break;
2891     }
2892     case dnn::ConvolutionKind::BACKWARD_FILTER: {
2893       // TBD: remove once MIOpen supports kBatchYXDepth for backward pass.
2894       BatchDescriptor output_back_descriptor;
2895       output_back_descriptor.CloneFrom(output_descriptor);
2896       std::unique_ptr<TemporaryDeviceMemory<uint8>> transform_scratch;
2897       output_data = MaybeTransformLayout(
2898           stream, miopen.handle(), ToMIOpenDataType(element_type),
2899           &output_back_descriptor, output_data, &transform_scratch);
2900 
2901       status = wrap::miopenConvolutionBackwardWeights(
2902           miopen.handle(),
2903           /*alpha=*/&alpha,
2904           /*diffDesc=*/output_nd.handle(),
2905           /*diffData=*/output_data.opaque(),
2906           /*srcDesc=*/input_nd.handle(),
2907           /*srcData=*/input_data.opaque(),
2908           /*convDesc=*/conv.handle(),
2909           /*algo=*/
2910           static_cast<miopenConvBwdWeightsAlgorithm_t>(
2911               algorithm_desc.algo_id()),
2912           /*beta=*/&beta,
2913           /*gradDesc=*/filter.handle(),
2914           /*gradData=*/filter_data.opaque(),
2915           /*workSpace=*/scratch_memory.opaque(),
2916           /*workSpaceSizeInBytes=*/scratch_memory.size());
2917       break;
2918     }
2919     default:
2920       return port::InternalError(
2921           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
2922   }
2923 
2924   if (is_profiling) {
2925     if (!timer->Stop(AsGpuStream(stream))) {
2926       timer->Destroy();
2927       return port::Status(port::error::INTERNAL, "Failed to stop timer");
2928     }
2929     if (status == miopenStatusSuccess) {
2930       dnn::AlgorithmDesc algotype(algorithm_desc.algo_id(), false);
2931       output_profile_result->set_algorithm(algotype);
2932       output_profile_result->set_elapsed_time_in_ms(
2933           timer->GetElapsedMilliseconds());
2934       output_profile_result->set_scratch_size(scratch_memory.size());
2935     }
2936     timer->Destroy();
2937   }
2938 
2939   if (status != miopenStatusSuccess) {
2940     return port::InternalError(absl::StrCat(
2941         "Failed to euqueue convolution on stream: ", ToString(status)));
2942   }
2943 
2944   return port::Status::OK();
2945 }
2946 
GetConvolveAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)2947 bool MIOpenSupport::GetConvolveAlgorithms(
2948     // ROCM TODO: refactor cc_major / cc_minor
2949     bool with_winograd_nonfused, int cc_major, int cc_minor,
2950     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
2951   out_algorithms->assign({
2952       // clang-format off
2953       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoGEMM, false),
2954       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoDirect, false),
2955       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoFFT, false),
2956       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoWinograd, false),
2957       // clang-format on
2958   });
2959   return true;
2960 }
2961 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)2962 bool MIOpenSupport::GetRnnAlgorithms(
2963     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
2964   // ROCM TODO: implement this with proper MIOpen API
2965   return true;
2966 }
2967 
GetConvolveBackwardDataAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)2968 bool MIOpenSupport::GetConvolveBackwardDataAlgorithms(
2969     // ROCM TODO: refactor cc_major / cc_minor
2970     bool with_winograd_nonfused, int cc_major, int cc_minor,
2971     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
2972   out_algorithms->assign({
2973       // clang-format off
2974       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoGEMM, false),
2975       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoDirect, false),
2976       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoFFT, false),
2977       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoWinograd, false),
2978       // clang-format on
2979   });
2980   return true;
2981 }
2982 
GetConvolveBackwardFilterAlgorithms(bool with_winograd_nonfused,int cc_major,int cc_minor,std::vector<dnn::AlgorithmDesc> * out_algorithms)2983 bool MIOpenSupport::GetConvolveBackwardFilterAlgorithms(
2984     // ROCM TODO: refactor cc_major / cc_minor
2985     bool with_winograd_nonfused, int cc_major, int cc_minor,
2986     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
2987   out_algorithms->assign({
2988       // clang-format off
2989       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoGEMM, false),
2990       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoDirect, false),
2991       // clang-format on
2992   });
2993   return true;
2994 }
2995 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)2996 bool MIOpenSupport::DoBatchNormalizationForward(
2997     Stream* stream, const DeviceMemory<Eigen::half>& x,
2998     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
2999     const DeviceMemory<float>& estimated_mean,
3000     const DeviceMemory<float>& estimated_variance,
3001     const dnn::BatchDescriptor& x_desc,
3002     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3003     DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
3004     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3005     DeviceMemory<float>* saved_inv_var, bool is_training,
3006     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3007     std::function<void()> inv_var_to_var) {
3008   return DoBatchNormalizationForwardImpl<Eigen::half, float>(
3009       stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3010       estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
3011       batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
3012       std::move(var_to_inv_var), std::move(inv_var_to_var));
3013 }
3014 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3015 bool MIOpenSupport::DoBatchNormalizationForward(
3016     Stream* stream, const DeviceMemory<float>& x,
3017     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3018     const DeviceMemory<float>& estimated_mean,
3019     const DeviceMemory<float>& estimated_variance,
3020     const dnn::BatchDescriptor& x_desc,
3021     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3022     DeviceMemory<float>* y, DeviceMemory<float>* batch_mean,
3023     DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
3024     DeviceMemory<float>* saved_inv_var, bool is_training,
3025     std::function<const DeviceMemory<float>&()> var_to_inv_var,
3026     std::function<void()> inv_var_to_var) {
3027   return DoBatchNormalizationForwardImpl<float, float>(
3028       stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
3029       estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
3030       batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
3031       std::move(var_to_inv_var), std::move(inv_var_to_var));
3032 }
3033 
3034 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training,std::function<const DeviceMemory<U> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)3035 bool MIOpenSupport::DoBatchNormalizationForwardImpl(
3036     Stream* stream, dnn::DataType input_data_type,
3037     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3038     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3039     const DeviceMemory<U>& estimated_mean,
3040     const DeviceMemory<U>& estimated_variance,
3041     const dnn::BatchDescriptor& x_desc,
3042     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3043     DeviceMemory<T>* y, DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3044     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3045     bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
3046     std::function<void()> inv_var_to_var) {
3047   auto miopen = miopen_->GetHandle(parent_, stream);
3048 
3049   ScopedTensorDescriptor x_descriptor{x_desc,
3050                                       ToMIOpenDataType(input_data_type)};
3051   ScopedTensorDescriptor scale_offset_descriptor{
3052       scale_offset_desc, ToMIOpenDataType(scale_data_type)};
3053   miopenBatchNormMode_t mode = miopenBNSpatial;
3054   float one = 1.0;
3055   float zero = 0.0;
3056 
3057   auto status = miopenStatusInvalidValue;
3058   if (is_training) {
3059     stream->ThenMemZero(batch_mean, batch_mean->size());
3060     stream->ThenMemZero(batch_var, batch_var->size());
3061     status = wrap::miopenBatchNormalizationForwardTraining(
3062         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3063         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3064         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3065         1.0, batch_mean->opaque(), batch_var->opaque(), epsilon,
3066         saved_mean->opaque(), saved_inv_var->opaque());
3067   } else {
3068     const void* maybe_inv_var = estimated_variance.opaque();
3069     status = wrap::miopenBatchNormalizationForwardInference(
3070         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3071         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3072         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3073         const_cast<void*>(estimated_mean.opaque()),
3074         const_cast<void*>(maybe_inv_var), epsilon);
3075   }
3076   if (status != miopenStatusSuccess) {
3077     LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
3078                << ToString(status);
3079     return false;
3080   }
3081   return true;
3082 }
3083 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3084 bool MIOpenSupport::DoBatchNormalizationBackward(
3085     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3086     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3087     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3088     const dnn::BatchDescriptor& x_desc,
3089     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3090     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3091     DeviceMemory<float>* offset_backprop) {
3092   return DoBatchNormalizationBackwardImpl<Eigen::half, float>(
3093       stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, inv_var,
3094       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3095       offset_backprop);
3096 }
3097 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)3098 bool MIOpenSupport::DoBatchNormalizationBackward(
3099     Stream* stream, const DeviceMemory<float>& y_backprop,
3100     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3101     const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
3102     const dnn::BatchDescriptor& x_desc,
3103     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3104     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3105     DeviceMemory<float>* offset_backprop) {
3106   return DoBatchNormalizationBackwardImpl<float, float>(
3107       stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, variance,
3108       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3109       offset_backprop);
3110 }
3111 
3112 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int miopen_input_type,int miopen_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3113 bool MIOpenSupport::DoBatchNormalizationBackwardImpl(
3114     Stream* stream, int miopen_input_type, int miopen_scale_type,
3115     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3116     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3117     const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
3118     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3119     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3120     DeviceMemory<U>* offset_backprop) {
3121   auto miopen = miopen_->GetHandle(parent_, stream);
3122   ScopedTensorDescriptor x_descriptor{
3123       x_desc, static_cast<miopenDataType_t>(miopen_input_type)};
3124   ScopedTensorDescriptor scale_offset_descriptor{
3125       scale_offset_desc, static_cast<miopenDataType_t>(miopen_scale_type)};
3126   miopenBatchNormMode_t mode = miopenBNSpatial;
3127   float one = 1.0;
3128   float zero = 0.0;
3129 
3130   auto status = wrap::miopenBatchNormalizationBackward(
3131       miopen.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3132       x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3133       x_descriptor.handle(), x_backprop->opaque(),
3134       scale_offset_descriptor.handle(), scale.opaque(),
3135       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3136       mean.opaque(), variance.opaque());
3137   if (status != miopenStatusSuccess) {
3138     LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
3139                << ToString(status);
3140     return false;
3141   }
3142   return true;
3143 }
3144 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3145 bool MIOpenSupport::DoFusedConvolve(
3146     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3147     const DeviceMemory<double>& conv_input_data, double conv_input_scale,
3148     const dnn::FilterDescriptor& filter_descriptor,
3149     const DeviceMemory<double>& filter_data,
3150     const dnn::ConvolutionDescriptor& convolution_descriptor,
3151     const DeviceMemory<double>& side_input_data, double side_input_scale,
3152     const dnn::BatchDescriptor& bias_descriptor,
3153     const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
3154     const dnn::BatchDescriptor& output_descriptor,
3155     DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
3156     const dnn::AlgorithmConfig& algorithm_config,
3157     dnn::ProfileResult* output_profile_result) {
3158   LOG(ERROR) << "fused convolve not implemented yet";
3159   return false;
3160 }
3161 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3162 bool MIOpenSupport::DoFusedConvolve(
3163     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3164     const DeviceMemory<float>& conv_input_data, float conv_input_scale,
3165     const dnn::FilterDescriptor& filter_descriptor,
3166     const DeviceMemory<float>& filter_data,
3167     const dnn::ConvolutionDescriptor& convolution_descriptor,
3168     const DeviceMemory<float>& side_input_data, float side_input_scale,
3169     const dnn::BatchDescriptor& bias_descriptor,
3170     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3171     const dnn::BatchDescriptor& output_descriptor,
3172     DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
3173     const dnn::AlgorithmConfig& algorithm_config,
3174     dnn::ProfileResult* output_profile_result) {
3175   LOG(ERROR) << "fused convolve not implemented yet";
3176   return false;
3177 }
3178 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3179 bool MIOpenSupport::DoFusedConvolve(
3180     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3181     const DeviceMemory<Eigen::half>& conv_input_data, float conv_input_scale,
3182     const dnn::FilterDescriptor& filter_descriptor,
3183     const DeviceMemory<Eigen::half>& filter_data,
3184     const dnn::ConvolutionDescriptor& convolution_descriptor,
3185     const DeviceMemory<Eigen::half>& side_input_data, float side_input_scale,
3186     const dnn::BatchDescriptor& bias_descriptor,
3187     const DeviceMemory<Eigen::half>& biases,
3188     dnn::ActivationMode activation_mode,
3189     const dnn::BatchDescriptor& output_descriptor,
3190     DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
3191     const dnn::AlgorithmConfig& algorithm_config,
3192     dnn::ProfileResult* output_profile_result) {
3193   LOG(ERROR) << "fused convolve not implemented yet";
3194   return false;
3195 }
3196 
DoFusedConvolve(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3197 bool MIOpenSupport::DoFusedConvolve(
3198     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
3199     const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
3200     const dnn::FilterDescriptor& filter_descriptor,
3201     const DeviceMemory<int8>& filter_data,
3202     const dnn::ConvolutionDescriptor& convolution_descriptor,
3203     const DeviceMemory<int8>& side_input_data, float side_input_scale,
3204     const dnn::BatchDescriptor& bias_descriptor,
3205     const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
3206     const dnn::BatchDescriptor& output_descriptor,
3207     DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
3208     const dnn::AlgorithmConfig& algorithm_config,
3209     dnn::ProfileResult* output_profile_result) {
3210   LOG(ERROR) << "fused convolve not implemented yet";
3211   return false;
3212 }
3213 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3214 bool MIOpenSupport::DoTransformTensor(Stream* stream,
3215                                       const dnn::BatchDescriptor& input_desc,
3216                                       dnn::DataType input_type,
3217                                       const DeviceMemoryBase& input_data,
3218                                       const dnn::BatchDescriptor& output_desc,
3219                                       dnn::DataType output_type, float scale,
3220                                       DeviceMemoryBase* output_data) {
3221   // ROCM TODO implement this operation
3222   LOG(ERROR) << "transform tensor not implemented yet";
3223   return false;
3224 }
3225 
3226 template <class T>
DoConvolveBackwardBiasImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)3227 bool MIOpenSupport::DoConvolveBackwardBiasImpl(
3228     Stream* stream, int miopen_type,  // Actually miopenDataType_t.
3229     const dnn::BatchDescriptor& input_descriptor,
3230     const DeviceMemory<T>& input_data,
3231     const dnn::BatchDescriptor& bias_descriptor,
3232     DeviceMemory<T>* backward_bias_data) {
3233   auto miopen = miopen_->GetHandle(parent_, stream);
3234 
3235   ScopedTensorDescriptor input_nd{input_descriptor,
3236                                   static_cast<miopenDataType_t>(miopen_type)};
3237   ScopedTensorDescriptor bias_nd{bias_descriptor,
3238                                  static_cast<miopenDataType_t>(miopen_type)};
3239 
3240   // Alpha is the scaling factor for input.
3241   float alpha = 1.0;
3242   // Beta is the scaling factor for output.
3243   float beta = 0.0;
3244 
3245   auto status = wrap::miopenConvolutionBackwardBias(
3246       miopen.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
3247       bias_nd.handle(), backward_bias_data->opaque());
3248   if (status != miopenStatusSuccess) {
3249     LOG(FATAL) << "failed to enqueue backward convolution on stream: "
3250                << ToString(status);
3251     return false;
3252   }
3253   return true;
3254 }
3255 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)3256 bool MIOpenSupport::DoConvolveBackwardBias(
3257     Stream* stream, const BatchDescriptor& input_descriptor,
3258     const DeviceMemory<double>& input_data,
3259     const BatchDescriptor& bias_descriptor,
3260     DeviceMemory<double>* backward_bias_data) {
3261   LOG(ERROR) << "miopen does not support double bwd bias yet";
3262   return false;
3263 }
3264 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)3265 bool MIOpenSupport::DoConvolveBackwardBias(
3266     Stream* stream, const BatchDescriptor& input_descriptor,
3267     const DeviceMemory<float>& input_data,
3268     const BatchDescriptor& bias_descriptor,
3269     DeviceMemory<float>* backward_bias_data) {
3270   return DoConvolveBackwardBiasImpl(stream, miopenFloat, input_descriptor,
3271                                     input_data, bias_descriptor,
3272                                     backward_bias_data);
3273 }
3274 
DoConvolveBackwardBias(Stream * stream,const BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)3275 bool MIOpenSupport::DoConvolveBackwardBias(
3276     Stream* stream, const BatchDescriptor& input_descriptor,
3277     const DeviceMemory<Eigen::half>& input_data,
3278     const BatchDescriptor& bias_descriptor,
3279     DeviceMemory<Eigen::half>* backward_bias_data) {
3280   return DoConvolveBackwardBiasImpl(stream, miopenHalf, input_descriptor,
3281                                     input_data, bias_descriptor,
3282                                     backward_bias_data);
3283 }
3284 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3285 bool MIOpenSupport::DoMatMul(Stream* stream,
3286                              const DeviceMemory<float>& input_data,
3287                              const DeviceMemory<float>& weights,
3288                              const dnn::BatchDescriptor& input_dimensions,
3289                              const dnn::BatchDescriptor& output_dimensions,
3290                              DeviceMemory<float>* output_data) {
3291   if (input_dimensions.count() != output_dimensions.count()) {
3292     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3293     return false;
3294   }
3295 
3296   // We do not permute the input or output, instead we just
3297   // reinterpret the layout. We are working with row-major matrices
3298   // and the rows of the input and output correspond to batch, so
3299   // batch has to be outermost in both the input and output.
3300   //
3301   // By adding transposes to the BLAS gemm call we could perhaps make
3302   // the kYXDepthBatch layout work as well, but there has been no need
3303   // for that so far.
3304   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3305       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3306     LOG(ERROR) << "Unsupported MatMul input layout.";
3307     return false;
3308   }
3309   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3310       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3311     LOG(ERROR) << "Unsupported MatMul output layout.";
3312     return false;
3313   }
3314 
3315   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3316     // This is a fast path that also supports the kBatchYXDepth layout.
3317 
3318     // The matrices here are in row-major format while BLAS expects
3319     // column-major, i.e. our matrices are transposed as far as BLAS
3320     // is concerned. So we need to compute output^T =
3321     // input^T*weights^T. There is no parameter for transposing the
3322     // output in BLAS gemm, but instead we can transpose both sides of
3323     // the equality to see that this is equivalent to
3324     // output=weights*input. So we only need to swap the order of
3325     // weights and input in the matrix product to correct for the
3326     // row-major versus column-major difference.
3327     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3328     const float beta = 0.0f;   // Ignore the original values in output_data.
3329     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
3330     const int64 n = input_dimensions.count();
3331     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
3332     stream->ThenBlasGemm(blas::Transpose::kNoTranspose,
3333                          blas::Transpose::kNoTranspose, m, n, k, alpha, weights,
3334                          m, input_data, k, beta, output_data, m);
3335   } else {
3336     // This is a slower and more complex path that supports output
3337     // width() * height() > 1, though it only supports the
3338     // kBatchYXDepth layout. Does support kBatchDepthYX if output
3339     // feature_map_count() == 1, as then there is no difference
3340     // between the two layouts.
3341     //
3342     // The operation here is the same as above, except that we have to
3343     // do the matrix multiplication for each (y,x) output coordinate
3344     // separately. We then interpret weights as containing K = width()
3345     // * height() different matrices, which we all multiply onto the
3346     // matrix from input_data, yielding K matrix products. We then
3347     // combine these together into one matrix by concatenating all the
3348     // first rows of these matrices, then all the seconds rows and so
3349     // on. We can do this with a batched matrix multiplication, where
3350     // the result is written to a different submatrix of the output
3351     // for each matrix multiplication.
3352     //
3353     // The reason that we only support the kBatchYXDepth output layout
3354     // is that we have to do something in the depth for each (y,x)
3355     // coordinate. The kBatchYXDepth layout has the depth information
3356     // for each point (y,x) in contiguous memory while the
3357     // kBatchDepthYX layout does not.
3358     //
3359     // TODO(broune): Consider a special case for when output depth ==
3360     // 1, as then possibly this could all be done as one matrix
3361     // multiplication instead of a batched one, which should be
3362     // faster. Another possibility would be to add a weights layout
3363     // parameter and then support kBatchDepthYX for a different
3364     // weights layout.
3365     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3366         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3367           output_dimensions.feature_map_count() == 1)) {
3368       LOG(ERROR) << "Unsupported MatMul output layout.";
3369       return false;
3370     }
3371 
3372     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3373     const float beta = 0.0f;   // Ignore the original values in output_data.
3374     const uint64 m = output_dimensions.feature_map_count();
3375     const uint64 n = input_dimensions.count();
3376     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
3377     const int lda = m;
3378     const int ldb = k;
3379     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3380     const int batch_count = output_dimensions.NodesPerFeatureMap();
3381 
3382     std::vector<DeviceMemory<float>> a(batch_count);
3383     std::vector<DeviceMemory<float>> b(batch_count);
3384     std::vector<DeviceMemory<float>> c(batch_count);
3385     for (int i = 0; i < batch_count; ++i) {
3386       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3387                                  output_dimensions.feature_map_count();
3388       a[i] = DeviceMemory<float>::MakeFromByteSize(
3389           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3390               weights_offset,
3391           weights.ElementCount() - weights_offset);
3392 
3393       b[i] = input_data;
3394 
3395       const int output_offset = i * output_dimensions.feature_map_count();
3396       c[i] = DeviceMemory<float>::MakeFromByteSize(
3397           const_cast<float*>(
3398               reinterpret_cast<const float*>(output_data->opaque())) +
3399               output_offset,
3400           output_data->ElementCount() - output_offset);
3401     }
3402     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3403       std::vector<DeviceMemory<float>*> ptrs;
3404       ptrs.reserve(v.size());
3405       for (auto& mem : v) {
3406         ptrs.push_back(&mem);
3407       }
3408       return ptrs;
3409     };
3410 
3411     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3412                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
3413                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3414                                 ldc, batch_count);
3415   }
3416 
3417   return stream->ok();
3418 }
3419 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3420 bool MIOpenSupport::DoBiasAdd(Stream* stream,
3421                               const DeviceMemory<float>& input_data,
3422                               const DeviceMemory<float>& biases,
3423                               const dnn::BatchDescriptor& dimensions,
3424                               DeviceMemory<float>* output_data) {
3425   ScopedTensorDescriptor input_descriptor{dimensions, miopenFloat};
3426 
3427   BatchDescriptor bias_dimensions;
3428   bias_dimensions.set_count(1)
3429       .set_feature_map_count(dimensions.feature_map_count())
3430       .set_height(1)
3431       .set_width(1)
3432       .set_layout(dnn::DataLayout::kBatchYXDepth);
3433   ScopedTensorDescriptor bias_descriptor{bias_dimensions, miopenFloat};
3434 
3435   if (input_data.opaque() != output_data->opaque()) {
3436     stream->ThenMemcpy(output_data, input_data,
3437                        dimensions.ElementCount() * sizeof(float));
3438     if (!stream->ok()) {
3439       LOG(ERROR)
3440           << "stream " << stream
3441           << " could not enqueue a tensor copy as part of bias addition.";
3442       return false;
3443     }
3444   }
3445 
3446   auto miopen = miopen_->GetHandle(parent_, stream);
3447 
3448   const float alpha1 = 1.0f;
3449   const float alpha2 = 0.0f;
3450   const float beta = 1.0f;
3451 
3452   auto status = wrap::miopenOpTensor(
3453       miopen.handle(), miopenTensorOpAdd, &alpha1, bias_descriptor.handle(),
3454       biases.opaque(), &alpha2, bias_descriptor.handle(), biases.opaque(),
3455       &beta, input_descriptor.handle(), output_data->opaque());
3456 
3457   if (status != miopenStatusSuccess) {
3458     LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
3459     return false;
3460   }
3461 
3462   return true;
3463 }
3464 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)3465 bool MIOpenSupport::DoActivate(Stream* stream,
3466                                dnn::ActivationMode activation_mode,
3467                                const dnn::BatchDescriptor& dimensions,
3468                                const DeviceMemory<float>& input_data,
3469                                DeviceMemory<float>* output_data,
3470                                uint64 options) {
3471   LOG(ERROR) << "miopen does not support activation yet";
3472   return false;
3473 }
3474 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)3475 bool MIOpenSupport::DoPoolForward(
3476     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3477     const dnn::BatchDescriptor& input_dimensions,
3478     const DeviceMemory<double>& input_data,
3479     const dnn::BatchDescriptor& output_dimensions,
3480     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
3481   LOG(ERROR) << "miopen does not support pooling for dobule type yet";
3482   return false;
3483 }
3484 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)3485 bool MIOpenSupport::DoPoolForward(
3486     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3487     const dnn::BatchDescriptor& input_dimensions,
3488     const DeviceMemory<float>& input_data,
3489     const dnn::BatchDescriptor& output_dimensions,
3490     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
3491   auto miopen = miopen_->GetHandle(parent_, stream);
3492 
3493   // Alpha is the scaling factor for input.
3494   float alpha = 1.0;
3495   // Beta is the scaling factor for output.
3496   float beta = 0.0;
3497 
3498   ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
3499   ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
3500   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
3501 
3502   auto status = wrap::miopenPoolingForward(
3503       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3504       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
3505       false, nullptr, 0);
3506   if (status != miopenStatusSuccess) {
3507     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
3508                << ToString(status);
3509     return false;
3510   }
3511   return true;
3512 }
3513 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)3514 bool MIOpenSupport::DoPoolForward(
3515     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3516     const dnn::BatchDescriptor& input_dimensions,
3517     const DeviceMemory<Eigen::half>& input_data,
3518     const dnn::BatchDescriptor& output_dimensions,
3519     DeviceMemory<Eigen::half>* output_data,
3520     ScratchAllocator* workspace_allocator) {
3521   auto miopen = miopen_->GetHandle(parent_, stream);
3522 
3523   // Alpha is the scaling factor for input.
3524   float alpha = 1.0;
3525   // Beta is the scaling factor for output.
3526   float beta = 0.0;
3527 
3528   ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
3529   ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
3530   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
3531 
3532   auto status = wrap::miopenPoolingForward(
3533       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3534       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
3535       false, nullptr, 0);
3536   if (status != miopenStatusSuccess) {
3537     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
3538                << ToString(status);
3539     return false;
3540   }
3541   return true;
3542 }
3543 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)3544 bool MIOpenSupport::DoPoolBackward(
3545     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3546     const dnn::BatchDescriptor& input_dimensions,
3547     const DeviceMemory<double>& input_data,
3548     const dnn::BatchDescriptor& output_dimensions,
3549     const DeviceMemory<double>& output_data,
3550     const DeviceMemory<double>& input_diff_data,
3551     DeviceMemory<double>* output_diff_data,
3552     ScratchAllocator* workspace_allocator) {
3553   LOG(ERROR) << "miopen does not support backward pooling on double type yet";
3554   return false;
3555 }
3556 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)3557 bool MIOpenSupport::DoPoolBackward(
3558     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3559     const dnn::BatchDescriptor& input_dimensions,
3560     const DeviceMemory<float>& input_data,
3561     const dnn::BatchDescriptor& output_dimensions,
3562     const DeviceMemory<float>& output_data,
3563     const DeviceMemory<float>& input_diff_data,
3564     DeviceMemory<float>* output_diff_data,
3565     ScratchAllocator* workspace_allocator) {
3566   auto miopen = miopen_->GetHandle(parent_, stream);
3567 
3568   // Alpha is the scaling factor for input.
3569   float alpha = 1.0;
3570   // Beta is the scaling factor for output.
3571   float beta = 0.0;
3572 
3573   ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
3574   ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
3575   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
3576 
3577   DeviceMemory<uint8> workspace;
3578   size_t workspace_size_in_bytes = 0;
3579   auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
3580                                                     &workspace_size_in_bytes);
3581 
3582   if (status != miopenStatusSuccess) {
3583     LOG(ERROR)
3584         << "failed to obtain workspace size for backward pooling on stream: "
3585         << ToString(status);
3586     return false;
3587   }
3588 
3589   // Allocate the workspace.
3590   if (workspace_size_in_bytes > 0) {
3591     assert(workspace_allocator);
3592     auto allocated =
3593         workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
3594     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
3595       LOG(ERROR) << "Failed to allocate backward pooling workspace";
3596       return false;
3597     }
3598   }
3599 
3600   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
3601   int dest2_size = 0;
3602 
3603   // miopen requires the strides and dims to be ordered as BDYX.
3604   std::vector<int64> dims64 =
3605       output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
3606 
3607   // miopen does not use strides and must have 4D tensor.
3608   std::vector<int> dims(4);
3609 
3610   std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
3611                  &CheckedNarrowing<int64, int>);
3612 
3613   dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
3614 
3615   if (dest2_size > 0) {
3616     assert(workspace_allocator);
3617     auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
3618     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
3619       LOG(ERROR) << "Failed to allocate backward pooling workspace";
3620       return false;
3621     }
3622   } else {
3623     LOG(ERROR) << "Failed to calcuate tensor size to chain forward and "
3624                   "backward pooling";
3625   }
3626 
3627   status = wrap::miopenPoolingForward(
3628       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3629       input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
3630       workspace.opaque(), workspace_size_in_bytes);
3631 
3632   if (status != miopenStatusSuccess) {
3633     LOG(ERROR)
3634         << "failed to enqueue forward pooling (before backward) on stream: "
3635         << ToString(status);
3636     return false;
3637   }
3638 
3639   status = wrap::miopenPoolingBackward(
3640       miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3641       dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3642       src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3643       output_diff_data->opaque(), workspace.opaque());
3644 
3645   if (status != miopenStatusSuccess) {
3646     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
3647                << ToString(status);
3648     return false;
3649   }
3650   return true;
3651 }
3652 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)3653 bool MIOpenSupport::DoPoolBackward(
3654     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3655     const dnn::BatchDescriptor& input_dimensions,
3656     const DeviceMemory<Eigen::half>& input_data,
3657     const dnn::BatchDescriptor& output_dimensions,
3658     const DeviceMemory<Eigen::half>& output_data,
3659     const DeviceMemory<Eigen::half>& input_diff_data,
3660     DeviceMemory<Eigen::half>* output_diff_data,
3661     ScratchAllocator* workspace_allocator) {
3662   auto miopen = miopen_->GetHandle(parent_, stream);
3663 
3664   // Alpha is the scaling factor for input.
3665   float alpha = 1.0;
3666   // Beta is the scaling factor for output.
3667   float beta = 0.0;
3668 
3669   ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
3670   ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
3671   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
3672 
3673   DeviceMemory<uint8> workspace;
3674   size_t workspace_size_in_bytes = 0;
3675   auto status = wrap::miopenPoolingGetWorkSpaceSize(dest_desc.handle(),
3676                                                     &workspace_size_in_bytes);
3677 
3678   if (status != miopenStatusSuccess) {
3679     LOG(ERROR)
3680         << "failed to obtain workspace size for backward pooling on stream: "
3681         << ToString(status);
3682     return false;
3683   }
3684 
3685   // Allocate the workspace.
3686   if (workspace_size_in_bytes > 0) {
3687     assert(workspace_allocator);
3688     auto allocated =
3689         workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
3690     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
3691       LOG(ERROR) << "Failed to allocate backward pooling workspace";
3692       return false;
3693     }
3694   }
3695 
3696   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
3697   int dest2_size = 0;
3698 
3699   // miopen requires the strides and dims to be ordered as BDYX.
3700   std::vector<int64> dims64 =
3701       output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
3702 
3703   // miopen does not use strides and must have 4D tensor.
3704   std::vector<int> dims(4);
3705 
3706   std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
3707                  &CheckedNarrowing<int64, int>);
3708 
3709   dest2_size = dims[0] * dims[1] * dims[2] * dims[3] * sizeof(float);
3710 
3711   if (dest2_size > 0) {
3712     assert(workspace_allocator);
3713     auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
3714     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
3715       LOG(ERROR) << "Failed to allocate backward pooling workspace";
3716       return false;
3717     }
3718   } else {
3719     LOG(ERROR) << "Failed to calcuate tensor size to chain forward and "
3720                   "backward pooling";
3721   }
3722 
3723   status = wrap::miopenPoolingForward(
3724       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
3725       input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
3726       workspace.opaque(), workspace_size_in_bytes);
3727 
3728   if (status != miopenStatusSuccess) {
3729     LOG(ERROR)
3730         << "failed to enqueue forward pooling (before backward) on stream: "
3731         << ToString(status);
3732     return false;
3733   }
3734 
3735   status = wrap::miopenPoolingBackward(
3736       miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
3737       dest2.opaque(), dest_desc.handle(), input_diff_data.opaque(),
3738       src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
3739       output_diff_data->opaque(), workspace.opaque());
3740 
3741   if (status != miopenStatusSuccess) {
3742     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
3743                << ToString(status);
3744     return false;
3745   }
3746   return true;
3747 }
3748 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)3749 bool MIOpenSupport::DoNormalizeWithDimensions(
3750     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3751     const dnn::BatchDescriptor& dimensions,
3752     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
3753   // Check for unsupported modes.
3754   if (normalize_descriptor.wrap_around()) {
3755     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
3756     return false;
3757   }
3758   if (normalize_descriptor.segment_size()) {
3759     LOG(ERROR) << "MIOpen LRN does not support segmentation";
3760     return false;
3761   }
3762 
3763   auto miopen = miopen_->GetHandle(parent_, stream);
3764 
3765   // Launch the normalization.
3766   ScopedTensorDescriptor dims{dimensions, miopenFloat};
3767   ScopedNormalizeDescriptor normalize{normalize_descriptor};
3768 
3769   // Alpha is the scaling factor for input.
3770   float alpha = 1.0f;
3771   // Beta is the scaling factor for output.
3772   float beta = 0.0f;
3773 
3774   auto status = wrap::miopenLRNForward(
3775       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
3776       input_data.opaque(), &beta, dims.handle(), output_data->opaque(), false,
3777       nullptr);
3778   if (status != miopenStatusSuccess) {
3779     LOG(ERROR) << "failed to run miopenLRNForward";
3780     return false;
3781   }
3782   return true;
3783 }
3784 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)3785 bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
3786     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
3787     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
3788     const DeviceMemory<float>& normalized_data,
3789     const DeviceMemory<float>& normalized_variable_gradient,
3790     DeviceMemory<float>* raw_variable_gradient,
3791     ScratchAllocator* workspace_allocator) {
3792   // Check for unsupported modes.
3793   if (normalize_descriptor.wrap_around()) {
3794     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
3795     return false;
3796   }
3797   if (normalize_descriptor.segment_size()) {
3798     LOG(ERROR) << "MIOpen LRN does not support segmentation";
3799     return false;
3800   }
3801 
3802   auto miopen = miopen_->GetHandle(parent_, stream);
3803 
3804   ScopedTensorDescriptor dims{dimensions, miopenFloat};
3805   ScopedNormalizeDescriptor normalize{normalize_descriptor};
3806 
3807   float alpha = 1.0f;
3808   float beta = 0.0f;
3809 
3810   DeviceMemory<uint8> workspace;
3811   size_t workspace_size_in_bytes = 0;
3812   auto status =
3813       wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes);
3814 
3815   if (status != miopenStatusSuccess) {
3816     LOG(ERROR) << "failed to obtain workspace size for miopenLRNBackward";
3817     return false;
3818   }
3819 
3820   // Allocate the workspace.
3821   if (workspace_size_in_bytes > 0) {
3822     assert(workspace_allocator);
3823     auto allocated =
3824         workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
3825     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
3826       LOG(ERROR) << "Failed to allocate backward pooling workspace";
3827       return false;
3828     }
3829   }
3830 
3831   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
3832   int dest2_size = 0;
3833 
3834   // miopen requires the strides and dims to be ordered as BDYX.
3835   std::vector<int64> dims64 =
3836       dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
3837 
3838   // miopen does not use strides and must have 4D tensor.
3839   std::vector<int> dimsint(4);
3840 
3841   std::transform(dims64.cbegin(), dims64.cend(), dimsint.begin(),
3842                  &CheckedNarrowing<int64, int>);
3843 
3844   dest2_size =
3845       dimsint[0] * dimsint[1] * dimsint[2] * dimsint[3] * sizeof(float);
3846 
3847   if (dest2_size > 0) {
3848     assert(workspace_allocator);
3849     auto allocated = workspace_allocator->AllocateBytes(stream, dest2_size);
3850     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
3851       LOG(ERROR)
3852           << "Failed to allocate tensor to chain forward and backward LRN";
3853       return false;
3854     }
3855   } else {
3856     LOG(ERROR)
3857         << "Failed to calcuate tensor size to chain forward and backward LRN";
3858   }
3859 
3860   status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
3861                                   dims.handle(), raw_data.opaque(), &beta,
3862                                   dims.handle(), dest2.opaque(), true,
3863                                   workspace.opaque());
3864 
3865   if (status != miopenStatusSuccess) {
3866     LOG(ERROR) << "failed to run miopenLRNForward";
3867     return false;
3868   }
3869 
3870   status = wrap::miopenLRNBackward(
3871       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
3872       normalized_data.opaque(), dims.handle(),
3873       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
3874       &beta, dims.handle(), raw_variable_gradient->opaque(),
3875       workspace.opaque());
3876 
3877   if (status != miopenStatusSuccess) {
3878     LOG(ERROR) << "failed to run miopenLRNBackward";
3879     return false;
3880   }
3881   return true;
3882 }
3883 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)3884 bool MIOpenSupport::DoDepthConcatenate(
3885     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
3886     port::ArraySlice<const DeviceMemory<float>*> input_data,
3887     DeviceMemory<float>* output_data) {
3888   CHECK_EQ(input_dimensions.size(), input_data.size());
3889 
3890   for (const auto& dimensions : input_dimensions) {
3891     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3892       LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only "
3893                     "supports the kBatchDepthYX layout.";
3894       return false;
3895     }
3896   }
3897 
3898   if (input_dimensions.empty()) {
3899     return true;  // Nothing to do.
3900   }
3901 
3902   dnn::BatchDescriptor output_dimensions =
3903       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
3904 
3905   const int64 area = output_dimensions.width() * output_dimensions.height();
3906   const auto index = [area](int64 batch, int64 depth, int64 yx,
3907                             int64 max_depth) {
3908     return (batch * max_depth + depth) * area + yx;
3909   };
3910 
3911   std::vector<float> output_host(output_dimensions.ElementCount());
3912   std::vector<float> tmp;
3913   int64 depth_sum = 0;
3914   for (size_t i = 0; i < input_data.size(); ++i) {
3915     const auto& dimensions = input_dimensions[i];
3916     tmp.resize(dimensions.ElementCount());
3917     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
3918     port::Status block_status = stream->BlockHostUntilDone();
3919     if (!block_status.ok()) {
3920       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
3921       return false;
3922     }
3923 
3924     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
3925       for (int64 yx = 0; yx < area; ++yx) {
3926         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
3927           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
3928                     << yx << ' ' << depth;
3929           output_host[index(batch, depth + depth_sum, yx,
3930                             output_dimensions.feature_map_count())] =
3931               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
3932         }
3933       }
3934     }
3935     depth_sum += dimensions.feature_map_count();
3936   }
3937   stream->ThenMemcpyH2D<float>(output_host, output_data);
3938   return true;
3939 }
3940 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3941 bool MIOpenSupport::DoElementwiseOperate(
3942     Stream* stream, dnn::ElementwiseOperation operation,
3943     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
3944     port::ArraySlice<const DeviceMemory<float>*> input_data,
3945     const dnn::BatchDescriptor& output_dimensions,
3946     DeviceMemory<float>* output_data) {
3947   LOG(FATAL) << "not yet implemented";  // TODO(leary)
3948   return false;
3949 }
3950 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)3951 bool MIOpenSupport::DoXYPad(Stream* stream,
3952                             const dnn::BatchDescriptor& dimensions,
3953                             const DeviceMemory<float>& input_data,
3954                             int64 left_pad, int64 right_pad, int64 top_pad,
3955                             int64 bottom_pad,
3956                             DeviceMemory<float>* output_data) {
3957   LOG(FATAL) << "not yet implemented";  // TODO(leary)
3958   return false;
3959 }
3960 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)3961 bool MIOpenSupport::DoXYSlice(Stream* stream,
3962                               const dnn::BatchDescriptor& dimensions,
3963                               const DeviceMemory<float>& input_data,
3964                               int64 left_trim, int64 right_trim, int64 top_trim,
3965                               int64 bottom_trim,
3966                               DeviceMemory<float>* output_data) {
3967   LOG(FATAL) << "not yet implemented";  // TODO(leary)
3968   return false;
3969 }
3970 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)3971 bool MIOpenSupport::DoMemcpyD2HQuantized(
3972     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
3973     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
3974   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
3975   return false;
3976 }
3977 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)3978 bool MIOpenSupport::DoMemcpyH2DQuantized(
3979     Stream* stream, const void* host_src, int64 size,
3980     dnn::QuantizedActivationMode mode,
3981     DeviceMemory<float>* gpu_unquantized_dst) {
3982   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
3983   return false;
3984 }
3985 
DeriveOutputBatchDescriptor(const BatchDescriptor & batch_descriptor,const FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)3986 bool MIOpenSupport::DeriveOutputBatchDescriptor(
3987     const BatchDescriptor& batch_descriptor,
3988     const FilterDescriptor& filter_descriptor,
3989     const dnn::ConvolutionDescriptor& convolution_descriptor,
3990     dnn::BatchDescriptor* output_batch_descriptor) {
3991   ScopedTensorDescriptor input_nd{batch_descriptor, miopenFloat};
3992   ScopedFilterDescriptor filter{filter_descriptor, batch_descriptor,
3993                                 miopenFloat};
3994   ScopedConvolutionDescriptor conv{convolution_descriptor, miopenFloat};
3995 
3996   int dn = batch_descriptor.ndims() + 2;
3997   std::vector<int> dims(dn);  // in BDYX
3998   auto status = wrap::miopenGetConvolutionForwardOutputDim(
3999       conv.handle(), input_nd.handle(), filter.handle(), &dims[0], &dims[1],
4000       &dims[2], &dims[3]);
4001   if (status != miopenStatusSuccess) {
4002     LOG(ERROR) << "could not get output tensor for convolution: "
4003                << ToString(status);
4004     return false;
4005   }
4006 
4007   output_batch_descriptor->set_count(dims[0])
4008       .set_feature_map_count(dims[1])
4009       .set_layout(batch_descriptor.layout());
4010 
4011   for (int i = 0; i < batch_descriptor.ndims(); i++) {
4012     output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4013                                              dims.rbegin()[i]);
4014   }
4015 
4016   return true;
4017 }
4018 
4019 template <typename T>
DoFusedConvolutionBiasActivationImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<T> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<T> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<T> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<T> * output_data,dnn::ProfileResult * output_profile_result)4020 bool MIOpenSupport::DoFusedConvolutionBiasActivationImpl(
4021     Stream* stream,
4022     int miopen_type,  // Actually miopenDataType_t.
4023     const dnn::BatchDescriptor& conv_input_descriptor,
4024     const DeviceMemory<T>& conv_input_data,
4025     const dnn::FilterDescriptor& filter_descriptor,
4026     const DeviceMemory<T>& filter_data,
4027     const dnn::ConvolutionDescriptor& convolution_descriptor,
4028     const dnn::BatchDescriptor& bias_descriptor,
4029     const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
4030     const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
4031     dnn::ProfileResult* output_profile_result) {
4032   auto miopen = miopen_->GetHandle(parent_, stream);
4033 
4034   ScopedTensorDescriptor conv_input_nd{
4035       conv_input_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4036 
4037   ScopedTensorDescriptor bias_nd{bias_descriptor,
4038                                  static_cast<miopenDataType_t>(miopen_type)};
4039 
4040   ScopedTensorDescriptor output_nd{output_descriptor,
4041                                    static_cast<miopenDataType_t>(miopen_type)};
4042 
4043   ScopedConvolutionDescriptor conv{convolution_descriptor,
4044                                    static_cast<miopenDataType_t>(miopen_type)};
4045 
4046   ScopedFilterDescriptor filter{filter_descriptor, conv_input_descriptor,
4047                                 static_cast<miopenDataType_t>(miopen_type)};
4048 
4049   ScopedActivationDescriptor activation_desc{activation_mode};
4050 
4051   ScopedFusionPlanConvolutionBiasActivation fusion_plan{
4052       miopen.handle(), conv_input_nd.handle(), filter.handle(),
4053       conv.handle(),   bias_nd.handle(),       activation_desc};
4054 
4055   bool retval = false;
4056 
4057   if (fusion_plan.CompilationSucceeded()) {
4058     const bool is_profiling = output_profile_result != nullptr;
4059 
4060     std::unique_ptr<GpuTimer> timer;
4061     if (is_profiling) {
4062       timer.reset(new GpuTimer(parent_));
4063       timer->Init();
4064       timer->Start(AsGpuStream(stream));
4065     }
4066 
4067     miopenStatus_t status = miopenStatusSuccess;
4068 
4069     if (status == miopenStatusSuccess) {
4070       fusion_plan.SetConvolutionArgs(filter_data.opaque());
4071     }
4072 
4073     if (status == miopenStatusSuccess) {
4074       status = fusion_plan.SetBiasArgs(bias_data.opaque());
4075     }
4076 
4077     if (status == miopenStatusSuccess) {
4078       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4079     }
4080 
4081     if (status == miopenStatusSuccess) {
4082       status =
4083           fusion_plan.Execute(conv_input_nd.handle(), conv_input_data.opaque(),
4084                               output_nd.handle(), output_data->opaque());
4085     }
4086 
4087     if (is_profiling) {
4088       timer->Stop(AsGpuStream(stream));
4089       if (status == miopenStatusSuccess) {
4090         output_profile_result->set_elapsed_time_in_ms(
4091             timer->GetElapsedMilliseconds());
4092       }
4093       timer->Destroy();
4094     }
4095 
4096     if (status != miopenStatusSuccess) {
4097       // Silently return when we are profiling.
4098       if (!is_profiling) {
4099         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4100                    << ToString(status);
4101       }
4102     }
4103 
4104     retval = true;
4105   }
4106 
4107   return retval;
4108 }
4109 
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)4110 bool MIOpenSupport::DoFusedConvolutionBiasActivation(
4111     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
4112     const DeviceMemory<float>& conv_input_data,
4113     const dnn::FilterDescriptor& filter_descriptor,
4114     const DeviceMemory<float>& filter_data,
4115     const dnn::ConvolutionDescriptor& convolution_descriptor,
4116     const dnn::BatchDescriptor& bias_descriptor,
4117     const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
4118     const dnn::BatchDescriptor& output_descriptor,
4119     DeviceMemory<float>* output_data,
4120     dnn::ProfileResult* output_profile_result) {
4121   return DoFusedConvolutionBiasActivationImpl<float>(
4122       stream, miopenFloat, conv_input_descriptor, conv_input_data,
4123       filter_descriptor, filter_data, convolution_descriptor, bias_descriptor,
4124       bias_data, activation_mode, output_descriptor, output_data,
4125       output_profile_result);
4126 }
4127 
4128 template <typename T, typename U>
DoFusedBatchNormActivationInferenceImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & mean_data,const DeviceMemory<U> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,dnn::ProfileResult * output_profile_result)4129 bool MIOpenSupport::DoFusedBatchNormActivationInferenceImpl(
4130     Stream* stream,
4131     int miopen_type,  // Actually miopenDataType_t.
4132     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4133     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4134     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4135     const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
4136     double epsilon, dnn::ActivationMode activation_mode,
4137     DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result) {
4138   auto miopen = miopen_->GetHandle(parent_, stream);
4139 
4140   ScopedTensorDescriptor x_nd{x_descriptor,
4141                               static_cast<miopenDataType_t>(miopen_type)};
4142 
4143   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4144       scale_offset_mean_variance_descriptor,
4145       static_cast<miopenDataType_t>(miopen_type)};
4146 
4147   ScopedActivationDescriptor activation_desc{activation_mode};
4148 
4149   ScopedFusionPlanBatchNormActivationInference fusion_plan{
4150       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4151       activation_desc};
4152 
4153   bool retval = false;
4154 
4155   if (fusion_plan.CompilationSucceeded()) {
4156     const bool is_profiling = output_profile_result != nullptr;
4157 
4158     std::unique_ptr<GpuTimer> timer;
4159     if (is_profiling) {
4160       timer.reset(new GpuTimer(parent_));
4161       timer->Init();
4162       timer->Start(AsGpuStream(stream));
4163     }
4164 
4165     miopenStatus_t status = miopenStatusSuccess;
4166 
4167     if (status == miopenStatusSuccess) {
4168       fusion_plan.SetBatchNormInferenceArgs(
4169           scale_data.opaque(), offset_data.opaque(), mean_data.opaque(),
4170           variance_data.opaque(), epsilon);
4171     }
4172 
4173     if (status == miopenStatusSuccess) {
4174       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4175     }
4176 
4177     if (status == miopenStatusSuccess) {
4178       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4179                                    x_nd.handle(), y_data->opaque());
4180     }
4181 
4182     if (is_profiling) {
4183       timer->Stop(AsGpuStream(stream));
4184       if (status == miopenStatusSuccess) {
4185         output_profile_result->set_elapsed_time_in_ms(
4186             timer->GetElapsedMilliseconds());
4187       }
4188       timer->Destroy();
4189     }
4190 
4191     if (status != miopenStatusSuccess) {
4192       // Silently return when we are profiling.
4193       if (!is_profiling) {
4194         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4195                    << ToString(status);
4196       }
4197     }
4198 
4199     retval = true;
4200   }
4201 
4202   return retval;
4203 }
4204 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)4205 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4206     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4207     const DeviceMemory<float>& x_data,
4208     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4209     const DeviceMemory<float>& scale_data,
4210     const DeviceMemory<float>& offset_data,
4211     const DeviceMemory<float>& mean_data,
4212     const DeviceMemory<float>& variance_data, double epsilon,
4213     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4214     dnn::ProfileResult* output_profile_result) {
4215   return DoFusedBatchNormActivationInferenceImpl<float, float>(
4216       stream, miopenFloat, x_descriptor, x_data,
4217       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4218       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4219 }
4220 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)4221 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4222     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4223     const DeviceMemory<Eigen::half>& x_data,
4224     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4225     const DeviceMemory<float>& scale_data,
4226     const DeviceMemory<float>& offset_data,
4227     const DeviceMemory<float>& mean_data,
4228     const DeviceMemory<float>& variance_data, double epsilon,
4229     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4230     dnn::ProfileResult* output_profile_result) {
4231   return DoFusedBatchNormActivationInferenceImpl<Eigen::half, float>(
4232       stream, miopenHalf, x_descriptor, x_data,
4233       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4234       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4235 }
4236 
4237 template <typename T, typename U>
DoFusedBatchNormActivationForwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,DeviceMemory<U> * batch_mean_data,DeviceMemory<U> * batch_var_data,DeviceMemory<U> * saved_mean_data,DeviceMemory<U> * saved_var_data,dnn::ProfileResult * output_profile_result)4238 bool MIOpenSupport::DoFusedBatchNormActivationForwardImpl(
4239     Stream* stream,
4240     int miopen_type,  // Actually miopenDataType_t.
4241     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4242     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4243     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4244     double epsilon, dnn::ActivationMode activation_mode,
4245     DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
4246     DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
4247     DeviceMemory<U>* saved_var_data,
4248     dnn::ProfileResult* output_profile_result) {
4249   auto miopen = miopen_->GetHandle(parent_, stream);
4250 
4251   ScopedTensorDescriptor x_nd{x_descriptor,
4252                               static_cast<miopenDataType_t>(miopen_type)};
4253 
4254   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4255       scale_offset_mean_variance_descriptor,
4256       static_cast<miopenDataType_t>(miopen_type)};
4257 
4258   ScopedActivationDescriptor activation_desc{activation_mode};
4259 
4260   ScopedFusionPlanBatchNormActivationForward fusion_plan{
4261       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4262       activation_desc};
4263 
4264   bool retval = false;
4265 
4266   if (fusion_plan.CompilationSucceeded()) {
4267     const bool is_profiling = output_profile_result != nullptr;
4268 
4269     std::unique_ptr<GpuTimer> timer;
4270     if (is_profiling) {
4271       timer.reset(new GpuTimer(parent_));
4272       timer->Init();
4273       timer->Start(AsGpuStream(stream));
4274     }
4275 
4276     miopenStatus_t status = miopenStatusSuccess;
4277 
4278     if (status == miopenStatusSuccess) {
4279       fusion_plan.SetBatchNormForwardArgs(
4280           scale_data.opaque(), offset_data.opaque(), batch_mean_data->opaque(),
4281           batch_var_data->opaque(), saved_mean_data->opaque(),
4282           saved_var_data->opaque(), epsilon);
4283     }
4284 
4285     if (status == miopenStatusSuccess) {
4286       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4287     }
4288 
4289     if (status == miopenStatusSuccess) {
4290       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4291                                    x_nd.handle(), y_data->opaque());
4292     }
4293 
4294     if (is_profiling) {
4295       timer->Stop(AsGpuStream(stream));
4296       if (status == miopenStatusSuccess) {
4297         output_profile_result->set_elapsed_time_in_ms(
4298             timer->GetElapsedMilliseconds());
4299       }
4300       timer->Destroy();
4301     }
4302 
4303     if (status != miopenStatusSuccess) {
4304       // Silently return when we are profiling.
4305       if (!is_profiling) {
4306         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4307                    << ToString(status);
4308       }
4309     }
4310 
4311     retval = true;
4312   }
4313 
4314   return retval;
4315 }
4316 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4317 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4318     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4319     const DeviceMemory<float>& x_data,
4320     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4321     const DeviceMemory<float>& scale_data,
4322     const DeviceMemory<float>& offset_data, double epsilon,
4323     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4324     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4325     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4326     dnn::ProfileResult* output_profile_result) {
4327   return DoFusedBatchNormActivationForwardImpl<float, float>(
4328       stream, miopenFloat, x_descriptor, x_data,
4329       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4330       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4331       saved_var_data, output_profile_result);
4332 }
4333 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4334 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4335     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4336     const DeviceMemory<Eigen::half>& x_data,
4337     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4338     const DeviceMemory<float>& scale_data,
4339     const DeviceMemory<float>& offset_data, double epsilon,
4340     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4341     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4342     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4343     dnn::ProfileResult* output_profile_result) {
4344   return DoFusedBatchNormActivationForwardImpl<Eigen::half, float>(
4345       stream, miopenHalf, x_descriptor, x_data,
4346       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4347       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4348       saved_var_data, output_profile_result);
4349 }
4350 
4351 template <typename T, typename U>
DoFusedBatchNormActivationBackwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<T> & y_act_backprop_data,const DeviceMemory<T> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<T> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & saved_mean_data,const DeviceMemory<U> & saved_var_data,DeviceMemory<T> * x_bn_backprop_data,DeviceMemory<U> * scale_backprop_data,DeviceMemory<U> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4352 bool MIOpenSupport::DoFusedBatchNormActivationBackwardImpl(
4353     Stream* stream,
4354     int miopen_type,  // Actually miopenDataType_t.
4355     const dnn::BatchDescriptor& y_act_backprop_descriptor,
4356     const DeviceMemory<T>& y_act_backprop_data,
4357     const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
4358     const DeviceMemory<T>& x_bn_data,
4359     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4360     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4361     const DeviceMemory<U>& saved_mean_data,
4362     const DeviceMemory<U>& saved_var_data, DeviceMemory<T>* x_bn_backprop_data,
4363     DeviceMemory<U>* scale_backprop_data, DeviceMemory<U>* offset_backprop_data,
4364     dnn::ProfileResult* output_profile_result) {
4365   auto miopen = miopen_->GetHandle(parent_, stream);
4366 
4367   ScopedTensorDescriptor y_act_backprop_nd{
4368       y_act_backprop_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4369 
4370   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4371       scale_offset_mean_variance_descriptor,
4372       static_cast<miopenDataType_t>(miopen_type)};
4373 
4374   ScopedActivationDescriptor activation_desc{activation_mode};
4375 
4376   ScopedFusionPlanBatchNormActivationBackward fusion_plan{
4377       miopen.handle(), y_act_backprop_nd.handle(),
4378       scale_offset_mean_variance_nd.handle(), activation_desc};
4379 
4380   bool retval = false;
4381 
4382   if (fusion_plan.CompilationSucceeded()) {
4383     const bool is_profiling = output_profile_result != nullptr;
4384 
4385     std::unique_ptr<GpuTimer> timer;
4386     if (is_profiling) {
4387       timer.reset(new GpuTimer(parent_));
4388       timer->Init();
4389       timer->Start(AsGpuStream(stream));
4390     }
4391 
4392     miopenStatus_t status = miopenStatusSuccess;
4393 
4394     if (status == miopenStatusSuccess) {
4395       fusion_plan.SetBatchNormBackwardArgs(
4396           x_bn_data.opaque(), scale_data.opaque(), offset_data.opaque(),
4397           saved_mean_data.opaque(), saved_var_data.opaque(),
4398           scale_backprop_data->opaque(), offset_backprop_data->opaque());
4399     }
4400 
4401     if (status == miopenStatusSuccess) {
4402       status = fusion_plan.SetActivationBackwardArgs(activation_desc,
4403                                                      y_act_data.opaque());
4404     }
4405 
4406     if (status == miopenStatusSuccess) {
4407       status = fusion_plan.Execute(
4408           y_act_backprop_nd.handle(), y_act_backprop_data.opaque(),
4409           y_act_backprop_nd.handle(), x_bn_backprop_data->opaque());
4410     }
4411 
4412     if (is_profiling) {
4413       timer->Stop(AsGpuStream(stream));
4414       if (status == miopenStatusSuccess) {
4415         output_profile_result->set_elapsed_time_in_ms(
4416             timer->GetElapsedMilliseconds());
4417       }
4418       timer->Destroy();
4419     }
4420 
4421     if (status != miopenStatusSuccess) {
4422       // Silently return when we are profiling.
4423       if (!is_profiling) {
4424         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4425                    << ToString(status);
4426       }
4427     }
4428 
4429     retval = true;
4430   }
4431 
4432   return retval;
4433 }
4434 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4435 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4436     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4437     const DeviceMemory<float>& y_act_backprop_data,
4438     const DeviceMemory<float>& y_act_data, dnn::ActivationMode activation_mode,
4439     const DeviceMemory<float>& x_bn_data,
4440     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4441     const DeviceMemory<float>& scale_data,
4442     const DeviceMemory<float>& offset_data,
4443     const DeviceMemory<float>& saved_mean_data,
4444     const DeviceMemory<float>& saved_var_data,
4445     DeviceMemory<float>* x_bn_backprop_data,
4446     DeviceMemory<float>* scale_backprop_data,
4447     DeviceMemory<float>* offset_backprop_data,
4448     dnn::ProfileResult* output_profile_result) {
4449   return DoFusedBatchNormActivationBackwardImpl<float, float>(
4450       stream, miopenFloat, y_act_backprop_descriptor, y_act_backprop_data,
4451       y_act_data, activation_mode, x_bn_data,
4452       scale_offset_mean_variance_descriptor, scale_data, offset_data,
4453       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
4454       offset_backprop_data, output_profile_result);
4455 }
4456 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4457 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4458     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4459     const DeviceMemory<Eigen::half>& y_act_backprop_data,
4460     const DeviceMemory<Eigen::half>& y_act_data,
4461     dnn::ActivationMode activation_mode,
4462     const DeviceMemory<Eigen::half>& x_bn_data,
4463     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4464     const DeviceMemory<float>& scale_data,
4465     const DeviceMemory<float>& offset_data,
4466     const DeviceMemory<float>& saved_mean_data,
4467     const DeviceMemory<float>& saved_var_data,
4468     DeviceMemory<Eigen::half>* x_bn_backprop_data,
4469     DeviceMemory<float>* scale_backprop_data,
4470     DeviceMemory<float>* offset_backprop_data,
4471     dnn::ProfileResult* output_profile_result) {
4472   return DoFusedBatchNormActivationBackwardImpl<Eigen::half, float>(
4473       stream, miopenHalf, y_act_backprop_descriptor, y_act_backprop_data,
4474       y_act_data, activation_mode, x_bn_data,
4475       scale_offset_mean_variance_descriptor, scale_data, offset_data,
4476       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
4477       offset_backprop_data, output_profile_result);
4478 }
4479 
4480 }  // namespace gpu
4481 
initialize_miopen()4482 void initialize_miopen() {
4483   auto miopenAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
4484       rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
4485 
4486   if (!miopenAlreadyRegistered) {
4487     port::Status status =
4488         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4489             rocm::kROCmPlatformId, gpu::kMIOpenPlugin, "MIOpen",
4490             [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4491               gpu::GpuExecutor* rocm_executor =
4492                   dynamic_cast<gpu::GpuExecutor*>(parent);
4493               if (rocm_executor == nullptr) {
4494                 LOG(ERROR)
4495                     << "Attempting to initialize an instance of the MIOpen "
4496                     << "support library with a non-ROCM StreamExecutor";
4497                 return nullptr;
4498               }
4499 
4500               gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(rocm_executor);
4501               if (!dnn->Init().ok()) {
4502                 // Note: Init() will log a more specific error.
4503                 delete dnn;
4504                 return nullptr;
4505               }
4506               return dnn;
4507             });
4508 
4509     if (!status.ok()) {
4510       LOG(ERROR) << "Unable to register MIOpen factory: "
4511                  << status.error_message();
4512     }
4513 
4514     PluginRegistry::Instance()->SetDefaultFactory(
4515         rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
4516   }
4517 }
4518 
4519 }  // namespace stream_executor
4520 
4521 REGISTER_MODULE_INITIALIZER(register_miopen,
4522                             { stream_executor::initialize_miopen(); });
4523