1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3    Licensed under the Apache License, Version 2.0 (the "License");
4    you may not use this file except in compliance with the License.
5    You may obtain a copy of the License at
6 
7    http://www.apache.org/licenses/LICENSE-2.0
8 
9    Unless required by applicable law or agreed to in writing, software
10    distributed under the License is distributed on an "AS IS" BASIS,
11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12    See the License for the specific language governing permissions and
13    limitations under the License.
14    ==============================================================================
15 */
16 #if TENSORFLOW_USE_ROCM
17 #include "tensorflow/core/util/rocm_solvers.h"
18 
19 #include <complex>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "rocm/include/rocblas.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
35 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
36 #include "tensorflow/stream_executor/lib/env.h"
37 #include "tensorflow/stream_executor/platform/default/dso_loader.h"
38 #include "tensorflow/stream_executor/platform/port.h"
39 
40 namespace tensorflow {
41 namespace {
42 
43 using stream_executor::gpu::GpuExecutor;
44 using stream_executor::gpu::ScopedActivateExecutorContext;
45 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
46 
47 namespace wrap {
48 #ifdef PLATFORM_GOOGLE
49 #define ROCBLAS_WRAP(__name)                                       \
50   struct WrapperShim__##__name {                                   \
51     static const char* kName;                                      \
52     template <typename... Args>                                    \
53     rocblas_status operator()(GpuExecutor* parent, Args... args) { \
54       ScopedActivateExecutorContext sac{parent};                   \
55       return ::__name(args...);                                    \
56     }                                                              \
57   } __name;                                                        \
58   const char* WrapperShim__##__name::kName = #__name;
59 
60 #else
61 
62 #define ROCBLAS_WRAP(__name)                                                \
63   struct DynLoadShim__##__name {                                            \
64     static const char* kName;                                               \
65     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;            \
66     static void* GetDsoHandle() {                                           \
67       auto s = GetRocblasDsoHandle();                                       \
68       return s.ValueOrDie();                                                \
69     }                                                                       \
70     static FuncPtrT LoadOrDie() {                                           \
71       void* f;                                                              \
72       auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
73           GetDsoHandle(), kName, &f);                                       \
74       CHECK(s.ok()) << "could not find " << kName                           \
75                     << " in rocblas DSO; dlerror: " << s.error_message();   \
76       return reinterpret_cast<FuncPtrT>(f);                                 \
77     }                                                                       \
78     static FuncPtrT DynLoad() {                                             \
79       static FuncPtrT f = LoadOrDie();                                      \
80       return f;                                                             \
81     }                                                                       \
82     template <typename... Args>                                             \
83     rocblas_status operator()(GpuExecutor* parent, Args... args) {          \
84       ScopedActivateExecutorContext sac{parent};                            \
85       return DynLoad()(args...);                                            \
86     }                                                                       \
87   } __name;                                                                 \
88   const char* DynLoadShim__##__name::kName = #__name;
89 
90 #endif
91 
92 ROCBLAS_WRAP(rocblas_create_handle)
93 ROCBLAS_WRAP(rocblas_destroy_handle)
94 ROCBLAS_WRAP(rocblas_set_stream)
95 ROCBLAS_WRAP(rocblas_dtrsm)
96 ROCBLAS_WRAP(rocblas_strsm)
97 
98 }  // namespace wrap
99 
100 struct ROCmSolverHandles {
ROCmSolverHandlestensorflow::__anon781cbc6c0111::ROCmSolverHandles101   explicit ROCmSolverHandles(GpuExecutor* parent, hipStream_t stream) {
102     parent_ = parent;
103     CHECK(wrap::rocblas_create_handle(parent_, &rocm_blas_handle) ==
104           rocblas_status_success)
105         << "Failed to create rocBlas instance.";
106     CHECK(wrap::rocblas_set_stream(parent_, rocm_blas_handle, stream) ==
107           rocblas_status_success)
108         << "Failed to set rocBlas stream.";
109   }
110 
~ROCmSolverHandlestensorflow::__anon781cbc6c0111::ROCmSolverHandles111   ~ROCmSolverHandles() {
112     CHECK(wrap::rocblas_destroy_handle(parent_, rocm_blas_handle) ==
113           rocblas_status_success)
114         << "Failed to destroy cuBlas instance.";
115   }
116   GpuExecutor* parent_;
117   rocblas_handle rocm_blas_handle;
118 };
119 
120 using HandleMap =
121     std::unordered_map<hipStream_t, std::unique_ptr<ROCmSolverHandles>>;
122 
123 // Returns a singleton map used for storing initialized handles for each unique
124 // gpu stream.
GetHandleMapSingleton()125 HandleMap* GetHandleMapSingleton() {
126   static HandleMap* cm = new HandleMap;
127   return cm;
128 }
129 
130 static mutex handle_map_mutex(LINKER_INITIALIZED);
131 
132 }  // namespace
133 
ROCmSolver(OpKernelContext * context)134 ROCmSolver::ROCmSolver(OpKernelContext* context) : context_(context) {
135   mutex_lock lock(handle_map_mutex);
136   GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(
137       context->op_device_context()->stream()->parent()->implementation());
138   const hipStream_t* hip_stream_ptr = CHECK_NOTNULL(
139       reinterpret_cast<const hipStream_t*>(context->op_device_context()
140                                                ->stream()
141                                                ->implementation()
142                                                ->GpuStreamMemberHack()));
143 
144   hip_stream_ = *hip_stream_ptr;
145   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
146   auto it = handle_map->find(hip_stream_);
147   if (it == handle_map->end()) {
148     LOG(INFO) << "Creating ROCmSolver handles for stream " << hip_stream_;
149     // Previously unseen Gpu stream. Initialize a set of Gpu solver library
150     // handles for it.
151     std::unique_ptr<ROCmSolverHandles> new_handles(
152         new ROCmSolverHandles(gpu_executor, hip_stream_));
153     it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles)))
154              .first;
155   }
156   rocm_blas_handle_ = it->second->rocm_blas_handle;
157 }
158 
~ROCmSolver()159 ROCmSolver::~ROCmSolver() {
160   for (auto tensor_ref : scratch_tensor_refs_) {
161     tensor_ref.Unref();
162   }
163 }
164 
165 #define TF_RETURN_IF_ROCBLAS_ERROR(expr)                                  \
166   do {                                                                    \
167     auto status = (expr);                                                 \
168     if (TF_PREDICT_FALSE(status != rocblas_status_success)) {             \
169       return errors::Internal(__FILE__, ":", __LINE__,                    \
170                               ": rocBlas call failed status = ", status); \
171     }                                                                     \
172   } while (0)
173 
174 // Macro that specializes a solver method for all 4 standard
175 // numeric types.
176 #define TF_CALL_LAPACK_TYPES(m) \
177   m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
178 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d)
179 
180 #define BLAS_SOLVER_FN(method, type_prefix) \
181   wrap::rocblas##_##type_prefix##method
182 
183 // Allocates a temporary tensor. The ROCmSolver object maintains a
184 // TensorReference to the underlying Tensor to prevent it from being deallocated
185 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)186 Status ROCmSolver::allocate_scoped_tensor(DataType type,
187                                           const TensorShape& shape,
188                                           Tensor* out_temp) {
189   const Status status = context_->allocate_temp(type, shape, out_temp);
190   if (status.ok()) {
191     scratch_tensor_refs_.emplace_back(*out_temp);
192   }
193   return status;
194 }
195 
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)196 Status ROCmSolver::forward_input_or_allocate_scoped_tensor(
197     gtl::ArraySlice<int> candidate_input_indices, DataType type,
198     const TensorShape& shape, Tensor* out_temp) {
199   const Status status = context_->forward_input_or_allocate_temp(
200       candidate_input_indices, type, shape, out_temp);
201   if (status.ok()) {
202     scratch_tensor_refs_.emplace_back(*out_temp);
203   }
204   return status;
205 }
206 
207 template <typename Scalar, typename SolverFnT>
TrsmImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_side side,rocblas_fill uplo,rocblas_operation trans,rocblas_diagonal diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)208 static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver,
209                               rocblas_handle rocm_blas_handle,
210                               rocblas_side side, rocblas_fill uplo,
211                               rocblas_operation trans, rocblas_diagonal diag,
212                               int m, int n,
213                               const Scalar* alpha, /* host or device pointer */
214                               const Scalar* A, int lda, Scalar* B, int ldb) {
215   mutex_lock lock(handle_map_mutex);
216   using ROCmScalar = typename ROCmComplexT<Scalar>::type;
217 
218   TF_RETURN_IF_ROCBLAS_ERROR(solver(gpu_executor, rocm_blas_handle, side, uplo,
219                                     trans, diag, m, n,
220                                     reinterpret_cast<const ROCmScalar*>(alpha),
221                                     reinterpret_cast<const ROCmScalar*>(A), lda,
222                                     reinterpret_cast<ROCmScalar*>(B), ldb));
223 
224   return Status::OK();
225 }
226 
227 #define TRSM_INSTANCE(Scalar, type_prefix)                                    \
228   template <>                                                                 \
229   Status ROCmSolver::Trsm<Scalar>(                                            \
230       rocblas_side side, rocblas_fill uplo, rocblas_operation trans,          \
231       rocblas_diagonal diag, int m, int n,                                    \
232       const Scalar* alpha, /* host or device pointer */                       \
233       const Scalar* A, int lda, Scalar* B, int ldb) {                         \
234     GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(                    \
235         context_->op_device_context()->stream()->parent()->implementation()); \
236     return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix),          \
237                     rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha,  \
238                     A, lda, B, ldb);                                          \
239   }
240 
241 TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE);
242 
243 }  // namespace tensorflow
244 
245 #endif  // TENSORFLOW_USE_ROCM
246