1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 #if TENSORFLOW_USE_ROCM
17 #include "tensorflow/core/util/rocm_solvers.h"
18
19 #include <complex>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "rocm/include/rocblas.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/blocking_counter.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
35 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
36 #include "tensorflow/stream_executor/lib/env.h"
37 #include "tensorflow/stream_executor/platform/default/dso_loader.h"
38 #include "tensorflow/stream_executor/platform/port.h"
39
40 namespace tensorflow {
41 namespace {
42
43 using stream_executor::gpu::GpuExecutor;
44 using stream_executor::gpu::ScopedActivateExecutorContext;
45 using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
46
47 namespace wrap {
48 #ifdef PLATFORM_GOOGLE
49 #define ROCBLAS_WRAP(__name) \
50 struct WrapperShim__##__name { \
51 static const char* kName; \
52 template <typename... Args> \
53 rocblas_status operator()(GpuExecutor* parent, Args... args) { \
54 ScopedActivateExecutorContext sac{parent}; \
55 return ::__name(args...); \
56 } \
57 } __name; \
58 const char* WrapperShim__##__name::kName = #__name;
59
60 #else
61
62 #define ROCBLAS_WRAP(__name) \
63 struct DynLoadShim__##__name { \
64 static const char* kName; \
65 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
66 static void* GetDsoHandle() { \
67 auto s = GetRocblasDsoHandle(); \
68 return s.ValueOrDie(); \
69 } \
70 static FuncPtrT LoadOrDie() { \
71 void* f; \
72 auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
73 GetDsoHandle(), kName, &f); \
74 CHECK(s.ok()) << "could not find " << kName \
75 << " in rocblas DSO; dlerror: " << s.error_message(); \
76 return reinterpret_cast<FuncPtrT>(f); \
77 } \
78 static FuncPtrT DynLoad() { \
79 static FuncPtrT f = LoadOrDie(); \
80 return f; \
81 } \
82 template <typename... Args> \
83 rocblas_status operator()(GpuExecutor* parent, Args... args) { \
84 ScopedActivateExecutorContext sac{parent}; \
85 return DynLoad()(args...); \
86 } \
87 } __name; \
88 const char* DynLoadShim__##__name::kName = #__name;
89
90 #endif
91
92 ROCBLAS_WRAP(rocblas_create_handle)
93 ROCBLAS_WRAP(rocblas_destroy_handle)
94 ROCBLAS_WRAP(rocblas_set_stream)
95 ROCBLAS_WRAP(rocblas_dtrsm)
96 ROCBLAS_WRAP(rocblas_strsm)
97
98 } // namespace wrap
99
100 struct ROCmSolverHandles {
ROCmSolverHandlestensorflow::__anon781cbc6c0111::ROCmSolverHandles101 explicit ROCmSolverHandles(GpuExecutor* parent, hipStream_t stream) {
102 parent_ = parent;
103 CHECK(wrap::rocblas_create_handle(parent_, &rocm_blas_handle) ==
104 rocblas_status_success)
105 << "Failed to create rocBlas instance.";
106 CHECK(wrap::rocblas_set_stream(parent_, rocm_blas_handle, stream) ==
107 rocblas_status_success)
108 << "Failed to set rocBlas stream.";
109 }
110
~ROCmSolverHandlestensorflow::__anon781cbc6c0111::ROCmSolverHandles111 ~ROCmSolverHandles() {
112 CHECK(wrap::rocblas_destroy_handle(parent_, rocm_blas_handle) ==
113 rocblas_status_success)
114 << "Failed to destroy cuBlas instance.";
115 }
116 GpuExecutor* parent_;
117 rocblas_handle rocm_blas_handle;
118 };
119
120 using HandleMap =
121 std::unordered_map<hipStream_t, std::unique_ptr<ROCmSolverHandles>>;
122
123 // Returns a singleton map used for storing initialized handles for each unique
124 // gpu stream.
GetHandleMapSingleton()125 HandleMap* GetHandleMapSingleton() {
126 static HandleMap* cm = new HandleMap;
127 return cm;
128 }
129
130 static mutex handle_map_mutex(LINKER_INITIALIZED);
131
132 } // namespace
133
ROCmSolver(OpKernelContext * context)134 ROCmSolver::ROCmSolver(OpKernelContext* context) : context_(context) {
135 mutex_lock lock(handle_map_mutex);
136 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>(
137 context->op_device_context()->stream()->parent()->implementation());
138 const hipStream_t* hip_stream_ptr = CHECK_NOTNULL(
139 reinterpret_cast<const hipStream_t*>(context->op_device_context()
140 ->stream()
141 ->implementation()
142 ->GpuStreamMemberHack()));
143
144 hip_stream_ = *hip_stream_ptr;
145 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
146 auto it = handle_map->find(hip_stream_);
147 if (it == handle_map->end()) {
148 LOG(INFO) << "Creating ROCmSolver handles for stream " << hip_stream_;
149 // Previously unseen Gpu stream. Initialize a set of Gpu solver library
150 // handles for it.
151 std::unique_ptr<ROCmSolverHandles> new_handles(
152 new ROCmSolverHandles(gpu_executor, hip_stream_));
153 it = handle_map->insert(std::make_pair(hip_stream_, std::move(new_handles)))
154 .first;
155 }
156 rocm_blas_handle_ = it->second->rocm_blas_handle;
157 }
158
~ROCmSolver()159 ROCmSolver::~ROCmSolver() {
160 for (auto tensor_ref : scratch_tensor_refs_) {
161 tensor_ref.Unref();
162 }
163 }
164
165 #define TF_RETURN_IF_ROCBLAS_ERROR(expr) \
166 do { \
167 auto status = (expr); \
168 if (TF_PREDICT_FALSE(status != rocblas_status_success)) { \
169 return errors::Internal(__FILE__, ":", __LINE__, \
170 ": rocBlas call failed status = ", status); \
171 } \
172 } while (0)
173
174 // Macro that specializes a solver method for all 4 standard
175 // numeric types.
176 #define TF_CALL_LAPACK_TYPES(m) \
177 m(float, s) m(double, d) m(std::complex<float>, c) m(std::complex<double>, z)
178 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, s) m(double, d)
179
180 #define BLAS_SOLVER_FN(method, type_prefix) \
181 wrap::rocblas##_##type_prefix##method
182
183 // Allocates a temporary tensor. The ROCmSolver object maintains a
184 // TensorReference to the underlying Tensor to prevent it from being deallocated
185 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)186 Status ROCmSolver::allocate_scoped_tensor(DataType type,
187 const TensorShape& shape,
188 Tensor* out_temp) {
189 const Status status = context_->allocate_temp(type, shape, out_temp);
190 if (status.ok()) {
191 scratch_tensor_refs_.emplace_back(*out_temp);
192 }
193 return status;
194 }
195
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)196 Status ROCmSolver::forward_input_or_allocate_scoped_tensor(
197 gtl::ArraySlice<int> candidate_input_indices, DataType type,
198 const TensorShape& shape, Tensor* out_temp) {
199 const Status status = context_->forward_input_or_allocate_temp(
200 candidate_input_indices, type, shape, out_temp);
201 if (status.ok()) {
202 scratch_tensor_refs_.emplace_back(*out_temp);
203 }
204 return status;
205 }
206
207 template <typename Scalar, typename SolverFnT>
TrsmImpl(GpuExecutor * gpu_executor,SolverFnT solver,rocblas_handle rocm_blas_handle,rocblas_side side,rocblas_fill uplo,rocblas_operation trans,rocblas_diagonal diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)208 static inline Status TrsmImpl(GpuExecutor* gpu_executor, SolverFnT solver,
209 rocblas_handle rocm_blas_handle,
210 rocblas_side side, rocblas_fill uplo,
211 rocblas_operation trans, rocblas_diagonal diag,
212 int m, int n,
213 const Scalar* alpha, /* host or device pointer */
214 const Scalar* A, int lda, Scalar* B, int ldb) {
215 mutex_lock lock(handle_map_mutex);
216 using ROCmScalar = typename ROCmComplexT<Scalar>::type;
217
218 TF_RETURN_IF_ROCBLAS_ERROR(solver(gpu_executor, rocm_blas_handle, side, uplo,
219 trans, diag, m, n,
220 reinterpret_cast<const ROCmScalar*>(alpha),
221 reinterpret_cast<const ROCmScalar*>(A), lda,
222 reinterpret_cast<ROCmScalar*>(B), ldb));
223
224 return Status::OK();
225 }
226
227 #define TRSM_INSTANCE(Scalar, type_prefix) \
228 template <> \
229 Status ROCmSolver::Trsm<Scalar>( \
230 rocblas_side side, rocblas_fill uplo, rocblas_operation trans, \
231 rocblas_diagonal diag, int m, int n, \
232 const Scalar* alpha, /* host or device pointer */ \
233 const Scalar* A, int lda, Scalar* B, int ldb) { \
234 GpuExecutor* gpu_executor = static_cast<GpuExecutor*>( \
235 context_->op_device_context()->stream()->parent()->implementation()); \
236 return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix), \
237 rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \
238 A, lda, B, ldb); \
239 }
240
241 TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE);
242
243 } // namespace tensorflow
244
245 #endif // TENSORFLOW_USE_ROCM
246