1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
17 #define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 namespace functor {
29 
30 // Calculates number of nonzero entries per batch of a sorted rank-3
31 // SparseTensor's indices.  indices is expected to have columns
32 // corresponding to [batch, row, column],  where indices[:,0] < B.
33 //
34 // REQUIRES:
35 //  indices.dimension(1) == 3
36 //  nnz_per_batch.dimension(0) == B
37 template <typename Device>
38 struct CalculateNNZPerBatchMatrixFromIndices {
39   Status operator()(OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
40                     TTypes<int32>::Vec nnz_per_batch);
41 };
42 
43 // Split a subset of a SparseTensors' indices into two vectors:
44 // COO row inds and COO col inds.  Outputs are:
45 //
46 //   coo_row_ind = indices[:, row_dim]
47 //   coo_col_ind = indices[:, row_dim + 1]
48 //
49 // where n = coo_row_ind.size()
50 // and row_dim = #cols(indices) - 1
51 //
52 // REQUIRES:
53 //   host_dense_shape.size() in [2, 3]
54 //   indices.dim_size(1) == host_dense_shape.size()
55 //   coo_row_ind.size() == coo_col_ind.size()
56 //   coo_row_ind.size() == indices.dim_size(0)
57 template <typename Device>
58 struct SparseTensorToCOOSparseMatrix {
59   void operator()(const Device& d, TTypes<int64>::ConstVec host_dense_shape,
60                   TTypes<int64>::ConstMatrix indices,
61                   TTypes<int32>::Vec coo_row_ind,
62                   TTypes<int32>::Vec coo_col_ind);
63 };
64 
65 // Write coo batch, row, and column vectors to output matrix indices:
66 //
67 //   indices[:, row_dim] = coo_row_ind
68 //   indices[:, col_dim] = coo_col_ind
69 //
70 // where row_dim = #cols(indices) - 1 and n = coo_row_ind.size().
71 // In addition, if #cols(indices) == 3, also store the batch:
72 //
73 //   indices[i, 0] = batch_of(i) where
74 //      host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1)
75 //
76 // REQUIRES:
77 //
78 //   host_dense_shape.size() in [2, 3]
79 //   indices.dim_size(1) == host_dense_shape.size()
80 //   host_batch_ptr.size() ==
81 //   coo_row_ind.size() == coo_col_ind.size()
82 //
83 template <typename Device>
84 struct COOSparseMatrixToSparseTensor {
85   Status operator()(OpKernelContext* c,
86                     TTypes<int64>::ConstVec host_dense_shape,
87                     TTypes<int32>::ConstVec host_batch_ptrs,
88                     TTypes<int32>::Vec coo_row_ind,
89                     TTypes<int32>::ConstVec coo_col_ind,
90                     TTypes<int64>::Matrix indices);
91 };
92 
93 // Convert a vector of coo row indices to csr row pointers.
94 //
95 // REQUIRES:
96 //
97 //   csr_row_ptr.size() == rows + 1.
98 //   max(coo_row_ptr) < rows.
99 //
100 template <typename Device>
101 struct COOSparseMatrixToCSRSparseMatrix {
102   Status operator()(OpKernelContext* c, const int rows, const int cols,
103                     TTypes<int32>::UnalignedVec coo_row_ind,
104                     TTypes<int32>::UnalignedVec csr_row_ptr);
105 };
106 
107 // Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix
108 // batch ptrs, csr row pointers and coo column indices.
109 //
110 // REQUIRES:
111 //   batch_ptr.size() == batch_size + 1
112 //   csr_row_ptr.size() == batch_size * (num_rows + 1)
113 //   csr_col_ind.size() == total_nnz
114 //   batch_size == 1 if rank == 2
115 //
116 //   where
117 //     total_nnz = indices.dim_size(0)
118 //     rank = indices.dim_size(1)
119 //   Also csr_row_ptr should be initially filled with zeros.
120 //
121 struct SparseTensorToCSRSparseMatrixCPUFunctor {
122   Status operator()(const int64 batch_size, const int num_rows,
123                     TTypes<int64>::ConstMatrix indices,
124                     TTypes<int32>::Vec batch_ptr,
125                     TTypes<int32>::Vec csr_row_ptr,
126                     TTypes<int32>::Vec csr_col_ind);
127 };
128 
129 // Convert a vector of csr row pointers to coo row indices.
130 //
131 // REQUIRES:
132 //
133 //   coo_row_ptr.size() == nnz.
134 //   csr_row_ptr[-1] == nnz.
135 //
136 template <typename Device>
137 struct CSRSparseMatrixToCOOSparseMatrix {
138   Status operator()(OpKernelContext* c,
139                     TTypes<int32>::UnalignedConstVec csr_row_ptr,
140                     TTypes<int32>::UnalignedVec coo_row_ind);
141 };
142 
143 // Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format
144 // and B and C are dense.
145 template <typename Device, typename T>
146 struct CSRSparseMatrixMatMul {
147   explicit CSRSparseMatrixMatMul(const bool transpose_output);
148   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
149                  typename TTypes<T>::ConstMatrix b,
150                  typename TTypes<T>::Matrix c);
151 };
152 
153 // Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format
154 // and x and y are dense vectors.
155 template <typename Device, typename T>
156 class CSRSparseMatrixMatVec {
157   CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a);
158   Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
159                  const T* x, T* y);
160 };
161 
162 // Calculates C = functor(A, B) where A and B are CSR and C is CSR
163 // with a different sparsity pattern.
164 template <typename Device, typename T>
165 struct CSRStructureModifyingFunctor {
~CSRStructureModifyingFunctorCSRStructureModifyingFunctor166   virtual ~CSRStructureModifyingFunctor() {}
167 
168   virtual Status Initialize() = 0;
169 
170   virtual Status GetWorkspaceSize(const ConstCSRComponent<T>& a,
171                                   const ConstCSRComponent<T>& b,
172                                   size_t* bufferSize) = 0;
173 
174   virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
175                                     const ConstCSRComponent<T>& b,
176                                     TTypes<int32>::UnalignedVec c_row_ptr,
177                                     int* output_nnz, void* workspace) = 0;
178 
179   virtual Status Compute(const ConstCSRComponent<T>& a,
180                          const ConstCSRComponent<T>& b, CSRComponent<T>* c,
181                          void* workspace) = 0;
182 };
183 
184 // Calculates C = alpha * A + beta * B, where A and B are in CSR
185 // format, and alpha and beta are scalars on the host.
186 template <typename Device, typename T>
187 struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor<Device, T> {
188   explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha,
189                               const T beta);
190 };
191 
192 // Calculates C = matmul(A, B), where A, B, and C are in CSR format.
193 template <typename Device, typename T>
194 struct CSRSparseSparseMatrixMatMul
195     : public CSRStructureModifyingFunctor<Device, T> {
196   explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
197                                        bool transpose_b);
198 };
199 
200 // Calculates Y = transpose(X) where X and Y are CSR format components.
201 template <typename Device, typename T>
202 struct CSRSparseMatrixTransposeComponent {
203   Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
204                     CSRComponent<T>* y);
205 };
206 
207 // Calculates Y = transpose(X) where X and Y are in CSR format.
208 template <typename Device, typename T>
209 struct CSRSparseMatrixTranspose {
210   Status operator()(OpKernelContext* ctx, bool conjugate,
211                     const CSRSparseMatrix& input_matrix,
212                     CSRSparseMatrix* output_matrix);
213 };
214 
215 // Calculates Y = softmax(X) where X and Y are in CSR format;
216 // missing coefficients in X are treates as -inf (logits of 0 probability).
217 template <typename Device, typename T>
218 struct CSRSparseMatrixSoftmax {
219   Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits,
220                     typename TTypes<T>::Vec softmax_values);
221 };
222 
223 template <typename Device, typename T>
224 struct CSRSparseMatrixSoftmaxGrad {
225   Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax,
226                     const CSRSparseMatrix& grad_softmax,
227                     typename TTypes<T>::Vec gradient_values);
228 };
229 
230 template <typename Device, typename T>
231 class CSRSparseMatrixMulScalar {
232  public:
CSRSparseMatrixMulScalar()233   explicit CSRSparseMatrixMulScalar() {}
234 
235   Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
236                  typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c);
237 };
238 
239 template <typename Device, typename T>
240 class CSRSparseMatrixBatchMulVec {
241  public:
CSRSparseMatrixBatchMulVec()242   explicit CSRSparseMatrixBatchMulVec() {}
243 
244   Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
245                  typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c);
246 };
247 
248 }  // namespace functor
249 
250 }  // namespace tensorflow
251 
252 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
253