1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/util/cuda_solvers.h"
18
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "third_party/gpus/cuda/include/cublas_v2.h"
25 #include "third_party/gpus/cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/stream_executor.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
37
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47 const float* const*, int, const int*, float**,
48 int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50 const double* const*, int, const int*, double**,
51 int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53 const float2* const*, int, const int*, float2**,
54 int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56 const double2* const*, int, const int*,
57 double2**, int, int*, int);
58
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60 const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62 const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64 const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66 const int*, double2**, int, int*, int);
67
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75 double2**, int, int*, int);
76
77 using trsm_S = cublasStatus_t(cublasContext*, cublasSideMode_t,
78 cublasFillMode_t, cublasOperation_t,
79 cublasDiagType_t, int, int, const float*,
80 const float* const*, int, float* const*, int,
81 int);
82 using trsm_D = cublasStatus_t(cublasContext*, cublasSideMode_t,
83 cublasFillMode_t, cublasOperation_t,
84 cublasDiagType_t, int, int, const double*,
85 const double* const*, int, double* const*, int,
86 int);
87 using trsm_C = cublasStatus_t(cublasContext*, cublasSideMode_t,
88 cublasFillMode_t, cublasOperation_t,
89 cublasDiagType_t, int, int, const float2*,
90 const float2* const*, int, float2* const*, int,
91 int);
92 using trsm_Z = cublasStatus_t(cublasContext*, cublasSideMode_t,
93 cublasFillMode_t, cublasOperation_t,
94 cublasDiagType_t, int, int, const double2*,
95 const double2* const*, int, double2* const*, int,
96 int);
97 }
98
99 namespace tensorflow {
100 namespace {
101
102 using se::cuda::ScopedActivateExecutorContext;
103
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)104 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
105 const void* src, uint64 bytes) {
106 auto stream = context->op_device_context()->stream();
107 se::DeviceMemoryBase wrapped_dst(dst);
108 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
109 }
110
111 // A set of initialized handles to the underlying Cuda libraries used by
112 // CudaSolver. We maintain one such set of handles per unique stream.
113 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anon4e6550d80111::CudaSolverHandles114 explicit CudaSolverHandles(cudaStream_t stream) {
115 CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
116 << "Failed to create cuSolverDN instance.";
117 CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
118 CUSOLVER_STATUS_SUCCESS)
119 << "Failed to set cuSolverDN stream.";
120 CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
121 << "Failed to create cuBlas instance.";
122 CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
123 << "Failed to set cuBlas stream.";
124 }
125
~CudaSolverHandlestensorflow::__anon4e6550d80111::CudaSolverHandles126 ~CudaSolverHandles() {
127 CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
128 << "Failed to destroy cuBlas instance.";
129 CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
130 << "Failed to destroy cuSolverDN instance.";
131 }
132 cublasHandle_t cublas_handle;
133 cusolverDnHandle_t cusolver_dn_handle;
134 };
135
136 static mutex handle_map_mutex(LINKER_INITIALIZED);
137
138 using HandleMap =
139 std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
140
141 // Returns a singleton map used for storing initialized handles for each unique
142 // cuda stream.
GetHandleMapSingleton()143 HandleMap* GetHandleMapSingleton() {
144 static HandleMap* cm = new HandleMap;
145 return cm;
146 }
147
148 } // namespace
149
150 #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \
151 do { \
152 auto status = (expr); \
153 if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
154 return errors::Internal( \
155 __FILE__, ":", __LINE__, \
156 ": cuSolverDN call failed with status =", status); \
157 } \
158 } while (0)
159
160 #define TF_RETURN_IF_CUBLAS_ERROR(expr) \
161 do { \
162 auto status = (expr); \
163 if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \
164 return errors::Internal(__FILE__, ":", __LINE__, \
165 ": cuBlas call failed status = ", status); \
166 } \
167 } while (0)
168
CudaSolver(OpKernelContext * context)169 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
170 mutex_lock lock(handle_map_mutex);
171 const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
172 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
173 ->stream()
174 ->implementation()
175 ->GpuStreamMemberHack()));
176 cuda_stream_ = *cu_stream_ptr;
177 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
178 auto it = handle_map->find(cuda_stream_);
179 if (it == handle_map->end()) {
180 LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
181 // Previously unseen Cuda stream. Initialize a set of Cuda solver library
182 // handles for it.
183 std::unique_ptr<CudaSolverHandles> new_handles(
184 new CudaSolverHandles(cuda_stream_));
185 it =
186 handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
187 .first;
188 }
189 cusolver_dn_handle_ = it->second->cusolver_dn_handle;
190 cublas_handle_ = it->second->cublas_handle;
191 }
192
~CudaSolver()193 CudaSolver::~CudaSolver() {
194 for (const auto& tensor_ref : scratch_tensor_refs_) {
195 tensor_ref.Unref();
196 }
197 }
198
199 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)200 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
201 std::unique_ptr<CudaSolver> solver,
202 const std::vector<DeviceLapackInfo>& dev_lapack_infos,
203 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
204 info_checker_callback) {
205 CHECK(info_checker_callback != nullptr);
206 std::vector<HostLapackInfo> host_lapack_infos;
207 if (dev_lapack_infos.empty()) {
208 info_checker_callback(Status::OK(), host_lapack_infos);
209 return;
210 }
211
212 // Launch memcpys to copy info back from the device to the host.
213 for (const auto& dev_lapack_info : dev_lapack_infos) {
214 bool success = true;
215 auto host_copy = dev_lapack_info.CopyToHost(&success);
216 OP_REQUIRES(
217 solver->context(), success,
218 errors::Internal(
219 "Failed to launch copy of dev_lapack_info to host, debug_info = ",
220 dev_lapack_info.debug_info()));
221 host_lapack_infos.push_back(std::move(host_copy));
222 }
223
224 // This callback checks that all batch items in all calls were processed
225 // successfully and passes status to the info_checker_callback accordingly.
226 auto* stream = solver->context()->op_device_context()->stream();
227 auto wrapped_info_checker_callback =
228 [stream](
229 CudaSolver* solver,
230 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
231 info_checker_callback,
232 std::vector<HostLapackInfo> host_lapack_infos) {
233 ScopedActivateExecutorContext scoped_activation{stream->parent()};
234 Status status;
235 for (const auto& host_lapack_info : host_lapack_infos) {
236 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
237 const int info_value = host_lapack_info(i);
238 if (info_value != 0) {
239 status = errors::InvalidArgument(
240 "Got info = ", info_value, " for batch index ", i,
241 ", expected info = 0. Debug_info = ",
242 host_lapack_info.debug_info());
243 }
244 }
245 if (!status.ok()) {
246 break;
247 }
248 }
249 // Delete solver to release temp tensor refs.
250 delete solver;
251
252 // Delegate further error checking to provided functor.
253 info_checker_callback(status, host_lapack_infos);
254 };
255 // Note: An std::function cannot have unique_ptr arguments (it must be copy
256 // constructible and therefore so must its arguments). Therefore, we release
257 // solver into a raw pointer to be deleted at the end of
258 // wrapped_info_checker_callback.
259 // Release ownership of solver. It will be deleted in the cb callback.
260 auto solver_raw_ptr = solver.release();
261 auto cb =
262 std::bind(wrapped_info_checker_callback, solver_raw_ptr,
263 std::move(info_checker_callback), std::move(host_lapack_infos));
264
265 solver_raw_ptr->context()
266 ->device()
267 ->tensorflow_gpu_device_info()
268 ->event_mgr->ThenExecute(stream, std::move(cb));
269 }
270
271 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)272 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
273 std::unique_ptr<CudaSolver> solver,
274 const std::vector<DeviceLapackInfo>& dev_lapack_info,
275 AsyncOpKernel::DoneCallback done) {
276 OpKernelContext* context = solver->context();
277 auto wrapped_done = [context, done](
278 const Status& status,
279 const std::vector<HostLapackInfo>& /* unused */) {
280 if (done != nullptr) {
281 OP_REQUIRES_OK_ASYNC(context, status, done);
282 done();
283 } else {
284 OP_REQUIRES_OK(context, status);
285 }
286 };
287 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
288 wrapped_done);
289 }
290
291 // Allocates a temporary tensor. The CudaSolver object maintains a
292 // TensorReference to the underlying Tensor to prevent it from being deallocated
293 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)294 Status CudaSolver::allocate_scoped_tensor(DataType type,
295 const TensorShape& shape,
296 Tensor* out_temp) {
297 const Status status = context_->allocate_temp(type, shape, out_temp);
298 if (status.ok()) {
299 scratch_tensor_refs_.emplace_back(*out_temp);
300 }
301 return status;
302 }
303
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)304 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
305 gtl::ArraySlice<int> candidate_input_indices, DataType type,
306 const TensorShape& shape, Tensor* out_temp) {
307 const Status status = context_->forward_input_or_allocate_temp(
308 candidate_input_indices, type, shape, out_temp);
309 if (status.ok()) {
310 scratch_tensor_refs_.emplace_back(*out_temp);
311 }
312 return status;
313 }
314
315 // Macro that specializes a solver method for all 4 standard
316 // numeric types.
317 #define TF_CALL_LAPACK_TYPES(m) \
318 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
319 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
320
321 // Macros to construct cusolverDn method names.
322 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
323 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
324 #define DN_BUFSIZE_FN(method, type_prefix) \
325 cusolverDn##type_prefix##method##_bufferSize
326
327 // Macros to construct cublas method names.
328 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
329 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
330
331 //=============================================================================
332 // Wrappers of cuSolverDN computational methods begin here.
333 //
334 // WARNING to implementers: The function signatures listed in the online docs
335 // are sometimes inaccurate, e.g., are missing 'const' on pointers
336 // to immutable arguments, while the actual headers have them as expected.
337 // Check the actual declarations in the cusolver_api.h header file.
338 //
339 // NOTE: The cuSolver functions called below appear not to be threadsafe.
340 // so we put a global lock around the calls. Since these functions only put a
341 // kernel on the shared stream, it is not a big performance hit.
342 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
343 //=============================================================================
344
345 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)346 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
347 cublasOperation_t transa,
348 cublasOperation_t transb, int m, int n,
349 const Scalar* alpha, /* host or device pointer */
350 const Scalar* A, int lda,
351 const Scalar* beta, /* host or device pointer */
352 const Scalar* B, int ldb, Scalar* C, int ldc) {
353 mutex_lock lock(handle_map_mutex);
354 using CudaScalar = typename CUDAComplexT<Scalar>::type;
355 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
356 reinterpret_cast<const CudaScalar*>(alpha),
357 reinterpret_cast<const CudaScalar*>(A), lda,
358 reinterpret_cast<const CudaScalar*>(beta),
359 reinterpret_cast<const CudaScalar*>(B), ldb,
360 reinterpret_cast<CudaScalar*>(C), ldc));
361 return Status::OK();
362 }
363
364 #define GEAM_INSTANCE(Scalar, type_prefix) \
365 template <> \
366 Status CudaSolver::Geam<Scalar>( \
367 cublasOperation_t transa, cublasOperation_t transb, int m, int n, \
368 const Scalar* alpha, /* host or device pointer */ \
369 const Scalar* A, int lda, \
370 const Scalar* beta, /* host or device pointer */ \
371 const Scalar* B, int ldb, Scalar* C, int ldc) const { \
372 return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
373 transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \
374 }
375
376 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
377
378 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)379 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
380 CudaSolver* cuda_solver,
381 OpKernelContext* context,
382 cusolverDnHandle_t cusolver_dn_handle,
383 cublasFillMode_t uplo, int n, Scalar* A, int lda,
384 int* dev_lapack_info) {
385 mutex_lock lock(handle_map_mutex);
386 /* Get amount of workspace memory required. */
387 int lwork;
388 TF_RETURN_IF_CUSOLVER_ERROR(
389 bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
390 /* Allocate device memory for workspace. */
391 auto dev_workspace =
392 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
393 /* Launch the solver kernel. */
394 TF_RETURN_IF_CUSOLVER_ERROR(solver(
395 cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
396 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
397 return Status::OK();
398 }
399
400 #define POTRF_INSTANCE(Scalar, type_prefix) \
401 template <> \
402 Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \
403 int lda, int* dev_lapack_info) { \
404 return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \
405 DN_SOLVER_FN(potrf, type_prefix), this, context_, \
406 cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
407 }
408
409 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
410
411 #if CUDA_VERSION >= 9020
412 template <typename Scalar, typename SolverFnT>
PotrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,const Scalar * const host_a_dev_ptrs[],int lda,DeviceLapackInfo * dev_lapack_info,int batch_size)413 static inline Status PotrfBatchedImpl(
414 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
415 cusolverDnHandle_t cusolver_dn_handle, cublasFillMode_t uplo, int n,
416 const Scalar* const host_a_dev_ptrs[], int lda,
417 DeviceLapackInfo* dev_lapack_info, int batch_size) {
418 mutex_lock lock(handle_map_mutex);
419 using CudaScalar = typename CUDAComplexT<Scalar>::type;
420 ScratchSpace<uint8> dev_a_dev_ptrs =
421 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
422 /* on_host */ false);
423 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
424 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
425 return errors::Internal("PotrfBatched: failed to copy pointers to device");
426 }
427 TF_RETURN_IF_CUSOLVER_ERROR(
428 solver(cusolver_dn_handle, uplo, n,
429 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
430 dev_lapack_info->mutable_data(), batch_size));
431 return Status::OK();
432 }
433
434 #define POTRF_BATCHED_INSTANCE(Scalar, type_prefix) \
435 template <> \
436 Status CudaSolver::PotrfBatched( \
437 cublasFillMode_t uplo, int n, const Scalar* const host_a_dev_ptrs[], \
438 int lda, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
439 return PotrfBatchedImpl(DN_SOLVER_FN(potrfBatched, type_prefix), this, \
440 context_, cusolver_dn_handle_, uplo, n, \
441 host_a_dev_ptrs, lda, dev_lapack_info, \
442 batch_size); \
443 }
444
445 TF_CALL_LAPACK_TYPES(POTRF_BATCHED_INSTANCE);
446 #endif // CUDA_VERSION >= 9020
447
448 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)449 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
450 CudaSolver* cuda_solver,
451 OpKernelContext* context,
452 cusolverDnHandle_t cusolver_dn_handle, int m,
453 int n, Scalar* A, int lda, int* dev_pivots,
454 int* dev_lapack_info) {
455 mutex_lock lock(handle_map_mutex);
456 /* Get amount of workspace memory required. */
457 int lwork;
458 TF_RETURN_IF_CUSOLVER_ERROR(
459 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
460 /* Allocate device memory for workspace. */
461 auto dev_workspace =
462 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
463 /* Launch the solver kernel. */
464 TF_RETURN_IF_CUSOLVER_ERROR(solver(
465 cusolver_dn_handle, m, n, CUDAComplex(A), lda,
466 CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
467 return Status::OK();
468 }
469
470 #define GETRF_INSTANCE(Scalar, type_prefix) \
471 template <> \
472 Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
473 int* dev_pivots, int* dev_lapack_info) { \
474 return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \
475 DN_SOLVER_FN(getrf, type_prefix), this, context_, \
476 cusolver_dn_handle_, m, n, A, lda, dev_pivots, \
477 dev_lapack_info); \
478 }
479
480 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
481
482 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)483 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
484 cusolverDnHandle_t cusolver_dn_handle,
485 cublasOperation_t trans, int n, int nrhs,
486 const Scalar* A, int lda, const int* pivots,
487 Scalar* B, int ldb, int* dev_lapack_info) {
488 mutex_lock lock(handle_map_mutex);
489 /* Launch the solver kernel. */
490 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
491 CUDAComplex(A), lda, pivots,
492 CUDAComplex(B), ldb, dev_lapack_info));
493 return Status::OK();
494 }
495
496 #define GETRS_INSTANCE(Scalar, type_prefix) \
497 template <> \
498 Status CudaSolver::Getrs<Scalar>( \
499 cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \
500 const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \
501 return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \
502 cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
503 ldb, dev_lapack_info); \
504 }
505
506 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
507
508 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)509 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
510 CudaSolver* cuda_solver,
511 OpKernelContext* context,
512 cusolverDnHandle_t cusolver_dn_handle, int m,
513 int n, Scalar* A, int lda, Scalar* tau,
514 int* dev_lapack_info) {
515 mutex_lock lock(handle_map_mutex);
516 /* Get amount of workspace memory required. */
517 int lwork;
518 TF_RETURN_IF_CUSOLVER_ERROR(
519 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
520 /* Allocate device memory for workspace. */
521 auto dev_workspace =
522 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
523 /* Launch the solver kernel. */
524 TF_RETURN_IF_CUSOLVER_ERROR(solver(
525 cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
526 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
527 return Status::OK();
528 }
529
530 #define GEQRF_INSTANCE(Scalar, type_prefix) \
531 template <> \
532 Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
533 Scalar* tau, int* dev_lapack_info) { \
534 return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \
535 DN_SOLVER_FN(geqrf, type_prefix), this, context_, \
536 cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
537 }
538
539 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
540
541 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)542 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
543 CudaSolver* cuda_solver,
544 OpKernelContext* context,
545 cusolverDnHandle_t cusolver_dn_handle,
546 cublasSideMode_t side, cublasOperation_t trans,
547 int m, int n, int k, const Scalar* dev_a,
548 int lda, const Scalar* dev_tau, Scalar* dev_c,
549 int ldc, int* dev_lapack_info) {
550 mutex_lock lock(handle_map_mutex);
551 /* Get amount of workspace memory required. */
552 int lwork;
553 TF_RETURN_IF_CUSOLVER_ERROR(
554 bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
555 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
556 /* Allocate device memory for workspace. */
557 auto dev_workspace =
558 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
559 /* Launch the solver kernel. */
560 TF_RETURN_IF_CUSOLVER_ERROR(solver(
561 cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
562 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
563 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
564 return Status::OK();
565 }
566
567 // Unfortunately the LAPACK function name differs for the real and complex case
568 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
569 // one separately.
570 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \
571 template <> \
572 Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \
573 int m, int n, int k, const Scalar* dev_a, int lda, \
574 const Scalar* dev_tau, Scalar* dev_c, int ldc, \
575 int* dev_lapack_info) { \
576 return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \
577 DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \
578 context_, cusolver_dn_handle_, side, trans, m, n, k, \
579 dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \
580 }
581
582 UNMQR_INSTANCE(float, or, S);
583 UNMQR_INSTANCE(double, or, D);
584 UNMQR_INSTANCE(complex64, un, C);
585 UNMQR_INSTANCE(complex128, un, Z);
586
587 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)588 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
589 CudaSolver* cuda_solver,
590 OpKernelContext* context,
591 cusolverDnHandle_t cusolver_dn_handle, int m,
592 int n, int k, Scalar* dev_a, int lda,
593 const Scalar* dev_tau, int* dev_lapack_info) {
594 mutex_lock lock(handle_map_mutex);
595 /* Get amount of workspace memory required. */
596 int lwork;
597 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
598 CUDAComplex(dev_a), lda,
599 CUDAComplex(dev_tau), &lwork));
600 /* Allocate device memory for workspace. */
601 auto dev_workspace =
602 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
603 /* Launch the solver kernel. */
604 TF_RETURN_IF_CUSOLVER_ERROR(
605 solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
606 CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
607 lwork, dev_lapack_info));
608 return Status::OK();
609 }
610
611 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \
612 template <> \
613 Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
614 const Scalar* dev_tau, int* dev_lapack_info) { \
615 return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \
616 DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
617 context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \
618 dev_tau, dev_lapack_info); \
619 }
620
621 UNGQR_INSTANCE(float, or, S);
622 UNGQR_INSTANCE(double, or, D);
623 UNGQR_INSTANCE(complex64, un, C);
624 UNGQR_INSTANCE(complex128, un, Z);
625
626 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)627 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
628 CudaSolver* cuda_solver,
629 OpKernelContext* context,
630 cusolverDnHandle_t cusolver_dn_handle,
631 cusolverEigMode_t jobz, cublasFillMode_t uplo,
632 int n, Scalar* dev_A, int lda,
633 typename Eigen::NumTraits<Scalar>::Real* dev_W,
634 int* dev_lapack_info) {
635 mutex_lock lock(handle_map_mutex);
636 /* Get amount of workspace memory required. */
637 int lwork;
638 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
639 CUDAComplex(dev_A), lda,
640 CUDAComplex(dev_W), &lwork));
641 /* Allocate device memory for workspace. */
642 auto dev_workspace =
643 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
644 /* Launch the solver kernel. */
645 TF_RETURN_IF_CUSOLVER_ERROR(
646 solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
647 CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
648 lwork, dev_lapack_info));
649 return Status::OK();
650 }
651
652 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \
653 template <> \
654 Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \
655 int n, Scalar* dev_A, int lda, \
656 typename Eigen::NumTraits<Scalar>::Real* dev_W, \
657 int* dev_lapack_info) { \
658 return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \
659 DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \
660 context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
661 dev_W, dev_lapack_info); \
662 }
663
664 HEEVD_INSTANCE(float, sy, S);
665 HEEVD_INSTANCE(double, sy, D);
666 HEEVD_INSTANCE(complex64, he, C);
667 HEEVD_INSTANCE(complex128, he, Z);
668
669 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)670 static inline Status GesvdImpl(
671 BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
672 OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
673 signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
674 Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
675 mutex_lock lock(handle_map_mutex);
676 /* Get amount of workspace memory required. */
677 int lwork;
678 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
679 /* Allocate device memory for workspace. */
680 auto dev_workspace =
681 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
682 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
683 CUDAComplex(A), lda, S, CUDAComplex(U),
684 ldu, CUDAComplex(VT), ldvt,
685 CUDAComplex(dev_workspace.mutable_data()),
686 lwork, nullptr, dev_lapack_info));
687 return Status::OK();
688 }
689
690 #define GESVD_INSTANCE(Scalar, type_prefix) \
691 template <> \
692 Status CudaSolver::Gesvd<Scalar>( \
693 signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \
694 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \
695 int ldvt, int* dev_lapack_info) { \
696 return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \
697 DN_SOLVER_FN(gesvd, type_prefix), this, context_, \
698 cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
699 dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \
700 }
701
702 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
703
704 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)705 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
706 CudaSolver* cuda_solver,
707 OpKernelContext* context,
708 cusolverDnHandle_t cusolver_dn_handle,
709 cusolverEigMode_t jobz, int m, int n,
710 Scalar* A, int lda, Scalar* S, Scalar* U,
711 int ldu, Scalar* V, int ldv,
712 int* dev_lapack_info, int batch_size) {
713 mutex_lock lock(handle_map_mutex);
714 /* Get amount of workspace memory required. */
715 int lwork;
716 /* Default parameters for gesvdj and gesvdjBatched. */
717 gesvdjInfo_t svdj_info;
718 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
719 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
720 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
721 ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
722 /* Allocate device memory for workspace. */
723 auto dev_workspace =
724 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
725 TF_RETURN_IF_CUSOLVER_ERROR(solver(
726 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
727 ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
728 lwork, dev_lapack_info, svdj_info, batch_size));
729 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
730 return Status::OK();
731 }
732
733 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix) \
734 template <> \
735 Status CudaSolver::GesvdjBatched<Scalar>( \
736 cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, \
737 Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, \
738 int* dev_lapack_info, int batch_size) { \
739 return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix), \
740 DN_SOLVER_FN(gesvdjBatched, type_prefix), this, \
741 context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
742 lda, dev_S, dev_U, ldu, dev_V, ldv, \
743 dev_lapack_info, batch_size); \
744 }
745
746 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
747
748 //=============================================================================
749 // Wrappers of cuBlas computational methods begin here.
750 //
751 // WARNING to implementers: The function signatures listed in the online docs
752 // are sometimes inaccurate, e.g., are missing 'const' on pointers
753 // to immutable arguments, while the actual headers have them as expected.
754 // Check the actual declarations in the cublas_api.h header file.
755 //=============================================================================
756 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)757 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
758 OpKernelContext* context,
759 cublasHandle_t cublas_handle, int n,
760 const Scalar* const host_a_dev_ptrs[],
761 int lda, int* dev_pivots,
762 DeviceLapackInfo* dev_lapack_info,
763 int batch_size) {
764 mutex_lock lock(handle_map_mutex);
765 using CudaScalar = typename CUDAComplexT<Scalar>::type;
766 ScratchSpace<uint8> dev_a_dev_ptrs =
767 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
768 /* on_host */ false);
769 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
770 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
771 return errors::Internal("GetrfBatched: failed to copy pointers to device");
772 }
773 TF_RETURN_IF_CUBLAS_ERROR(
774 solver(cublas_handle, n,
775 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
776 dev_pivots, dev_lapack_info->mutable_data(), batch_size));
777 return Status::OK();
778 }
779
780 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \
781 template <> \
782 Status CudaSolver::GetrfBatched( \
783 int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \
784 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
785 return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \
786 context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
787 dev_pivots, dev_lapack_info, batch_size); \
788 }
789
790 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
791
792 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)793 static inline Status GetrsBatchedImpl(
794 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
795 cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
796 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
797 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
798 int batch_size) {
799 mutex_lock lock(handle_map_mutex);
800 using CudaScalar = typename CUDAComplexT<Scalar>::type;
801 ScratchSpace<uint8> dev_a_dev_ptrs =
802 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
803 /* on_host */ false);
804 ScratchSpace<uint8> dev_b_dev_ptrs =
805 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
806 /* on_host */ false);
807 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
808 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
809 return errors::Internal("GetrsBatched: failed to copy pointers to device");
810 }
811 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
812 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
813 return errors::Internal("GetrsBatched: failed to copy pointers to device");
814 }
815 TF_RETURN_IF_CUBLAS_ERROR(solver(
816 cublas_handle, trans, n, nrhs,
817 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
818 dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
819 ldb, host_lapack_info, batch_size));
820 return Status::OK();
821 }
822
823 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \
824 template <> \
825 Status CudaSolver::GetrsBatched( \
826 cublasOperation_t trans, int n, int nrhs, \
827 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
828 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
829 int batch_size) { \
830 return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
831 BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
832 this, context_, cublas_handle_, trans, n, nrhs, \
833 host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
834 ldb, host_lapack_info, batch_size); \
835 }
836
837 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
838
839 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)840 static inline Status GetriBatchedImpl(
841 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
842 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
843 int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
844 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
845 mutex_lock lock(handle_map_mutex);
846 using CudaScalar = typename CUDAComplexT<Scalar>::type;
847 ScratchSpace<uint8> dev_a_dev_ptrs =
848 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
849 /* on_host */ false);
850 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
851 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
852 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
853 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
854 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
855 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
856 return errors::Internal("GetriBatched: failed to copy pointers to device");
857 }
858 TF_RETURN_IF_CUBLAS_ERROR(
859 solver(cublas_handle, n,
860 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
861 lda, dev_pivots,
862 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
863 ldainv, dev_lapack_info->mutable_data(), batch_size));
864 return Status::OK();
865 }
866
867 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \
868 template <> \
869 Status CudaSolver::GetriBatched( \
870 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
871 const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \
872 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
873 return GetriBatchedImpl( \
874 reinterpret_cast<getri_##type_prefix*>( \
875 BLAS_SOLVER_FN(getriBatched, type_prefix)), \
876 this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
877 host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \
878 }
879
880 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
881
882 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)883 static inline Status MatInvBatchedImpl(
884 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
885 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
886 int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
887 DeviceLapackInfo* dev_lapack_info, int batch_size) {
888 mutex_lock lock(handle_map_mutex);
889 using CudaScalar = typename CUDAComplexT<Scalar>::type;
890 ScratchSpace<uint8> dev_a_dev_ptrs =
891 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
892 /* on_host */ false);
893 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
894 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
895 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
896 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
897 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
898 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
899 return errors::Internal("MatInvBatched: failed to copy pointers to device");
900 }
901 TF_RETURN_IF_CUBLAS_ERROR(solver(
902 cublas_handle, n,
903 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
904 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
905 dev_lapack_info->mutable_data(), batch_size));
906 return Status::OK();
907 }
908
909 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \
910 template <> \
911 Status CudaSolver::MatInvBatched( \
912 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
913 const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \
914 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
915 return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \
916 BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
917 this, context_, cublas_handle_, n, \
918 host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \
919 ldainv, dev_lapack_info, batch_size); \
920 }
921
922 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
923
924 template <typename Scalar, typename SolverFnT>
TrsmImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * A,int lda,Scalar * B,int ldb)925 static inline Status TrsmImpl(SolverFnT solver, cublasHandle_t cublas_handle,
926 cublasSideMode_t side, cublasFillMode_t uplo,
927 cublasOperation_t trans, cublasDiagType_t diag,
928 int m, int n,
929 const Scalar* alpha, /* host or device pointer */
930 const Scalar* A, int lda, Scalar* B, int ldb) {
931 mutex_lock lock(handle_map_mutex);
932 using CudaScalar = typename CUDAComplexT<Scalar>::type;
933 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, side, uplo, trans, diag, m, n,
934 reinterpret_cast<const CudaScalar*>(alpha),
935 reinterpret_cast<const CudaScalar*>(A), lda,
936 reinterpret_cast<CudaScalar*>(B), ldb));
937 return Status::OK();
938 }
939
940 #define TRSM_INSTANCE(Scalar, type_prefix) \
941 template <> \
942 Status CudaSolver::Trsm<Scalar>( \
943 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
944 cublasDiagType_t diag, int m, int n, \
945 const Scalar* alpha, /* host or device pointer */ \
946 const Scalar* A, int lda, Scalar* B, int ldb) { \
947 return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), cublas_handle_, side, \
948 uplo, trans, diag, m, n, alpha, A, lda, B, ldb); \
949 }
950
951 TF_CALL_LAPACK_TYPES(TRSM_INSTANCE);
952
953 template <typename Scalar, typename SolverFnT>
TrsvImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int n,const Scalar * A,int lda,Scalar * x,int incx)954 static inline Status TrsvImpl(SolverFnT solver, cublasHandle_t cublas_handle,
955 cublasFillMode_t uplo, cublasOperation_t trans,
956 cublasDiagType_t diag, int n, const Scalar* A,
957 int lda, Scalar* x, int incx) {
958 mutex_lock lock(handle_map_mutex);
959 using CudaScalar = typename CUDAComplexT<Scalar>::type;
960 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, uplo, trans, diag, n,
961 reinterpret_cast<const CudaScalar*>(A), lda,
962 reinterpret_cast<CudaScalar*>(x), incx));
963 return Status::OK();
964 }
965
966 #define TRSV_INSTANCE(Scalar, type_prefix) \
967 template <> \
968 Status CudaSolver::Trsv<Scalar>( \
969 cublasFillMode_t uplo, cublasOperation_t trans, cublasDiagType_t diag, \
970 int n, const Scalar* A, int lda, Scalar* x, int incx) { \
971 return TrsvImpl(BLAS_SOLVER_FN(trsv, type_prefix), cublas_handle_, uplo, \
972 trans, diag, n, A, lda, x, incx); \
973 }
974
975 TF_CALL_LAPACK_TYPES(TRSV_INSTANCE);
976
977 template <typename Scalar, typename SolverFnT>
TrsmBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasSideMode_t side,cublasFillMode_t uplo,cublasOperation_t trans,cublasDiagType_t diag,int m,int n,const Scalar * alpha,const Scalar * const host_a_dev_ptrs[],int lda,Scalar * host_b_dev_ptrs[],int ldb,int batch_size)978 static inline Status TrsmBatchedImpl(
979 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
980 cublasHandle_t cublas_handle, cublasSideMode_t side, cublasFillMode_t uplo,
981 cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
982 const Scalar* alpha, const Scalar* const host_a_dev_ptrs[], int lda,
983 Scalar* host_b_dev_ptrs[], int ldb, int batch_size) {
984 mutex_lock lock(handle_map_mutex);
985 using CudaScalar = typename CUDAComplexT<Scalar>::type;
986 ScratchSpace<uint8> dev_a_dev_ptrs =
987 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
988 /* on_host */ false);
989 ScratchSpace<uint8> dev_b_dev_ptrs =
990 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
991 /* on_host */ false);
992 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
993 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
994 return errors::Internal("TrsmBatched: failed to copy pointers to device");
995 }
996 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
997 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
998 return errors::Internal("TrsmBatched: failed to copy pointers to device");
999 }
1000 TF_RETURN_IF_CUBLAS_ERROR(
1001 solver(cublas_handle, side, uplo, trans, diag, m, n,
1002 reinterpret_cast<const CudaScalar*>(alpha),
1003 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
1004 lda, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
1005 ldb, batch_size));
1006 return Status::OK();
1007 }
1008
1009 #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \
1010 template <> \
1011 Status CudaSolver::TrsmBatched( \
1012 cublasSideMode_t side, cublasFillMode_t uplo, cublasOperation_t trans, \
1013 cublasDiagType_t diag, int m, int n, const Scalar* alpha, \
1014 const Scalar* const dev_Aarray[], int lda, Scalar* dev_Barray[], \
1015 int ldb, int batch_size) { \
1016 return TrsmBatchedImpl(reinterpret_cast<trsm_##type_prefix*>( \
1017 BLAS_SOLVER_FN(trsmBatched, type_prefix)), \
1018 this, context_, cublas_handle_, side, uplo, trans, \
1019 diag, m, n, alpha, dev_Aarray, lda, dev_Barray, \
1020 ldb, batch_size); \
1021 }
1022
1023 TF_CALL_LAPACK_TYPES(TRSM_BATCHED_INSTANCE);
1024
1025 } // namespace tensorflow
1026
1027 #endif // GOOGLE_CUDA
1028