1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================
15 */
16 #ifdef GOOGLE_CUDA
17 #include "tensorflow/core/kernels/cuda_solvers.h"
18
19 #include <chrono>
20 #include <complex>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "cuda/include/cublas_v2.h"
25 #include "cuda/include/cusolverDn.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/blocking_counter.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/platform/cuda.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/stream_executor.h"
36 #include "tensorflow/core/platform/types.h"
37
38 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
39 // casting away constness on our data, we instead reinterpret the CuBLAS
40 // functions as what they were clearly meant to be, and thus we can call
41 // the functions naturally.
42 //
43 // (The error is that input-only arrays are bound to parameter types
44 // "const T**" instead of the correct "const T* const*".)
45 extern "C" {
46 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
47 const float* const*, int, const int*, float**,
48 int, int*, int);
49 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
50 const double* const*, int, const int*, double**,
51 int, int*, int);
52 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
53 const float2* const*, int, const int*, float2**,
54 int, int*, int);
55 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
56 const double2* const*, int, const int*,
57 double2**, int, int*, int);
58
59 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
60 const int*, float**, int, int*, int);
61 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
62 const int*, double**, int, int*, int);
63 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
64 const int*, float2**, int, int*, int);
65 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
66 const int*, double2**, int, int*, int);
67
68 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
69 float**, int, int*, int);
70 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
71 double**, int, int*, int);
72 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
73 float2**, int, int*, int);
74 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
75 double2**, int, int*, int);
76 }
77
78 namespace tensorflow {
79 namespace {
80
81 using se::cuda::ScopedActivateExecutorContext;
82
CopyHostToDevice(OpKernelContext * context,void * dst,const void * src,uint64 bytes)83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
84 const void* src, uint64 bytes) {
85 auto stream = context->op_device_context()->stream();
86 se::DeviceMemoryBase wrapped_dst(dst);
87 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
88 }
89
90 // A set of initialized handles to the underlying Cuda libraries used by
91 // CudaSolver. We maintain one such set of handles per unique stream.
92 struct CudaSolverHandles {
CudaSolverHandlestensorflow::__anona853a54e0111::CudaSolverHandles93 explicit CudaSolverHandles(cudaStream_t stream) {
94 CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
95 << "Failed to create cuSolverDN instance.";
96 CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
97 CUSOLVER_STATUS_SUCCESS)
98 << "Failed to set cuSolverDN stream.";
99 CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
100 << "Failed to create cuBlas instance.";
101 CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
102 << "Failed to set cuBlas stream.";
103 }
104
~CudaSolverHandlestensorflow::__anona853a54e0111::CudaSolverHandles105 ~CudaSolverHandles() {
106 CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
107 << "Failed to destroy cuBlas instance.";
108 CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
109 << "Failed to destroy cuSolverDN instance.";
110 }
111 cublasHandle_t cublas_handle;
112 cusolverDnHandle_t cusolver_dn_handle;
113 };
114
115 static mutex handle_map_mutex(LINKER_INITIALIZED);
116
117 using HandleMap =
118 std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
119
120 // Returns a singleton map used for storing initialized handles for each unique
121 // cuda stream.
GetHandleMapSingleton()122 HandleMap* GetHandleMapSingleton() {
123 static HandleMap* cm = new HandleMap;
124 return cm;
125 }
126
127 } // namespace
128
129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \
130 do { \
131 auto status = (expr); \
132 if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
133 return errors::Internal( \
134 __FILE__, ":", __LINE__, \
135 ": cuSolverDN call failed with status =", status); \
136 } \
137 } while (0)
138
139 #define TF_RETURN_IF_CUBLAS_ERROR(expr) \
140 do { \
141 auto status = (expr); \
142 if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \
143 return errors::Internal(__FILE__, ":", __LINE__, \
144 ": cuBlas call failed status = ", status); \
145 } \
146 } while (0)
147
CudaSolver(OpKernelContext * context)148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
149 mutex_lock lock(handle_map_mutex);
150 const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
151 reinterpret_cast<const cudaStream_t*>(context->op_device_context()
152 ->stream()
153 ->implementation()
154 ->GpuStreamMemberHack()));
155 cuda_stream_ = *cu_stream_ptr;
156 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
157 auto it = handle_map->find(cuda_stream_);
158 if (it == handle_map->end()) {
159 LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
160 // Previously unseen Cuda stream. Initialize a set of Cuda solver library
161 // handles for it.
162 std::unique_ptr<CudaSolverHandles> new_handles(
163 new CudaSolverHandles(cuda_stream_));
164 it =
165 handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
166 .first;
167 }
168 cusolver_dn_handle_ = it->second->cusolver_dn_handle;
169 cublas_handle_ = it->second->cublas_handle;
170 }
171
~CudaSolver()172 CudaSolver::~CudaSolver() {
173 for (auto tensor_ref : scratch_tensor_refs_) {
174 tensor_ref.Unref();
175 }
176 }
177
178 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_infos,std::function<void (const Status &,const std::vector<HostLapackInfo> &)> info_checker_callback)179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
180 std::unique_ptr<CudaSolver> solver,
181 const std::vector<DeviceLapackInfo>& dev_lapack_infos,
182 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
183 info_checker_callback) {
184 CHECK(info_checker_callback != nullptr);
185 std::vector<HostLapackInfo> host_lapack_infos;
186 if (dev_lapack_infos.empty()) {
187 info_checker_callback(Status::OK(), host_lapack_infos);
188 return;
189 }
190
191 // Launch memcpys to copy info back from the device to the host.
192 for (const auto& dev_lapack_info : dev_lapack_infos) {
193 bool success = true;
194 auto host_copy = dev_lapack_info.CopyToHost(&success);
195 OP_REQUIRES(
196 solver->context(), success,
197 errors::Internal(
198 "Failed to launch copy of dev_lapack_info to host, debug_info = ",
199 dev_lapack_info.debug_info()));
200 host_lapack_infos.push_back(std::move(host_copy));
201 }
202
203 // This callback checks that all batch items in all calls were processed
204 // successfully and passes status to the info_checker_callback accordingly.
205 auto* stream = solver->context()->op_device_context()->stream();
206 auto wrapped_info_checker_callback =
207 [stream](
208 CudaSolver* solver,
209 std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
210 info_checker_callback,
211 std::vector<HostLapackInfo> host_lapack_infos) {
212 ScopedActivateExecutorContext scoped_activation{stream->parent()};
213 Status status;
214 for (const auto& host_lapack_info : host_lapack_infos) {
215 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
216 const int info_value = host_lapack_info(i);
217 if (info_value != 0) {
218 status = errors::InvalidArgument(
219 "Got info = ", info_value, " for batch index ", i,
220 ", expected info = 0. Debug_info = ",
221 host_lapack_info.debug_info());
222 }
223 }
224 if (!status.ok()) {
225 break;
226 }
227 }
228 // Delete solver to release temp tensor refs.
229 delete solver;
230
231 // Delegate further error checking to provided functor.
232 info_checker_callback(status, host_lapack_infos);
233 };
234 // Note: An std::function cannot have unique_ptr arguments (it must be copy
235 // constructible and therefore so must its arguments). Therefore, we release
236 // solver into a raw pointer to be deleted at the end of
237 // wrapped_info_checker_callback.
238 // Release ownership of solver. It will be deleted in the cb callback.
239 auto solver_raw_ptr = solver.release();
240 auto cb =
241 std::bind(wrapped_info_checker_callback, solver_raw_ptr,
242 std::move(info_checker_callback), std::move(host_lapack_infos));
243
244 solver_raw_ptr->context()
245 ->device()
246 ->tensorflow_gpu_device_info()
247 ->event_mgr->ThenExecute(stream, std::move(cb));
248 }
249
250 // static
CheckLapackInfoAndDeleteSolverAsync(std::unique_ptr<CudaSolver> solver,const std::vector<DeviceLapackInfo> & dev_lapack_info,AsyncOpKernel::DoneCallback done)251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
252 std::unique_ptr<CudaSolver> solver,
253 const std::vector<DeviceLapackInfo>& dev_lapack_info,
254 AsyncOpKernel::DoneCallback done) {
255 OpKernelContext* context = solver->context();
256 auto wrapped_done = [context, done](
257 const Status& status,
258 const std::vector<HostLapackInfo>& /* unused */) {
259 if (done != nullptr) {
260 OP_REQUIRES_OK_ASYNC(context, status, done);
261 done();
262 } else {
263 OP_REQUIRES_OK(context, status);
264 }
265 };
266 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
267 wrapped_done);
268 }
269
270 // Allocates a temporary tensor. The CudaSolver object maintains a
271 // TensorReference to the underlying Tensor to prevent it from being deallocated
272 // prematurely.
allocate_scoped_tensor(DataType type,const TensorShape & shape,Tensor * out_temp)273 Status CudaSolver::allocate_scoped_tensor(DataType type,
274 const TensorShape& shape,
275 Tensor* out_temp) {
276 const Status status = context_->allocate_temp(type, shape, out_temp);
277 if (status.ok()) {
278 scratch_tensor_refs_.emplace_back(*out_temp);
279 }
280 return status;
281 }
282
forward_input_or_allocate_scoped_tensor(gtl::ArraySlice<int> candidate_input_indices,DataType type,const TensorShape & shape,Tensor * out_temp)283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
284 gtl::ArraySlice<int> candidate_input_indices, DataType type,
285 const TensorShape& shape, Tensor* out_temp) {
286 const Status status = context_->forward_input_or_allocate_temp(
287 candidate_input_indices, type, shape, out_temp);
288 if (status.ok()) {
289 scratch_tensor_refs_.emplace_back(*out_temp);
290 }
291 return status;
292 }
293
294 // Macro that specializes a solver method for all 4 standard
295 // numeric types.
296 #define TF_CALL_LAPACK_TYPES(m) \
297 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
299
300 // Macros to construct cusolverDn method names.
301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
303 #define DN_BUFSIZE_FN(method, type_prefix) \
304 cusolverDn##type_prefix##method##_bufferSize
305
306 // Macros to construct cublas method names.
307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
309
310 //=============================================================================
311 // Wrappers of cuSolverDN computational methods begin here.
312 //
313 // WARNING to implementers: The function signatures listed in the online docs
314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
315 // to immutable arguments, while the actual headers have them as expected.
316 // Check the actual declarations in the cusolver_api.h header file.
317 //
318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
319 // so we put a global lock around the calls. Since these functions only put a
320 // kernel on the shared stream, it is not a big performance hit.
321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
322 //=============================================================================
323
324 template <typename Scalar, typename SolverFnT>
GeamImpl(SolverFnT solver,cublasHandle_t cublas_handle,cublasOperation_t transa,cublasOperation_t transb,int m,int n,const Scalar * alpha,const Scalar * A,int lda,const Scalar * beta,const Scalar * B,int ldb,Scalar * C,int ldc)325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
326 cublasOperation_t transa,
327 cublasOperation_t transb, int m, int n,
328 const Scalar* alpha, /* host or device pointer */
329 const Scalar* A, int lda,
330 const Scalar* beta, /* host or device pointer */
331 const Scalar* B, int ldb, Scalar* C, int ldc) {
332 mutex_lock lock(handle_map_mutex);
333 using CudaScalar = typename CUDAComplexT<Scalar>::type;
334 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
335 reinterpret_cast<const CudaScalar*>(alpha),
336 reinterpret_cast<const CudaScalar*>(A), lda,
337 reinterpret_cast<const CudaScalar*>(beta),
338 reinterpret_cast<const CudaScalar*>(B), ldb,
339 reinterpret_cast<CudaScalar*>(C), ldc));
340 return Status::OK();
341 }
342
343 #define GEAM_INSTANCE(Scalar, type_prefix) \
344 template <> \
345 Status CudaSolver::Geam<Scalar>( \
346 cublasOperation_t transa, cublasOperation_t transb, int m, int n, \
347 const Scalar* alpha, /* host or device pointer */ \
348 const Scalar* A, int lda, \
349 const Scalar* beta, /* host or device pointer */ \
350 const Scalar* B, int ldb, Scalar* C, int ldc) const { \
351 return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
352 transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \
353 }
354
355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
356
357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
PotrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasFillMode_t uplo,int n,Scalar * A,int lda,int * dev_lapack_info)358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
359 CudaSolver* cuda_solver,
360 OpKernelContext* context,
361 cusolverDnHandle_t cusolver_dn_handle,
362 cublasFillMode_t uplo, int n, Scalar* A, int lda,
363 int* dev_lapack_info) {
364 mutex_lock lock(handle_map_mutex);
365 /* Get amount of workspace memory required. */
366 int lwork;
367 TF_RETURN_IF_CUSOLVER_ERROR(
368 bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
369 /* Allocate device memory for workspace. */
370 auto dev_workspace =
371 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
372 /* Launch the solver kernel. */
373 TF_RETURN_IF_CUSOLVER_ERROR(solver(
374 cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
375 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
376 return Status::OK();
377 }
378
379 #define POTRF_INSTANCE(Scalar, type_prefix) \
380 template <> \
381 Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \
382 int lda, int* dev_lapack_info) { \
383 return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \
384 DN_SOLVER_FN(potrf, type_prefix), this, context_, \
385 cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
386 }
387
388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
389
390 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GetrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,int * dev_pivots,int * dev_lapack_info)391 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
392 CudaSolver* cuda_solver,
393 OpKernelContext* context,
394 cusolverDnHandle_t cusolver_dn_handle, int m,
395 int n, Scalar* A, int lda, int* dev_pivots,
396 int* dev_lapack_info) {
397 mutex_lock lock(handle_map_mutex);
398 /* Get amount of workspace memory required. */
399 int lwork;
400 TF_RETURN_IF_CUSOLVER_ERROR(
401 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
402 /* Allocate device memory for workspace. */
403 auto dev_workspace =
404 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
405 /* Launch the solver kernel. */
406 TF_RETURN_IF_CUSOLVER_ERROR(solver(
407 cusolver_dn_handle, m, n, CUDAComplex(A), lda,
408 CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
409 return Status::OK();
410 }
411
412 #define GETRF_INSTANCE(Scalar, type_prefix) \
413 template <> \
414 Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \
415 int* dev_pivots, int* dev_lapack_info) { \
416 return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \
417 DN_SOLVER_FN(getrf, type_prefix), this, context_, \
418 cusolver_dn_handle_, m, n, A, lda, dev_pivots, \
419 dev_lapack_info); \
420 }
421
422 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
423
424 template <typename Scalar, typename SolverFnT>
GetrsImpl(SolverFnT solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * A,int lda,const int * pivots,Scalar * B,int ldb,int * dev_lapack_info)425 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
426 cusolverDnHandle_t cusolver_dn_handle,
427 cublasOperation_t trans, int n, int nrhs,
428 const Scalar* A, int lda, const int* pivots,
429 Scalar* B, int ldb, int* dev_lapack_info) {
430 mutex_lock lock(handle_map_mutex);
431 /* Launch the solver kernel. */
432 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
433 CUDAComplex(A), lda, pivots,
434 CUDAComplex(B), ldb, dev_lapack_info));
435 return Status::OK();
436 }
437
438 #define GETRS_INSTANCE(Scalar, type_prefix) \
439 template <> \
440 Status CudaSolver::Getrs<Scalar>( \
441 cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \
442 const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \
443 return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \
444 cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
445 ldb, dev_lapack_info); \
446 }
447
448 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
449
450 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GeqrfImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,Scalar * A,int lda,Scalar * tau,int * dev_lapack_info)451 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
452 CudaSolver* cuda_solver,
453 OpKernelContext* context,
454 cusolverDnHandle_t cusolver_dn_handle, int m,
455 int n, Scalar* A, int lda, Scalar* tau,
456 int* dev_lapack_info) {
457 mutex_lock lock(handle_map_mutex);
458 /* Get amount of workspace memory required. */
459 int lwork;
460 TF_RETURN_IF_CUSOLVER_ERROR(
461 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
462 /* Allocate device memory for workspace. */
463 auto dev_workspace =
464 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
465 /* Launch the solver kernel. */
466 TF_RETURN_IF_CUSOLVER_ERROR(solver(
467 cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
468 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
469 return Status::OK();
470 }
471
472 #define GEQRF_INSTANCE(Scalar, type_prefix) \
473 template <> \
474 Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \
475 Scalar* tau, int* dev_lapack_info) { \
476 return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \
477 DN_SOLVER_FN(geqrf, type_prefix), this, context_, \
478 cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
479 }
480
481 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
482
483 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UnmqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cublasSideMode_t side,cublasOperation_t trans,int m,int n,int k,const Scalar * dev_a,int lda,const Scalar * dev_tau,Scalar * dev_c,int ldc,int * dev_lapack_info)484 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
485 CudaSolver* cuda_solver,
486 OpKernelContext* context,
487 cusolverDnHandle_t cusolver_dn_handle,
488 cublasSideMode_t side, cublasOperation_t trans,
489 int m, int n, int k, const Scalar* dev_a,
490 int lda, const Scalar* dev_tau, Scalar* dev_c,
491 int ldc, int* dev_lapack_info) {
492 mutex_lock lock(handle_map_mutex);
493 /* Get amount of workspace memory required. */
494 int lwork;
495 TF_RETURN_IF_CUSOLVER_ERROR(
496 bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
497 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
498 /* Allocate device memory for workspace. */
499 auto dev_workspace =
500 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
501 /* Launch the solver kernel. */
502 TF_RETURN_IF_CUSOLVER_ERROR(solver(
503 cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
504 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
505 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
506 return Status::OK();
507 }
508
509 // Unfortunately the LAPACK function name differs for the real and complex case
510 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
511 // one separately.
512 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \
513 template <> \
514 Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \
515 int m, int n, int k, const Scalar* dev_a, int lda, \
516 const Scalar* dev_tau, Scalar* dev_c, int ldc, \
517 int* dev_lapack_info) { \
518 return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \
519 DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \
520 context_, cusolver_dn_handle_, side, trans, m, n, k, \
521 dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \
522 }
523
524 UNMQR_INSTANCE(float, or, S);
525 UNMQR_INSTANCE(double, or, D);
526 UNMQR_INSTANCE(complex64, un, C);
527 UNMQR_INSTANCE(complex128, un, Z);
528
529 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
UngqrImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,int m,int n,int k,Scalar * dev_a,int lda,const Scalar * dev_tau,int * dev_lapack_info)530 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
531 CudaSolver* cuda_solver,
532 OpKernelContext* context,
533 cusolverDnHandle_t cusolver_dn_handle, int m,
534 int n, int k, Scalar* dev_a, int lda,
535 const Scalar* dev_tau, int* dev_lapack_info) {
536 mutex_lock lock(handle_map_mutex);
537 /* Get amount of workspace memory required. */
538 int lwork;
539 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
540 CUDAComplex(dev_a), lda,
541 CUDAComplex(dev_tau), &lwork));
542 /* Allocate device memory for workspace. */
543 auto dev_workspace =
544 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
545 /* Launch the solver kernel. */
546 TF_RETURN_IF_CUSOLVER_ERROR(
547 solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
548 CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
549 lwork, dev_lapack_info));
550 return Status::OK();
551 }
552
553 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \
554 template <> \
555 Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \
556 const Scalar* dev_tau, int* dev_lapack_info) { \
557 return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \
558 DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
559 context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \
560 dev_tau, dev_lapack_info); \
561 }
562
563 UNGQR_INSTANCE(float, or, S);
564 UNGQR_INSTANCE(double, or, D);
565 UNGQR_INSTANCE(complex64, un, C);
566 UNGQR_INSTANCE(complex128, un, Z);
567
568 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
HeevdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,Scalar * dev_A,int lda,typename Eigen::NumTraits<Scalar>::Real * dev_W,int * dev_lapack_info)569 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
570 CudaSolver* cuda_solver,
571 OpKernelContext* context,
572 cusolverDnHandle_t cusolver_dn_handle,
573 cusolverEigMode_t jobz, cublasFillMode_t uplo,
574 int n, Scalar* dev_A, int lda,
575 typename Eigen::NumTraits<Scalar>::Real* dev_W,
576 int* dev_lapack_info) {
577 mutex_lock lock(handle_map_mutex);
578 /* Get amount of workspace memory required. */
579 int lwork;
580 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
581 CUDAComplex(dev_A), lda,
582 CUDAComplex(dev_W), &lwork));
583 /* Allocate device memory for workspace. */
584 auto dev_workspace =
585 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
586 /* Launch the solver kernel. */
587 TF_RETURN_IF_CUSOLVER_ERROR(
588 solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
589 CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
590 lwork, dev_lapack_info));
591 return Status::OK();
592 }
593
594 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \
595 template <> \
596 Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \
597 int n, Scalar* dev_A, int lda, \
598 typename Eigen::NumTraits<Scalar>::Real* dev_W, \
599 int* dev_lapack_info) { \
600 return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \
601 DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \
602 context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
603 dev_W, dev_lapack_info); \
604 }
605
606 HEEVD_INSTANCE(float, sy, S);
607 HEEVD_INSTANCE(double, sy, D);
608 HEEVD_INSTANCE(complex64, he, C);
609 HEEVD_INSTANCE(complex128, he, Z);
610
611 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,signed char jobu,signed char jobvt,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * VT,int ldvt,int * dev_lapack_info)612 static inline Status GesvdImpl(
613 BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
614 OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
615 signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
616 Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
617 mutex_lock lock(handle_map_mutex);
618 /* Get amount of workspace memory required. */
619 int lwork;
620 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
621 /* Allocate device memory for workspace. */
622 auto dev_workspace =
623 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
624 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
625 CUDAComplex(A), lda, S, CUDAComplex(U),
626 ldu, CUDAComplex(VT), ldvt,
627 CUDAComplex(dev_workspace.mutable_data()),
628 lwork, nullptr, dev_lapack_info));
629 return Status::OK();
630 }
631
632 #define GESVD_INSTANCE(Scalar, type_prefix) \
633 template <> \
634 Status CudaSolver::Gesvd<Scalar>( \
635 signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \
636 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \
637 int ldvt, int* dev_lapack_info) { \
638 return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \
639 DN_SOLVER_FN(gesvd, type_prefix), this, context_, \
640 cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
641 dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \
642 }
643
644 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
645
646 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
GesvdjBatchedImpl(BufSizeFnT bufsize,SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cusolverDnHandle_t cusolver_dn_handle,cusolverEigMode_t jobz,int m,int n,Scalar * A,int lda,Scalar * S,Scalar * U,int ldu,Scalar * V,int ldv,int * dev_lapack_info,int batch_size)647 static inline Status GesvdjBatchedImpl(BufSizeFnT bufsize, SolverFnT solver,
648 CudaSolver* cuda_solver,
649 OpKernelContext* context,
650 cusolverDnHandle_t cusolver_dn_handle,
651 cusolverEigMode_t jobz, int m, int n,
652 Scalar* A, int lda, Scalar* S, Scalar* U,
653 int ldu, Scalar* V, int ldv,
654 int* dev_lapack_info, int batch_size) {
655 mutex_lock lock(handle_map_mutex);
656 /* Get amount of workspace memory required. */
657 int lwork;
658 /* Default parameters for gesvdj and gesvdjBatched. */
659 gesvdjInfo_t svdj_info;
660 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnCreateGesvdjInfo(&svdj_info));
661 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(
662 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
663 ldu, CUDAComplex(V), ldv, &lwork, svdj_info, batch_size));
664 /* Allocate device memory for workspace. */
665 auto dev_workspace =
666 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
667 TF_RETURN_IF_CUSOLVER_ERROR(solver(
668 cusolver_dn_handle, jobz, m, n, CUDAComplex(A), lda, S, CUDAComplex(U),
669 ldu, CUDAComplex(V), ldv, CUDAComplex(dev_workspace.mutable_data()),
670 lwork, dev_lapack_info, svdj_info, batch_size));
671 TF_RETURN_IF_CUSOLVER_ERROR(cusolverDnDestroyGesvdjInfo(svdj_info));
672 return Status::OK();
673 }
674
675 #define GESVDJBATCHED_INSTANCE(Scalar, type_prefix) \
676 template <> \
677 Status CudaSolver::GesvdjBatched<Scalar>( \
678 cusolverEigMode_t jobz, int m, int n, Scalar* dev_A, int lda, \
679 Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_V, int ldv, \
680 int* dev_lapack_info, int batch_size) { \
681 return GesvdjBatchedImpl(DN_BUFSIZE_FN(gesvdjBatched, type_prefix), \
682 DN_SOLVER_FN(gesvdjBatched, type_prefix), this, \
683 context_, cusolver_dn_handle_, jobz, m, n, dev_A, \
684 lda, dev_S, dev_U, ldu, dev_V, ldv, \
685 dev_lapack_info, batch_size); \
686 }
687
688 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVDJBATCHED_INSTANCE);
689
690 //=============================================================================
691 // Wrappers of cuBlas computational methods begin here.
692 //
693 // WARNING to implementers: The function signatures listed in the online docs
694 // are sometimes inaccurate, e.g., are missing 'const' on pointers
695 // to immutable arguments, while the actual headers have them as expected.
696 // Check the actual declarations in the cublas_api.h header file.
697 //=============================================================================
698 template <typename Scalar, typename SolverFnT>
GetrfBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,int * dev_pivots,DeviceLapackInfo * dev_lapack_info,int batch_size)699 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
700 OpKernelContext* context,
701 cublasHandle_t cublas_handle, int n,
702 const Scalar* const host_a_dev_ptrs[],
703 int lda, int* dev_pivots,
704 DeviceLapackInfo* dev_lapack_info,
705 int batch_size) {
706 mutex_lock lock(handle_map_mutex);
707 using CudaScalar = typename CUDAComplexT<Scalar>::type;
708 ScratchSpace<uint8> dev_a_dev_ptrs =
709 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
710 /* on_host */ false);
711 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
712 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
713 return errors::Internal("GetrfBatched: failed to copy pointers to device");
714 }
715 TF_RETURN_IF_CUBLAS_ERROR(
716 solver(cublas_handle, n,
717 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
718 dev_pivots, dev_lapack_info->mutable_data(), batch_size));
719 return Status::OK();
720 }
721
722 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \
723 template <> \
724 Status CudaSolver::GetrfBatched( \
725 int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \
726 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
727 return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \
728 context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
729 dev_pivots, dev_lapack_info, batch_size); \
730 }
731
732 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
733
734 template <typename Scalar, typename SolverFnT>
GetrsBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,cublasOperation_t trans,int n,int nrhs,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_b_dev_ptrs[],int ldb,int * host_lapack_info,int batch_size)735 static inline Status GetrsBatchedImpl(
736 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
737 cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
738 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
739 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info,
740 int batch_size) {
741 mutex_lock lock(handle_map_mutex);
742 using CudaScalar = typename CUDAComplexT<Scalar>::type;
743 ScratchSpace<uint8> dev_a_dev_ptrs =
744 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
745 /* on_host */ false);
746 ScratchSpace<uint8> dev_b_dev_ptrs =
747 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
748 /* on_host */ false);
749 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
750 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
751 return errors::Internal("GetrsBatched: failed to copy pointers to device");
752 }
753 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
754 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
755 return errors::Internal("GetrsBatched: failed to copy pointers to device");
756 }
757 TF_RETURN_IF_CUBLAS_ERROR(solver(
758 cublas_handle, trans, n, nrhs,
759 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
760 dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
761 ldb, host_lapack_info, batch_size));
762 return Status::OK();
763 }
764
765 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \
766 template <> \
767 Status CudaSolver::GetrsBatched( \
768 cublasOperation_t trans, int n, int nrhs, \
769 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \
770 const Scalar* const host_b_dev_ptrs[], int ldb, int* host_lapack_info, \
771 int batch_size) { \
772 return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \
773 BLAS_SOLVER_FN(getrsBatched, type_prefix)), \
774 this, context_, cublas_handle_, trans, n, nrhs, \
775 host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
776 ldb, host_lapack_info, batch_size); \
777 }
778
779 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
780
781 template <typename Scalar, typename SolverFnT>
GetriBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const int * dev_pivots,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)782 static inline Status GetriBatchedImpl(
783 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
784 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
785 int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
786 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
787 mutex_lock lock(handle_map_mutex);
788 using CudaScalar = typename CUDAComplexT<Scalar>::type;
789 ScratchSpace<uint8> dev_a_dev_ptrs =
790 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
791 /* on_host */ false);
792 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
793 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
794 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
795 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
796 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
797 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
798 return errors::Internal("GetriBatched: failed to copy pointers to device");
799 }
800 TF_RETURN_IF_CUBLAS_ERROR(
801 solver(cublas_handle, n,
802 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
803 lda, dev_pivots,
804 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
805 ldainv, dev_lapack_info->mutable_data(), batch_size));
806 return Status::OK();
807 }
808
809 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \
810 template <> \
811 Status CudaSolver::GetriBatched( \
812 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
813 const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \
814 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \
815 return GetriBatchedImpl( \
816 reinterpret_cast<getri_##type_prefix*>( \
817 BLAS_SOLVER_FN(getriBatched, type_prefix)), \
818 this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
819 host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \
820 }
821
822 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
823
824 template <typename Scalar, typename SolverFnT>
MatInvBatchedImpl(SolverFnT solver,CudaSolver * cuda_solver,OpKernelContext * context,cublasHandle_t cublas_handle,int n,const Scalar * const host_a_dev_ptrs[],int lda,const Scalar * const host_a_inv_dev_ptrs[],int ldainv,DeviceLapackInfo * dev_lapack_info,int batch_size)825 static inline Status MatInvBatchedImpl(
826 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
827 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
828 int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
829 DeviceLapackInfo* dev_lapack_info, int batch_size) {
830 mutex_lock lock(handle_map_mutex);
831 using CudaScalar = typename CUDAComplexT<Scalar>::type;
832 ScratchSpace<uint8> dev_a_dev_ptrs =
833 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
834 /* on_host */ false);
835 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
836 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
837 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
838 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
839 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
840 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
841 return errors::Internal("MatInvBatched: failed to copy pointers to device");
842 }
843 TF_RETURN_IF_CUBLAS_ERROR(solver(
844 cublas_handle, n,
845 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
846 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
847 dev_lapack_info->mutable_data(), batch_size));
848 return Status::OK();
849 }
850
851 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \
852 template <> \
853 Status CudaSolver::MatInvBatched( \
854 int n, const Scalar* const host_a_dev_ptrs[], int lda, \
855 const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \
856 DeviceLapackInfo* dev_lapack_info, int batch_size) { \
857 return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \
858 BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
859 this, context_, cublas_handle_, n, \
860 host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \
861 ldainv, dev_lapack_info, batch_size); \
862 }
863
864 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
865
866 } // namespace tensorflow
867
868 #endif // GOOGLE_CUDA
869