1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
18
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24 #include <functional>
25 #include <vector>
26
27 #if GOOGLE_CUDA
28
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #include "third_party/gpus/cuda/include/cusparse.h"
31
32 using gpusparseStatus_t = cusparseStatus_t;
33 using gpusparseOperation_t = cusparseOperation_t;
34 using gpusparseMatDescr_t = cusparseMatDescr_t;
35 using gpusparseAction_t = cusparseAction_t;
36 using gpusparseHandle_t = cusparseHandle_t;
37 using gpuStream_t = cudaStream_t;
38 #if CUDA_VERSION >= 10020
39 using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
40 using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
41 using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
42 #endif
43
44 #define GPUSPARSE(postfix) CUSPARSE_##postfix
45 #define gpusparse(postfix) cusparse##postfix
46
47 #elif TENSORFLOW_USE_ROCM
48
49 #include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
50
51 using gpusparseStatus_t = hipsparseStatus_t;
52 using gpusparseOperation_t = hipsparseOperation_t;
53 using gpusparseMatDescr_t = hipsparseMatDescr_t;
54 using gpusparseAction_t = hipsparseAction_t;
55 using gpusparseHandle_t = hipsparseHandle_t;
56 using gpuStream_t = hipStream_t;
57
58 #define GPUSPARSE(postfix) HIPSPARSE_##postfix
59 #define gpusparse(postfix) hipsparse##postfix
60
61 #endif
62
63 #include "tensorflow/core/framework/op_kernel.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/tensor_types.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/platform/stream_executor.h"
68 #include "tensorflow/core/public/version.h"
69
70 // Macro that specializes a sparse method for all 4 standard
71 // numeric types.
72 // TODO: reuse with cuda_solvers
73 #define TF_CALL_LAPACK_TYPES(m) \
74 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
75
76 namespace tensorflow {
77
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)78 inline std::string ConvertGPUSparseErrorToString(
79 const gpusparseStatus_t status) {
80 switch (status) {
81 #define STRINGIZE(q) #q
82 #define RETURN_IF_STATUS(err) \
83 case err: \
84 return STRINGIZE(err);
85
86 #if GOOGLE_CUDA
87
88 RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
89 RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
90 RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
91 RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
92 RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
93 RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
94 RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
95 RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
96 RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
97
98 default:
99 return strings::StrCat("Unknown CUSPARSE error: ",
100 static_cast<int>(status));
101 #elif TENSORFLOW_USE_ROCM
102
103 RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
104 RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
105 RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
106 RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
107 RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
108 RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
109 RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
110 RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
111 RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
112 RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
113
114 default:
115 return strings::StrCat("Unknown hipSPARSE error: ",
116 static_cast<int>(status));
117 #endif
118
119 #undef RETURN_IF_STATUS
120 #undef STRINGIZE
121 }
122 }
123
124 #if GOOGLE_CUDA
125
126 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
127 do { \
128 auto status = (expr); \
129 if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \
130 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
131 "): cuSparse call failed with status ", \
132 ConvertGPUSparseErrorToString(status)); \
133 } \
134 } while (0)
135
136 #elif TENSORFLOW_USE_ROCM
137
138 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \
139 do { \
140 auto status = (expr); \
141 if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \
142 return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
143 "): hipSPARSE call failed with status ", \
144 ConvertGPUSparseErrorToString(status)); \
145 } \
146 } while (0)
147
148 #endif
149
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)150 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
151 bool conjugate,
152 Status* status) {
153 #if GOOGLE_CUDA
154 if (transpose) {
155 return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
156 : CUSPARSE_OPERATION_TRANSPOSE;
157 } else {
158 if (conjugate) {
159 DCHECK(status != nullptr);
160 *status = errors::InvalidArgument(
161 "Conjugate == True and transpose == False is not supported.");
162 }
163 return CUSPARSE_OPERATION_NON_TRANSPOSE;
164 }
165 #elif TENSORFLOW_USE_ROCM
166 if (transpose) {
167 return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
168 : HIPSPARSE_OPERATION_TRANSPOSE;
169 } else {
170 if (conjugate) {
171 DCHECK(status != nullptr);
172 *status = errors::InvalidArgument(
173 "Conjugate == True and transpose == False is not supported.");
174 }
175 return HIPSPARSE_OPERATION_NON_TRANSPOSE;
176 }
177 #endif
178 }
179
180 // The GpuSparse class provides a simplified templated API for cuSparse
181 // (http://docs.nvidia.com/cuda/cusparse/index.html).
182 // An object of this class wraps static cuSparse instances,
183 // and will launch Cuda kernels on the stream wrapped by the GPU device
184 // in the OpKernelContext provided to the constructor.
185 //
186 // Notice: All the computational member functions are asynchronous and simply
187 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
188 // object.
189
190 class GpuSparse {
191 public:
192 // This object stores a pointer to context, which must outlive it.
193 explicit GpuSparse(OpKernelContext* context);
~GpuSparse()194 virtual ~GpuSparse() {}
195
196 // This initializes the GpuSparse class if it hasn't
197 // been initialized yet. All following public methods require the
198 // class has been initialized. Can be run multiple times; all
199 // subsequent calls after the first have no effect.
200 Status Initialize(); // Move to constructor?
201
202 // ====================================================================
203 // Wrappers for cuSparse start here.
204 //
205
206 // Solves tridiagonal system of equations.
207 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
208 template <typename Scalar>
209 Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
210 const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
211
212 // Computes the size of a temporary buffer used by Gtsv2.
213 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
214 template <typename Scalar>
215 Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
216 const Scalar* du, const Scalar* B, int ldb,
217 size_t* bufferSizeInBytes) const;
218
219 // Solves tridiagonal system of equations without partial pivoting.
220 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
221 template <typename Scalar>
222 Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
223 const Scalar* du, Scalar* B, int ldb,
224 void* pBuffer) const;
225
226 // Computes the size of a temporary buffer used by Gtsv2NoPivot.
227 // See:
228 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
229 template <typename Scalar>
230 Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
231 const Scalar* d, const Scalar* du,
232 const Scalar* B, int ldb,
233 size_t* bufferSizeInBytes) const;
234
235 // Solves a batch of tridiagonal systems of equations. Doesn't support
236 // multiple right-hand sides per each system. Doesn't do pivoting.
237 // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
238 template <typename Scalar>
239 Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
240 const Scalar* du, Scalar* x, int batchCount,
241 int batchStride, void* pBuffer) const;
242
243 // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
244 // See:
245 // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
246 template <typename Scalar>
247 Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
248 const Scalar* d, const Scalar* du,
249 const Scalar* x, int batchCount,
250 int batchStride,
251 size_t* bufferSizeInBytes) const;
252
253 // Compresses the indices of rows or columns. It can be interpreted as a
254 // conversion from COO to CSR sparse storage format. See:
255 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
256 Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
257
258 // Uncompresses the indices of rows or columns. It can be interpreted as a
259 // conversion from CSR to COO sparse storage format. See:
260 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
261 Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
262
263 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
264 // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
265 // where A is a sparse matrix in CSR format, B and C are dense tall
266 // matrices. This routine allows transposition of matrix B, which
267 // may improve performance. See:
268 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
269 //
270 // **NOTE** Matrices B and C are expected to be in column-major
271 // order; to make them consistent with TensorFlow they
272 // must be transposed (or the matmul op's pre/post-processing must take this
273 // into account).
274 //
275 // **NOTE** This is an in-place operation for data in C.
276 template <typename Scalar>
277 Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
278 int n, int k, int nnz, const Scalar* alpha_host,
279 const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
280 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
281 const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
282 int ldc) const;
283 #else
284 // Workspace size query for sparse-dense matrix multiplication. Helper
285 // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
286 // where A is a sparse matrix in CSR format, B and C are dense matricies in
287 // column-major format. Returns needed workspace size in bytes.
288 template <typename Scalar>
289 Status SpMMBufferSize(gpusparseOperation_t transA,
290 gpusparseOperation_t transB, const Scalar* alpha,
291 const gpusparseSpMatDescr_t matA,
292 const gpusparseDnMatDescr_t matB, const Scalar* beta,
293 gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
294 size_t* bufferSize) const;
295
296 // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
297 // where A is a sparse matrix in CSR format, B and C are dense matricies in
298 // column-major format. Buffer is assumed to be at least as large as the
299 // workspace size returned by SpMMBufferSize().
300 //
301 // **NOTE** This is an in-place operation for data in C.
302 template <typename Scalar>
303 Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
304 const Scalar* alpha, const gpusparseSpMatDescr_t matA,
305 const gpusparseDnMatDescr_t matB, const Scalar* beta,
306 gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
307 int8* buffer) const;
308 #endif
309
310 // Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
311 // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
312 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
313 //
314 // **NOTE** This is an in-place operation for data in y.
315 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
316 template <typename Scalar>
317 Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
318 const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
319 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
320 const int* csrSortedColIndA, const Scalar* x,
321 const Scalar* beta_host, Scalar* y) const;
322 #else
323 template <typename Scalar>
324 Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
325 const Scalar* alpha_host, const Scalar* csrSortedValA,
326 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
327 const Scalar* x, const Scalar* beta_host, Scalar* y) const;
328 #endif // CUDA_VERSION < 10020
329
330 // Computes workspace size for sparse - sparse matrix addition of matrices
331 // stored in CSR format.
332 template <typename Scalar>
333 Status CsrgeamBufferSizeExt(
334 int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
335 int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
336 const int* csrSortedColIndA, const Scalar* beta,
337 const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
338 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
339 const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
340 int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
341
342 // Computes sparse-sparse matrix addition of matrices
343 // stored in CSR format. This is part one: calculate nnz of the
344 // output. csrSortedRowPtrC must be preallocated on device with
345 // m + 1 entries. See:
346 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
347 Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
348 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
349 const gpusparseMatDescr_t descrB, int nnzB,
350 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
351 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
352 int* nnzTotalDevHostPtr, void* workspace);
353
354 // Computes sparse - sparse matrix addition of matrices
355 // stored in CSR format. This is part two: perform sparse-sparse
356 // addition. csrValC and csrColIndC must be allocated on the device
357 // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz). See:
358 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
359 template <typename Scalar>
360 Status Csrgeam(int m, int n, const Scalar* alpha,
361 const gpusparseMatDescr_t descrA, int nnzA,
362 const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
363 const int* csrSortedColIndA, const Scalar* beta,
364 const gpusparseMatDescr_t descrB, int nnzB,
365 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
366 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
367 Scalar* csrSortedValC, int* csrSortedRowPtrC,
368 int* csrSortedColIndC, void* workspace);
369
370 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
371 // Computes sparse-sparse matrix multiplication of matrices
372 // stored in CSR format. This is part zero: calculate required workspace
373 // size.
374 template <typename Scalar>
375 Status CsrgemmBufferSize(
376 int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
377 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
378 const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
379 const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
380 #endif
381
382 // Computes sparse-sparse matrix multiplication of matrices
383 // stored in CSR format. This is part one: calculate nnz of the
384 // output. csrSortedRowPtrC must be preallocated on device with
385 // m + 1 entries. See:
386 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
387 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
388 Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
389 int m, int k, int n, const gpusparseMatDescr_t descrA,
390 int nnzA, const int* csrSortedRowPtrA,
391 const int* csrSortedColIndA,
392 const gpusparseMatDescr_t descrB, int nnzB,
393 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
394 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
395 int* nnzTotalDevHostPtr);
396 #else
397 Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
398 int nnzA, const int* csrSortedRowPtrA,
399 const int* csrSortedColIndA,
400 const gpusparseMatDescr_t descrB, int nnzB,
401 const int* csrSortedRowPtrB, const int* csrSortedColIndB,
402 const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
403 int* nnzTotalDevHostPtr, csrgemm2Info_t info,
404 void* workspace);
405 #endif
406
407 // Computes sparse - sparse matrix matmul of matrices
408 // stored in CSR format. This is part two: perform sparse-sparse
409 // addition. csrValC and csrColIndC must be allocated on the device
410 // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
411 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
412 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
413 template <typename Scalar>
414 Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
415 int m, int k, int n, const gpusparseMatDescr_t descrA,
416 int nnzA, const Scalar* csrSortedValA,
417 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
418 const gpusparseMatDescr_t descrB, int nnzB,
419 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
420 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
421 Scalar* csrSortedValC, int* csrSortedRowPtrC,
422 int* csrSortedColIndC);
423 #else
424 template <typename Scalar>
425 Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
426 int nnzA, const Scalar* csrSortedValA,
427 const int* csrSortedRowPtrA, const int* csrSortedColIndA,
428 const gpusparseMatDescr_t descrB, int nnzB,
429 const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
430 const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
431 Scalar* csrSortedValC, int* csrSortedRowPtrC,
432 int* csrSortedColIndC, const csrgemm2Info_t info,
433 void* workspace);
434 #endif
435
436 // In-place reordering of unsorted CSR to sorted CSR.
437 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
438 template <typename Scalar>
439 Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
440 Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
441
442 // Converts from CSR to CSC format (equivalently, transpose).
443 // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
444 template <typename Scalar>
445 Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
446 const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
447 int* cscRowInd, int* cscColPtr,
448 const gpusparseAction_t copyValues);
449
450 private:
451 bool initialized_;
452 OpKernelContext* context_; // not owned.
453 gpuStream_t gpu_stream_;
454 gpusparseHandle_t* gpusparse_handle_; // not owned.
455
456 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
457 };
458
459 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
460 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
461 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
462 class GpuSparseMatrixDescriptor {
463 public:
GpuSparseMatrixDescriptor()464 explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
465
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)466 GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
467 : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
468 rhs.initialized_ = false;
469 }
470
471 GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
472 if (this == &rhs) return *this;
473 Release();
474 initialized_ = rhs.initialized_;
475 descr_ = std::move(rhs.descr_);
476 rhs.initialized_ = false;
477 return *this;
478 }
479
~GpuSparseMatrixDescriptor()480 ~GpuSparseMatrixDescriptor() { Release(); }
481
482 // Initializes the underlying descriptor. Will fail on the second call if
483 // called more than once.
Initialize()484 Status Initialize() {
485 DCHECK(!initialized_);
486 #if GOOGLE_CUDA
487 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
488 #elif TENSORFLOW_USE_ROCM
489 TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
490 #endif
491 initialized_ = true;
492 return Status::OK();
493 }
494
descr()495 gpusparseMatDescr_t& descr() {
496 DCHECK(initialized_);
497 return descr_;
498 }
499
descr()500 const gpusparseMatDescr_t& descr() const {
501 DCHECK(initialized_);
502 return descr_;
503 }
504
505 private:
Release()506 void Release() {
507 if (initialized_) {
508 #if GOOGLE_CUDA
509 cusparseDestroyMatDescr(descr_);
510 #elif TENSORFLOW_USE_ROCM
511 wrap::hipsparseDestroyMatDescr(descr_);
512 #endif
513 initialized_ = false;
514 }
515 }
516
517 bool initialized_;
518 gpusparseMatDescr_t descr_;
519
520 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
521 };
522
523 #if GOOGLE_CUDA
524
525 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
526 // struct (csru2csrInfo_t) is initialized only once. See:
527 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
528 class GpuSparseCsrSortingConversionInfo {
529 public:
GpuSparseCsrSortingConversionInfo()530 explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
531
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)532 GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
533 : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
534 rhs.initialized_ = false;
535 }
536
537 GpuSparseCsrSortingConversionInfo& operator=(
538 GpuSparseCsrSortingConversionInfo&& rhs) {
539 if (this == &rhs) return *this;
540 Release();
541 initialized_ = rhs.initialized_;
542 info_ = std::move(rhs.info_);
543 rhs.initialized_ = false;
544 return *this;
545 }
546
~GpuSparseCsrSortingConversionInfo()547 ~GpuSparseCsrSortingConversionInfo() { Release(); }
548
549 // Initializes the underlying info. Will fail on the second call if called
550 // more than once.
Initialize()551 Status Initialize() {
552 DCHECK(!initialized_);
553 TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
554 initialized_ = true;
555 return Status::OK();
556 }
557
info()558 csru2csrInfo_t& info() {
559 DCHECK(initialized_);
560 return info_;
561 }
562
info()563 const csru2csrInfo_t& info() const {
564 DCHECK(initialized_);
565 return info_;
566 }
567
568 private:
Release()569 void Release() {
570 if (initialized_) {
571 cusparseDestroyCsru2csrInfo(info_);
572 initialized_ = false;
573 }
574 }
575
576 bool initialized_;
577 csru2csrInfo_t info_;
578
579 TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
580 };
581
582 #endif // GOOGLE_CUDA
583
584 } // namespace tensorflow
585
586 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
587
588 #endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
589