1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
18 
19 #define EIGEN_USE_THREADS
20 
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_encode_decode.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 
33 namespace tensorflow {
34 
35 class CSRSparseMatrix {
36   // CreateCSRSparseMatrix is the main method used to construct a
37   // CSRSparseMatrix.  The representations for both 2D and 3D
38   // (batched) CSR Sparse Matrices are the same:
39   //
40   // dtype: The datatype of the values.
41   // dense_shape: The dense shape of the matrix.
42   //   * Host int64 vector, size 2 or 3.
43   //   * Takes on values: (rows, cols) or (batch_size, rows, cols).
44   // batch_pointers: Batch offset pointers into col_indices and values.
45   //   * Host int32 vector, size (batch_size + 1).
46   //   * Takes on values: (0, nnz[0], nnz[0] + nnz[1], ..., total_nnz).
47   // row_pointers: Row offset pointers into col_indices and values.
48   //   * Device int32 vector, size ((rows + 1) * batch_size).
49   //   * Each block of size (rows + 1) takes on values:
50   //     (0, num_rows{b}[0], num_rows{b}[0] + num_rows{b}[1], ..., nnz[b]).
51   //     for b = 0 .. batch_size - 1.
52   // col_indices: Column values for the given row and column index.
53   //   * Device int32 vector, size total_nnz.
54   // values: Actual values for the given row and column index.
55   //   * Device dtype vector, size total_nnz.
56   //
57   // The storage agreement is such that for a given (batch, row, ix):
58   //   offset = batch_pointers(batch) + row_pointers(batch * (rows + 1) + row)
59   //   col = col_indices(offset + ix)
60   //   val = values(offset + ix)
61   // where ix < #nnz columns in (batch, row).
62   // Then:
63   //   matrix(batch, row, col) = val.
64   //
65   // All other elements in the dense representation are treated as 0 / empty.
66   //
67   // For example, for a 2D sparse matrix m shaped (3, 4) such that:
68   //
69   //   m[0, 0] = 1.0
70   //   m[0, 1] = 2.0
71   //   m[0, 2] = 3.0
72   //   m[2, 2] = 4.0
73   //   m[2, 3] = 5.0
74   //
75   // The corresponding representation is:
76   //
77   //   dtype: DT_FLOAT
78   //   dense_shape: (3, 4)
79   //   batch_pointers: (0, 5)
80   //   row_pointers: (0, 3, 3, 5)
81   //   col_indices: concat((0, 1, 2), (), (2, 3))
82   //   values: concat((1.0, 2.0, 3.0), (), (4.0, 5.0))
83   //
84   // For a 3D sparse matrix m shaped (2, 3, 4) such that:
85   //
86   //   m[0, 0, 0] = 1.0
87   //   m[0, 0, 2] = 2.0
88   //   m[0, 2, 3] = 3.0
89   //   m[1, 0, 3] = 4.0
90   //   m[1, 1, 0] = 5.0
91   //
92   // The corresponding representation is:
93   //   dtype: DT_FLOAT
94   //   dense_shape: (2, 3, 4)
95   //   batch_pointers: (0, 3, 5)
96   //   row_pointers: concat((0, 2, 2, 3), (0, 1, 2, 2))
97   //   col_indices: concat(concat((0, 2), (), (3,)),
98   //                       concat((3,),   (), (0,)))
99   //   values: concat(concat((1.0, 2.0), (3.0,), ()),
100   ///                 concat((4.0,),     (5.0,), ()))
101   //
102  public:
103   static constexpr const char kTypeName[] = "tensorflow::CSRSparseMatrix";
104 
CSRSparseMatrix()105   CSRSparseMatrix() : metadata_{false, DT_INVALID} {}
106 
CSRSparseMatrix(const CSRSparseMatrix & rhs)107   CSRSparseMatrix(const CSRSparseMatrix& rhs)
108       : metadata_(rhs.metadata_),
109         dense_shape_(rhs.dense_shape_),
110         batch_pointers_(rhs.batch_pointers_),
111         row_pointers_(rhs.row_pointers_),
112         col_indices_(rhs.col_indices_),
113         values_(rhs.values_) {
114     SetupVecs();
115   }
116 
CSRSparseMatrix(CSRSparseMatrix && rhs)117   CSRSparseMatrix(CSRSparseMatrix&& rhs)
118       : metadata_(rhs.metadata_),
119         dense_shape_(std::move(rhs.dense_shape_)),
120         batch_pointers_(std::move(rhs.batch_pointers_)),
121         row_pointers_(std::move(rhs.row_pointers_)),
122         col_indices_(std::move(rhs.col_indices_)),
123         values_(std::move(rhs.values_)) {
124     SetupVecs();
125     rhs.metadata_.validated = false;
126     rhs.metadata_.dtype = DT_INVALID;
127     rhs.ClearVecs();
128   }
129 
130   CSRSparseMatrix& operator=(CSRSparseMatrix&& rhs) {
131     if (this == &rhs) return *this;
132     metadata_ = rhs.metadata_;
133     metadata_.validated = rhs.metadata_.validated;
134     dense_shape_ = std::move(rhs.dense_shape_);
135     batch_pointers_ = std::move(rhs.batch_pointers_);
136     row_pointers_ = std::move(rhs.row_pointers_);
137     col_indices_ = std::move(rhs.col_indices_);
138     values_ = std::move(rhs.values_);
139     SetupVecs();
140     rhs.metadata_ = {false, DT_INVALID};
141     rhs.ClearVecs();
142     return *this;
143   }
144 
CreateCSRSparseMatrix(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values,CSRSparseMatrix * matrix)145   static Status CreateCSRSparseMatrix(DataType dtype,
146                                       const Tensor& dense_shape,     // on host
147                                       const Tensor& batch_pointers,  // on host
148                                       const Tensor& row_pointers,
149                                       const Tensor& col_indices,
150                                       const Tensor& values,
151                                       CSRSparseMatrix* matrix) {
152     *matrix = CSRSparseMatrix(dtype, dense_shape, batch_pointers, row_pointers,
153                               col_indices, values);
154     Status s = matrix->Validate();
155     matrix->metadata_.validated = s.ok();
156     matrix->SetupVecs();
157     return s;
158   }
159 
Validate()160   Status Validate() const {
161     return ValidateTypesAndShapes(metadata_.dtype, dense_shape_,
162                                   batch_pointers_, row_pointers_, col_indices_,
163                                   values_);
164   }
165 
Clear()166   void Clear() {
167     metadata_ = {false, DT_INVALID};
168     dense_shape_ = Tensor();
169     batch_pointers_ = Tensor();
170     row_pointers_ = Tensor();
171     col_indices_ = Tensor();
172     values_ = Tensor();
173     ClearVecs();
174   }
175 
valid()176   bool valid() const {
177     return metadata_.validated && dense_shape_.IsInitialized() &&
178            batch_pointers_.IsInitialized() && row_pointers_.IsInitialized() &&
179            col_indices_.IsInitialized() && values_.IsInitialized() &&
180            dense_shape_.NumElements() > 1 &&
181            batch_pointers_.NumElements() > 0 && row_pointers_.NumElements() > 0;
182   }
183 
dtype()184   DataType dtype() const {
185     DCHECK(valid());
186     return metadata_.dtype;
187   }
188 
dims()189   inline int dims() const {
190     DCHECK(valid());
191     return dense_shape_.NumElements();
192   }
193 
nnz(int batch)194   inline int nnz(int batch) const {
195     DCHECK_LT(batch, batch_size());
196     return (*batch_pointers_vec_)(batch + 1) - (*batch_pointers_vec_)(batch);
197   }
198 
batch_offset(int batch)199   inline int batch_offset(int batch) const {
200     DCHECK_LT(batch, batch_size());
201     return (*batch_pointers_vec_)(batch);
202   }
203 
total_nnz()204   inline int total_nnz() const {
205     DCHECK(valid());
206     return (*batch_pointers_vec_)(batch_size());
207   }
208 
dense_shape()209   inline Tensor& dense_shape() {
210     DCHECK(valid());
211     return dense_shape_;
212   }
213 
dense_shape()214   inline const Tensor& dense_shape() const {
215     DCHECK(valid());
216     return dense_shape_;
217   }
218 
row_pointers_vec(int batch)219   inline TTypes<int32>::UnalignedVec row_pointers_vec(int batch) {
220     DCHECK(valid());
221     DCHECK_LT(batch, batch_size());
222     const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
223     const int offset = batch * (rows + 1);
224     return TTypes<int32>::UnalignedVec(row_pointers_vec_->data() + offset,
225                                        rows + 1);
226   }
227 
row_pointers_vec(int batch)228   inline TTypes<int32>::UnalignedConstVec row_pointers_vec(int batch) const {
229     DCHECK(valid());
230     DCHECK_LT(batch, batch_size());
231     const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
232     const int offset = batch * (rows + 1);
233     return TTypes<int32>::UnalignedConstVec(row_pointers_vec_->data() + offset,
234                                             rows + 1);
235   }
236 
col_indices_vec(int batch)237   inline TTypes<int32>::UnalignedVec col_indices_vec(int batch) {
238     DCHECK(valid());
239     DCHECK_LT(batch, batch_size());
240     const int offset = (*batch_pointers_vec_)(batch);
241     const int nnz_in_batch = nnz(batch);
242     return TTypes<int32>::UnalignedVec(col_indices_vec_->data() + offset,
243                                        nnz_in_batch);
244   }
245 
col_indices_vec(int batch)246   inline TTypes<int32>::UnalignedConstVec col_indices_vec(int batch) const {
247     DCHECK(valid());
248     DCHECK_LT(batch, batch_size());
249     const int offset = (*batch_pointers_vec_)(batch);
250     const int nnz_in_batch = nnz(batch);
251     return TTypes<int32>::UnalignedConstVec(col_indices_vec_->data() + offset,
252                                             nnz_in_batch);
253   }
254 
255   template <typename T>
values_vec(int batch)256   inline typename TTypes<T>::UnalignedVec values_vec(int batch) {
257     DCHECK(valid());
258     DCHECK_LT(batch, batch_size());
259     const int offset = (*batch_pointers_vec_)(batch);
260     const int nnz_in_batch = nnz(batch);
261     return typename TTypes<T>::UnalignedVec(&(values().vec<T>()(offset)),
262                                             nnz_in_batch);
263   }
264 
265   template <typename T>
values_vec(int batch)266   inline typename TTypes<T>::UnalignedConstVec values_vec(int batch) const {
267     DCHECK(valid());
268     DCHECK_LT(batch, batch_size());
269     const int offset = (*batch_pointers_vec_)(batch);
270     const int nnz_in_batch = nnz(batch);
271     return typename TTypes<T>::UnalignedConstVec(&(values().vec<T>()(offset)),
272                                                  nnz_in_batch);
273   }
274 
row_pointers()275   inline Tensor& row_pointers() {
276     DCHECK(valid());
277     return row_pointers_;
278   }
279 
row_pointers()280   inline const Tensor& row_pointers() const {
281     DCHECK(valid());
282     return row_pointers_;
283   }
284 
col_indices()285   inline Tensor& col_indices() {
286     DCHECK(valid());
287     return col_indices_;
288   }
289 
col_indices()290   inline const Tensor& col_indices() const {
291     DCHECK(valid());
292     return col_indices_;
293   }
294 
values()295   inline Tensor& values() {
296     DCHECK(valid());
297     return values_;
298   }
299 
values()300   inline const Tensor& values() const {
301     DCHECK(valid());
302     return values_;
303   }
304 
batch_pointers()305   inline Tensor& batch_pointers() {
306     DCHECK(valid());
307     return batch_pointers_;
308   }
309 
batch_pointers()310   inline const Tensor& batch_pointers() const {
311     DCHECK(valid());
312     return batch_pointers_;
313   }
314 
TypeName()315   std::string TypeName() const { return kTypeName; }
316 
317   // TODO(ebrevdo): A better debug string.
DebugString()318   std::string DebugString() const { return dense_shape_.DebugString(); }
319 
320   // Returns the number of elements.  This is equal to 1 if the
321   // CSRSparseMatrix is a singleton matrix (dense_shape is length 2).
batch_size()322   int batch_size() const {
323     DCHECK(valid());
324     return batch_pointers_.NumElements() - 1;
325   }
326 
Decode(const VariantTensorData & p)327   bool Decode(const VariantTensorData& p) {
328     if (p.tensors_.empty()) return false;
329     Metadata metadata;
330     if (!p.get_metadata(&metadata)) return false;
331     const bool validated = metadata.validated;
332     const DataType dtype = metadata.dtype;
333 
334     // p.tensors_ should contain tensors {dense_shape, batch_pointers,
335     // row_pointers, col_indices, values}.
336     if (p.tensors_.size() != 5) return false;
337 
338     Tensor dense_shape = p.tensors_[0];
339     if (dense_shape.dtype() != DT_INT64) return false;
340     if (dense_shape.dims() != 1) return false;
341     int rank = dense_shape.dim_size(0);
342     if (rank < 2 || rank > 3) return false;
343 
344     Tensor batch_pointers(p.tensors_[1]);
345     Tensor row_pointers(p.tensors_[2]);
346     Tensor col_indices(p.tensors_[3]);
347     Tensor values(p.tensors_[4]);
348 
349     // Check that the validated bool is consistent with the data.
350     Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers,
351                                       row_pointers, col_indices, values);
352     if (s.ok() != validated) return false;
353 
354     // Save to this object.
355     metadata_ = metadata;
356     dense_shape_ = std::move(dense_shape);
357     batch_pointers_ = std::move(batch_pointers);
358     row_pointers_ = std::move(row_pointers);
359     col_indices_ = std::move(col_indices);
360     values_ = std::move(values);
361     SetupVecs();
362     return true;
363   }
364 
Encode(VariantTensorData * p)365   void Encode(VariantTensorData* p) const {
366     DCHECK(valid());
367 
368     // Store metadata_ to p's metadata
369     p->set_metadata(metadata_);
370 
371     // Store dense_shape, row_pointers, col_indices, and values to p->tensors_.
372     p->tensors_.reserve(5);
373     p->tensors_.push_back(dense_shape_);
374     p->tensors_.push_back(batch_pointers_);
375     p->tensors_.push_back(row_pointers_);
376     p->tensors_.push_back(col_indices_);
377     p->tensors_.push_back(values_);
378   }
379 
380   // This static method copies CSRSparseMatrices in all directions:
381   //   Host->Device, Device->Host, and Device->Device.
DeviceCopy(const CSRSparseMatrix & from,CSRSparseMatrix * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)382   static Status DeviceCopy(
383       const CSRSparseMatrix& from, CSRSparseMatrix* to,
384       const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
385     VLOG(2) << "DeviceCopy from type: " << DataTypeString(from.dtype())
386             << " and shape: " << from.dense_shape().DebugString();
387     Tensor to_row_ptr(DT_INT32);
388     Tensor to_col_ind(DT_INT32);
389     Tensor to_values(from.dtype());
390     TF_RETURN_IF_ERROR(copy(from.row_pointers(), &to_row_ptr));
391     TF_RETURN_IF_ERROR(copy(from.col_indices(), &to_col_ind));
392     TF_RETURN_IF_ERROR(copy(from.values(), &to_values));
393     return CreateCSRSparseMatrix(from.dtype(),
394                                  from.dense_shape(),     // Always on host.
395                                  from.batch_pointers(),  // Always on host.
396                                  to_row_ptr, to_col_ind, to_values, to);
397   }
398 
399  private:
CSRSparseMatrix(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values)400   CSRSparseMatrix(DataType dtype, const Tensor& dense_shape,
401                   const Tensor& batch_pointers, const Tensor& row_pointers,
402                   const Tensor& col_indices, const Tensor& values)
403       : metadata_{false, dtype},
404         dense_shape_(dense_shape),
405         batch_pointers_(batch_pointers),
406         row_pointers_(row_pointers),
407         col_indices_(col_indices),
408         values_(values) {}
409 
SetupVecs()410   void SetupVecs() {
411     if (!metadata_.validated) return;
412     batch_pointers_vec_.reset(
413         new TTypes<int32>::Vec(batch_pointers_.vec<int32>()));
414     row_pointers_vec_.reset(new TTypes<int32>::Vec(row_pointers_.vec<int32>()));
415     col_indices_vec_.reset(new TTypes<int32>::Vec(col_indices_.vec<int32>()));
416   }
417 
ClearVecs()418   void ClearVecs() {
419     batch_pointers_vec_.reset();
420     row_pointers_vec_.reset();
421     col_indices_vec_.reset();
422   }
423 
ValidateTypesAndShapes(DataType dtype,const Tensor & dense_shape,const Tensor & batch_pointers,const Tensor & row_pointers,const Tensor & col_indices,const Tensor & values)424   static Status ValidateTypesAndShapes(DataType dtype,
425                                        const Tensor& dense_shape,
426                                        const Tensor& batch_pointers,
427                                        const Tensor& row_pointers,
428                                        const Tensor& col_indices,
429                                        const Tensor& values) {
430     // TODO(ebrevdo): Consider adding support for other floating point types
431     // (namely, float16).
432     if (dtype != DT_FLOAT && dtype != DT_DOUBLE && dtype != DT_COMPLEX64 &&
433         dtype != DT_COMPLEX128) {
434       return errors::InvalidArgument(
435           "CSRSparseMatrix::Validate: dtype = ", DataTypeString(dtype),
436           " not in {float32, float64, complex64, complex128}");
437     }
438     // dense_shape checks
439     if (dense_shape.dtype() != DT_INT64) {
440       return errors::InvalidArgument(
441           "CSRSparseMatrix::Validate: dense_shape.dtype() = ",
442           DataTypeString(dense_shape.dtype()), " != int64");
443     }
444     if (dense_shape.dims() != 1) {
445       return errors::InvalidArgument(
446           "CSRSparseMatrix::Validate: dense_shape should be a vector, but saw "
447           "tensor: ",
448           dense_shape.DebugString());
449     }
450     int rank = dense_shape.dim_size(0);
451     if (rank < 2 || rank > 3) {
452       return errors::InvalidArgument(
453           "CSRSparseMatrix::Validate: dense_shape should be a 2- or 3- vector, "
454           "but saw: ",
455           dense_shape.SummarizeValue(5));
456     }
457     auto dense_shape_t = dense_shape.vec<int64>();
458     const int64 batch_size = (rank == 2) ? 1 : dense_shape_t(0);
459     const int64 num_rows = (rank == 2) ? dense_shape_t(0) : dense_shape_t(1);
460 
461     if (batch_pointers.dtype() != DT_INT32) {
462       return errors::InvalidArgument(
463           "CSRSparseMatrix::Validate: batch_pointers.dtype() = ",
464           DataTypeString(batch_pointers.dtype()), " != int32");
465     }
466     if (batch_pointers.dims() != 1) {
467       return errors::InvalidArgument(
468           "CSRSparseMatrix::Validate: batch_indices is not a vector, saw "
469           "shape: ",
470           batch_pointers.shape().DebugString());
471     }
472 
473     // batch size checks
474     if (batch_size != batch_pointers.NumElements() - 1) {
475       return errors::InvalidArgument(
476           "CSRSparseMatrix::Validate: dense_shape is ",
477           dense_shape.SummarizeValue(5),
478           " but batch pointers implies batch size is ",
479           batch_pointers.NumElements() - 1);
480     }
481 
482     if (row_pointers.dtype() != DT_INT32) {
483       return errors::InvalidArgument(
484           "CSRSparseMatrix::Validate: row_pointers.dtype() = ",
485           DataTypeString(row_pointers.dtype()), " != int32");
486     }
487     if (row_pointers.dims() != 1) {
488       return errors::InvalidArgument(
489           "CSRSparseMatrix::Validate: row_pointers is not a vector, saw "
490           "shape: ",
491           row_pointers.shape().DebugString());
492     }
493     if (row_pointers.dim_size(0) != batch_size * (num_rows + 1)) {
494       return errors::InvalidArgument(
495           "CSRSparseMatrix::Validate: row_pointers should have size batch_size "
496           "* (num_rows + 1), saw shapes: ",
497           dense_shape.DebugString(), " vs. ",
498           row_pointers.shape().DebugString());
499     }
500     if (col_indices.dtype() != DT_INT32) {
501       return errors::InvalidArgument(
502           "CSRSparseMatrix::Validate: col_indices.dtype() = ",
503           DataTypeString(col_indices.dtype()), " != int32");
504     }
505     if (col_indices.dims() != 1) {
506       return errors::InvalidArgument(
507           "CSRSparseMatrix::Validate: col_indices is not a vector, saw shape: ",
508           col_indices.shape().DebugString());
509     }
510     if (values.dtype() != dtype) {
511       return errors::InvalidArgument(
512           "CSRSparseMatrix::Validate: values.dtype() = ",
513           DataTypeString(values.dtype()),
514           " != dtype = ", DataTypeString(dtype));
515     }
516     if (values.dims() != 1) {
517       return errors::InvalidArgument(
518           "CSRSparseMatrix::Validate: values is not a vector, saw shape: ",
519           values.shape().DebugString());
520     }
521     if (col_indices.dim_size(0) != values.dim_size(0)) {
522       return errors::InvalidArgument(
523           "CSRSparseMatrix::Validate: size(col_indices) = ",
524           col_indices.dim_size(0), " != size(values) = ", values.dim_size(0));
525     }
526     return Status::OK();
527   }
528 
529   struct Metadata {
530     bool validated;
531     DataType dtype;
532   };
533   Metadata metadata_;
534   Tensor dense_shape_;
535   Tensor batch_pointers_;
536   Tensor row_pointers_;
537   Tensor col_indices_;
538   Tensor values_;
539   std::unique_ptr<TTypes<int32>::Vec> batch_pointers_vec_;
540   std::unique_ptr<TTypes<int32>::Vec> row_pointers_vec_;
541   std::unique_ptr<TTypes<int32>::Vec> col_indices_vec_;
542 };
543 
544 // Call BinaryFunctor<Device, T>()(ctx, a, b, c)
545 // where T depends on a.dtype().  T will be one of: float, double,
546 // complex64, complex128.
547 template <typename Device, template <typename, typename> class BinaryFunctor>
CSRSparseMatrixBinaryHelper(OpKernelContext * ctx,const CSRSparseMatrix & a,const CSRSparseMatrix & b,CSRSparseMatrix * c)548 Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx,
549                                    const CSRSparseMatrix& a,
550                                    const CSRSparseMatrix& b,
551                                    CSRSparseMatrix* c) {
552   DataType dt = a.dtype();
553   if (dt != b.dtype()) {
554     return errors::InvalidArgument(
555         "CSRSparseMatrixBinaryHelper: Inconsistent dtypes for input matrices, "
556         "a "
557         "dtype: ",
558         DataTypeString(dt), ", b dtype: ", DataTypeString(b.dtype()));
559   }
560   switch (dt) {
561     case DT_FLOAT: {
562       BinaryFunctor<Device, float> functor(ctx);
563       return functor(a, b, c);
564     }
565     case DT_DOUBLE: {
566       BinaryFunctor<Device, double> functor(ctx);
567       return functor(a, b, c);
568     }
569     case DT_COMPLEX64: {
570       BinaryFunctor<Device, complex64> functor(ctx);
571       return functor(a, b, c);
572     }
573     case DT_COMPLEX128: {
574       BinaryFunctor<Device, complex128> functor(ctx);
575       return functor(a, b, c);
576     }
577     default:
578       return errors::InvalidArgument(
579           "CSRSparseMatrixBinaryHelper: a.dtype (", DataTypeString(dt),
580           ") is not one of: float, double, complex64, complex128");
581   }
582 }
583 
584 // Call UnaryFunctor<Device, T>()(ctx, a, b)
585 // where T depends on a.dtype().  T will be one of: float, double,
586 // complex64, complex128.
587 template <typename Device, template <typename, typename> class UnaryFunctor>
CSRSparseMatrixUnaryHelper(OpKernelContext * ctx,const CSRSparseMatrix & a,CSRSparseMatrix * b)588 Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx,
589                                   const CSRSparseMatrix& a,
590                                   CSRSparseMatrix* b) {
591   DataType dt = a.dtype();
592   switch (dt) {
593     case DT_FLOAT: {
594       UnaryFunctor<Device, float> functor(ctx);
595       return functor(a, b);
596     }
597     case DT_DOUBLE: {
598       UnaryFunctor<Device, double> functor(ctx);
599       return functor(a, b);
600     }
601     case DT_COMPLEX64: {
602       UnaryFunctor<Device, complex64> functor(ctx);
603       return functor(a, b);
604     }
605     case DT_COMPLEX128: {
606       UnaryFunctor<Device, complex128> functor(ctx);
607       return functor(a, b);
608     }
609     default:
610       return errors::InvalidArgument(
611           "CSRSparseMatrixUnaryHelper: a.dtype (", DataTypeString(dt),
612           ") is not one of: float, double, complex64, complex128");
613   }
614 }
615 
616 template <typename T>
617 struct ConstCSRComponent {
618   TTypes<int32>::UnalignedConstVec row_ptr;
619   TTypes<int32>::UnalignedConstVec col_ind;
620   typename TTypes<T>::UnalignedConstVec values;
621   TTypes<int64>::ConstVec dense_shape_host;
622 };
623 
624 template <typename T>
625 struct CSRComponent {
626   TTypes<int32>::UnalignedVec row_ptr;
627   TTypes<int32>::UnalignedVec col_ind;
628   typename TTypes<T>::UnalignedVec values;
629   TTypes<int64>::Vec dense_shape_host;
630 };
631 
632 template <typename T>
ExtractVariantFromInput(OpKernelContext * ctx,int index,const T ** value)633 Status ExtractVariantFromInput(OpKernelContext* ctx, int index,
634                                const T** value) {
635   const Tensor& input_t = ctx->input(index);
636   const Variant& input_variant = input_t.scalar<Variant>()();
637   *value = input_variant.get<T>();
638   if (*value == nullptr) {
639     return errors::InvalidArgument("Could not retrieve Variant input ", index);
640   }
641   if (!(*value)->valid()) {
642     return errors::InvalidArgument("Variant input ", index, " is not valid.");
643   }
644   return Status::OK();
645 }
646 
647 }  // namespace tensorflow
648 
649 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
650