1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
17 #define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
18 
19 // This header declares the class GpuSparse, which contains wrappers of
20 // cuSparse libraries for use in TensorFlow kernels.
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include <functional>
25 #include <vector>
26 
27 #if GOOGLE_CUDA
28 
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #include "third_party/gpus/cuda/include/cusparse.h"
31 
32 using gpusparseStatus_t = cusparseStatus_t;
33 using gpusparseOperation_t = cusparseOperation_t;
34 using gpusparseMatDescr_t = cusparseMatDescr_t;
35 using gpusparseAction_t = cusparseAction_t;
36 using gpusparseHandle_t = cusparseHandle_t;
37 using gpuStream_t = cudaStream_t;
38 #if CUDA_VERSION >= 10020
39 using gpusparseDnMatDescr_t = cusparseDnMatDescr_t;
40 using gpusparseSpMatDescr_t = cusparseSpMatDescr_t;
41 using gpusparseSpMMAlg_t = cusparseSpMMAlg_t;
42 #endif
43 
44 #define GPUSPARSE(postfix) CUSPARSE_##postfix
45 #define gpusparse(postfix) cusparse##postfix
46 
47 #elif TENSORFLOW_USE_ROCM
48 
49 #include "tensorflow/stream_executor/rocm/hipsparse_wrapper.h"
50 
51 using gpusparseStatus_t = hipsparseStatus_t;
52 using gpusparseOperation_t = hipsparseOperation_t;
53 using gpusparseMatDescr_t = hipsparseMatDescr_t;
54 using gpusparseAction_t = hipsparseAction_t;
55 using gpusparseHandle_t = hipsparseHandle_t;
56 using gpuStream_t = hipStream_t;
57 
58 #define GPUSPARSE(postfix) HIPSPARSE_##postfix
59 #define gpusparse(postfix) hipsparse##postfix
60 
61 #endif
62 
63 #include "tensorflow/core/framework/op_kernel.h"
64 #include "tensorflow/core/framework/tensor.h"
65 #include "tensorflow/core/framework/tensor_types.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/platform/stream_executor.h"
68 #include "tensorflow/core/public/version.h"
69 
70 // Macro that specializes a sparse method for all 4 standard
71 // numeric types.
72 // TODO: reuse with cuda_solvers
73 #define TF_CALL_LAPACK_TYPES(m) \
74   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
75 
76 namespace tensorflow {
77 
ConvertGPUSparseErrorToString(const gpusparseStatus_t status)78 inline std::string ConvertGPUSparseErrorToString(
79     const gpusparseStatus_t status) {
80   switch (status) {
81 #define STRINGIZE(q) #q
82 #define RETURN_IF_STATUS(err) \
83   case err:                   \
84     return STRINGIZE(err);
85 
86 #if GOOGLE_CUDA
87 
88     RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS)
89     RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED)
90     RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED)
91     RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE)
92     RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH)
93     RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR)
94     RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED)
95     RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR)
96     RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
97 
98     default:
99       return strings::StrCat("Unknown CUSPARSE error: ",
100                              static_cast<int>(status));
101 #elif TENSORFLOW_USE_ROCM
102 
103     RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS)
104     RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED)
105     RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED)
106     RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE)
107     RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH)
108     RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR)
109     RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED)
110     RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR)
111     RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED)
112     RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT)
113 
114     default:
115       return strings::StrCat("Unknown hipSPARSE error: ",
116                              static_cast<int>(status));
117 #endif
118 
119 #undef RETURN_IF_STATUS
120 #undef STRINGIZE
121   }
122 }
123 
124 #if GOOGLE_CUDA
125 
126 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
127   do {                                                                     \
128     auto status = (expr);                                                  \
129     if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) {             \
130       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
131                               "): cuSparse call failed with status ",      \
132                               ConvertGPUSparseErrorToString(status));      \
133     }                                                                      \
134   } while (0)
135 
136 #elif TENSORFLOW_USE_ROCM
137 
138 #define TF_RETURN_IF_GPUSPARSE_ERROR(expr)                                 \
139   do {                                                                     \
140     auto status = (expr);                                                  \
141     if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) {            \
142       return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \
143                               "): hipSPARSE call failed with status ",     \
144                               ConvertGPUSparseErrorToString(status));      \
145     }                                                                      \
146   } while (0)
147 
148 #endif
149 
TransposeAndConjugateToGpuSparseOp(bool transpose,bool conjugate,Status * status)150 inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose,
151                                                                bool conjugate,
152                                                                Status* status) {
153 #if GOOGLE_CUDA
154   if (transpose) {
155     return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
156                      : CUSPARSE_OPERATION_TRANSPOSE;
157   } else {
158     if (conjugate) {
159       DCHECK(status != nullptr);
160       *status = errors::InvalidArgument(
161           "Conjugate == True and transpose == False is not supported.");
162     }
163     return CUSPARSE_OPERATION_NON_TRANSPOSE;
164   }
165 #elif TENSORFLOW_USE_ROCM
166   if (transpose) {
167     return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE
168                      : HIPSPARSE_OPERATION_TRANSPOSE;
169   } else {
170     if (conjugate) {
171       DCHECK(status != nullptr);
172       *status = errors::InvalidArgument(
173           "Conjugate == True and transpose == False is not supported.");
174     }
175     return HIPSPARSE_OPERATION_NON_TRANSPOSE;
176   }
177 #endif
178 }
179 
180 // The GpuSparse class provides a simplified templated API for cuSparse
181 // (http://docs.nvidia.com/cuda/cusparse/index.html).
182 // An object of this class wraps static cuSparse instances,
183 // and will launch Cuda kernels on the stream wrapped by the GPU device
184 // in the OpKernelContext provided to the constructor.
185 //
186 // Notice: All the computational member functions are asynchronous and simply
187 // launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse
188 // object.
189 
190 class GpuSparse {
191  public:
192   // This object stores a pointer to context, which must outlive it.
193   explicit GpuSparse(OpKernelContext* context);
~GpuSparse()194   virtual ~GpuSparse() {}
195 
196   // This initializes the GpuSparse class if it hasn't
197   // been initialized yet.  All following public methods require the
198   // class has been initialized.  Can be run multiple times; all
199   // subsequent calls after the first have no effect.
200   Status Initialize();  // Move to constructor?
201 
202   // ====================================================================
203   // Wrappers for cuSparse start here.
204   //
205 
206   // Solves tridiagonal system of equations.
207   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
208   template <typename Scalar>
209   Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
210                const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
211 
212   // Computes the size of a temporary buffer used by Gtsv2.
213   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
214   template <typename Scalar>
215   Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
216                             const Scalar* du, const Scalar* B, int ldb,
217                             size_t* bufferSizeInBytes) const;
218 
219   // Solves tridiagonal system of equations without partial pivoting.
220   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
221   template <typename Scalar>
222   Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
223                       const Scalar* du, Scalar* B, int ldb,
224                       void* pBuffer) const;
225 
226   // Computes the size of a temporary buffer used by Gtsv2NoPivot.
227   // See:
228   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
229   template <typename Scalar>
230   Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
231                                    const Scalar* d, const Scalar* du,
232                                    const Scalar* B, int ldb,
233                                    size_t* bufferSizeInBytes) const;
234 
235   // Solves a batch of tridiagonal systems of equations. Doesn't support
236   // multiple right-hand sides per each system. Doesn't do pivoting.
237   // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
238   template <typename Scalar>
239   Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
240                            const Scalar* du, Scalar* x, int batchCount,
241                            int batchStride, void* pBuffer) const;
242 
243   // Computes the size of a temporary buffer used by Gtsv2StridedBatch.
244   // See:
245   // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
246   template <typename Scalar>
247   Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
248                                         const Scalar* d, const Scalar* du,
249                                         const Scalar* x, int batchCount,
250                                         int batchStride,
251                                         size_t* bufferSizeInBytes) const;
252 
253   // Compresses the indices of rows or columns. It can be interpreted as a
254   // conversion from COO to CSR sparse storage format. See:
255   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
256   Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
257 
258   // Uncompresses the indices of rows or columns. It can be interpreted as a
259   // conversion from CSR to COO sparse storage format. See:
260   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
261   Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
262 
263 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
264   // Sparse-dense matrix multiplication C = alpha * op(A) * op(B)  + beta * C,
265   // where A is a sparse matrix in CSR format, B and C are dense tall
266   // matrices.  This routine allows transposition of matrix B, which
267   // may improve performance.  See:
268   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
269   //
270   // **NOTE** Matrices B and C are expected to be in column-major
271   // order; to make them consistent with TensorFlow they
272   // must be transposed (or the matmul op's pre/post-processing must take this
273   // into account).
274   //
275   // **NOTE** This is an in-place operation for data in C.
276   template <typename Scalar>
277   Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m,
278                int n, int k, int nnz, const Scalar* alpha_host,
279                const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA,
280                const int* csrSortedRowPtrA, const int* csrSortedColIndA,
281                const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
282                int ldc) const;
283 #else
284   // Workspace size query for sparse-dense matrix multiplication. Helper
285   // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C,
286   // where A is a sparse matrix in CSR format, B and C are dense matricies in
287   // column-major format. Returns needed workspace size in bytes.
288   template <typename Scalar>
289   Status SpMMBufferSize(gpusparseOperation_t transA,
290                         gpusparseOperation_t transB, const Scalar* alpha,
291                         const gpusparseSpMatDescr_t matA,
292                         const gpusparseDnMatDescr_t matB, const Scalar* beta,
293                         gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
294                         size_t* bufferSize) const;
295 
296   // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C,
297   // where A is a sparse matrix in CSR format, B and C are dense matricies in
298   // column-major format. Buffer is assumed to be at least as large as the
299   // workspace size returned by SpMMBufferSize().
300   //
301   // **NOTE** This is an in-place operation for data in C.
302   template <typename Scalar>
303   Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB,
304               const Scalar* alpha, const gpusparseSpMatDescr_t matA,
305               const gpusparseDnMatDescr_t matB, const Scalar* beta,
306               gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg,
307               int8* buffer) const;
308 #endif
309 
310   // Sparse-dense vector multiplication y = alpha * op(A) * x  + beta * y,
311   // where A is a sparse matrix in CSR format, x and y are dense vectors. See:
312   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
313   //
314   // **NOTE** This is an in-place operation for data in y.
315 #if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM
316   template <typename Scalar>
317   Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
318                const Scalar* alpha_host, const gpusparseMatDescr_t descrA,
319                const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
320                const int* csrSortedColIndA, const Scalar* x,
321                const Scalar* beta_host, Scalar* y) const;
322 #else
323   template <typename Scalar>
324   Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz,
325                const Scalar* alpha_host, const Scalar* csrSortedValA,
326                const int* csrSortedRowPtrA, const int* csrSortedColIndA,
327                const Scalar* x, const Scalar* beta_host, Scalar* y) const;
328 #endif  // CUDA_VERSION < 10020
329 
330   // Computes workspace size for sparse - sparse matrix addition of matrices
331   // stored in CSR format.
332   template <typename Scalar>
333   Status CsrgeamBufferSizeExt(
334       int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA,
335       int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
336       const int* csrSortedColIndA, const Scalar* beta,
337       const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
338       const int* csrSortedRowPtrB, const int* csrSortedColIndB,
339       const gpusparseMatDescr_t descrC, Scalar* csrSortedValC,
340       int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize);
341 
342   // Computes sparse-sparse matrix addition of matrices
343   // stored in CSR format.  This is part one: calculate nnz of the
344   // output.  csrSortedRowPtrC must be preallocated on device with
345   // m + 1 entries.  See:
346   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
347   Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA,
348                     const int* csrSortedRowPtrA, const int* csrSortedColIndA,
349                     const gpusparseMatDescr_t descrB, int nnzB,
350                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
351                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
352                     int* nnzTotalDevHostPtr, void* workspace);
353 
354   // Computes sparse - sparse matrix addition of matrices
355   // stored in CSR format.  This is part two: perform sparse-sparse
356   // addition.  csrValC and csrColIndC must be allocated on the device
357   // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz).  See:
358   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
359   template <typename Scalar>
360   Status Csrgeam(int m, int n, const Scalar* alpha,
361                  const gpusparseMatDescr_t descrA, int nnzA,
362                  const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
363                  const int* csrSortedColIndA, const Scalar* beta,
364                  const gpusparseMatDescr_t descrB, int nnzB,
365                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
366                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
367                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
368                  int* csrSortedColIndC, void* workspace);
369 
370 #if GOOGLE_CUDA && (CUDA_VERSION >= 10000)
371   // Computes sparse-sparse matrix multiplication of matrices
372   // stored in CSR format.  This is part zero: calculate required workspace
373   // size.
374   template <typename Scalar>
375   Status CsrgemmBufferSize(
376       int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA,
377       const int* csrSortedRowPtrA, const int* csrSortedColIndA,
378       const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB,
379       const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes);
380 #endif
381 
382   // Computes sparse-sparse matrix multiplication of matrices
383   // stored in CSR format.  This is part one: calculate nnz of the
384   // output.  csrSortedRowPtrC must be preallocated on device with
385   // m + 1 entries.  See:
386   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
387 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
388   Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB,
389                     int m, int k, int n, const gpusparseMatDescr_t descrA,
390                     int nnzA, const int* csrSortedRowPtrA,
391                     const int* csrSortedColIndA,
392                     const gpusparseMatDescr_t descrB, int nnzB,
393                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
394                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
395                     int* nnzTotalDevHostPtr);
396 #else
397   Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA,
398                     int nnzA, const int* csrSortedRowPtrA,
399                     const int* csrSortedColIndA,
400                     const gpusparseMatDescr_t descrB, int nnzB,
401                     const int* csrSortedRowPtrB, const int* csrSortedColIndB,
402                     const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC,
403                     int* nnzTotalDevHostPtr, csrgemm2Info_t info,
404                     void* workspace);
405 #endif
406 
407   // Computes sparse - sparse matrix matmul of matrices
408   // stored in CSR format.  This is part two: perform sparse-sparse
409   // addition.  csrValC and csrColIndC must be allocated on the device
410   // with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz).  See:
411   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
412 #if (GOOGLE_CUDA && (CUDA_VERSION < 10000)) || TENSORFLOW_USE_ROCM
413   template <typename Scalar>
414   Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB,
415                  int m, int k, int n, const gpusparseMatDescr_t descrA,
416                  int nnzA, const Scalar* csrSortedValA,
417                  const int* csrSortedRowPtrA, const int* csrSortedColIndA,
418                  const gpusparseMatDescr_t descrB, int nnzB,
419                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
420                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
421                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
422                  int* csrSortedColIndC);
423 #else
424   template <typename Scalar>
425   Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA,
426                  int nnzA, const Scalar* csrSortedValA,
427                  const int* csrSortedRowPtrA, const int* csrSortedColIndA,
428                  const gpusparseMatDescr_t descrB, int nnzB,
429                  const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
430                  const int* csrSortedColIndB, const gpusparseMatDescr_t descrC,
431                  Scalar* csrSortedValC, int* csrSortedRowPtrC,
432                  int* csrSortedColIndC, const csrgemm2Info_t info,
433                  void* workspace);
434 #endif
435 
436   // In-place reordering of unsorted CSR to sorted CSR.
437   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
438   template <typename Scalar>
439   Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA,
440                   Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
441 
442   // Converts from CSR to CSC format (equivalently, transpose).
443   // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
444   template <typename Scalar>
445   Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
446                  const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
447                  int* cscRowInd, int* cscColPtr,
448                  const gpusparseAction_t copyValues);
449 
450  private:
451   bool initialized_;
452   OpKernelContext* context_;  // not owned.
453   gpuStream_t gpu_stream_;
454   gpusparseHandle_t* gpusparse_handle_;  // not owned.
455 
456   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparse);
457 };
458 
459 // A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
460 // only once. For more details on the descriptor (gpusparseMatDescr_t), see:
461 // https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
462 class GpuSparseMatrixDescriptor {
463  public:
GpuSparseMatrixDescriptor()464   explicit GpuSparseMatrixDescriptor() : initialized_(false) {}
465 
GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor && rhs)466   GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs)
467       : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
468     rhs.initialized_ = false;
469   }
470 
471   GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) {
472     if (this == &rhs) return *this;
473     Release();
474     initialized_ = rhs.initialized_;
475     descr_ = std::move(rhs.descr_);
476     rhs.initialized_ = false;
477     return *this;
478   }
479 
~GpuSparseMatrixDescriptor()480   ~GpuSparseMatrixDescriptor() { Release(); }
481 
482   // Initializes the underlying descriptor.  Will fail on the second call if
483   // called more than once.
Initialize()484   Status Initialize() {
485     DCHECK(!initialized_);
486 #if GOOGLE_CUDA
487     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
488 #elif TENSORFLOW_USE_ROCM
489     TF_RETURN_IF_GPUSPARSE_ERROR(wrap::hipsparseCreateMatDescr(&descr_));
490 #endif
491     initialized_ = true;
492     return Status::OK();
493   }
494 
descr()495   gpusparseMatDescr_t& descr() {
496     DCHECK(initialized_);
497     return descr_;
498   }
499 
descr()500   const gpusparseMatDescr_t& descr() const {
501     DCHECK(initialized_);
502     return descr_;
503   }
504 
505  private:
Release()506   void Release() {
507     if (initialized_) {
508 #if GOOGLE_CUDA
509       cusparseDestroyMatDescr(descr_);
510 #elif TENSORFLOW_USE_ROCM
511       wrap::hipsparseDestroyMatDescr(descr_);
512 #endif
513       initialized_ = false;
514     }
515   }
516 
517   bool initialized_;
518   gpusparseMatDescr_t descr_;
519 
520   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseMatrixDescriptor);
521 };
522 
523 #if GOOGLE_CUDA
524 
525 // A wrapper class to ensure that an unsorted/sorted CSR conversion information
526 // struct (csru2csrInfo_t) is initialized only once. See:
527 // https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
528 class GpuSparseCsrSortingConversionInfo {
529  public:
GpuSparseCsrSortingConversionInfo()530   explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {}
531 
GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo && rhs)532   GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs)
533       : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
534     rhs.initialized_ = false;
535   }
536 
537   GpuSparseCsrSortingConversionInfo& operator=(
538       GpuSparseCsrSortingConversionInfo&& rhs) {
539     if (this == &rhs) return *this;
540     Release();
541     initialized_ = rhs.initialized_;
542     info_ = std::move(rhs.info_);
543     rhs.initialized_ = false;
544     return *this;
545   }
546 
~GpuSparseCsrSortingConversionInfo()547   ~GpuSparseCsrSortingConversionInfo() { Release(); }
548 
549   // Initializes the underlying info. Will fail on the second call if called
550   // more than once.
Initialize()551   Status Initialize() {
552     DCHECK(!initialized_);
553     TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
554     initialized_ = true;
555     return Status::OK();
556   }
557 
info()558   csru2csrInfo_t& info() {
559     DCHECK(initialized_);
560     return info_;
561   }
562 
info()563   const csru2csrInfo_t& info() const {
564     DCHECK(initialized_);
565     return info_;
566   }
567 
568  private:
Release()569   void Release() {
570     if (initialized_) {
571       cusparseDestroyCsru2csrInfo(info_);
572       initialized_ = false;
573     }
574   }
575 
576   bool initialized_;
577   csru2csrInfo_t info_;
578 
579   TF_DISALLOW_COPY_AND_ASSIGN(GpuSparseCsrSortingConversionInfo);
580 };
581 
582 #endif  // GOOGLE_CUDA
583 
584 }  // namespace tensorflow
585 
586 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
587 
588 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
589