1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 17 #define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 namespace functor { 29 30 // Calculates number of nonzero entries per batch of a sorted rank-3 31 // SparseTensor's indices. indices is expected to have columns 32 // corresponding to [batch, row, column], where indices[:,0] < B. 33 // 34 // REQUIRES: 35 // indices.dimension(1) == 3 36 // nnz_per_batch.dimension(0) == B 37 template <typename Device> 38 struct CalculateNNZPerBatchMatrixFromIndices { 39 Status operator()(OpKernelContext* c, TTypes<int64>::ConstMatrix indices, 40 TTypes<int32>::Vec nnz_per_batch); 41 }; 42 43 // Split a subset of a SparseTensors' indices into two vectors: 44 // COO row inds and COO col inds. Outputs are: 45 // 46 // coo_row_ind = indices[:, row_dim] 47 // coo_col_ind = indices[:, row_dim + 1] 48 // 49 // where n = coo_row_ind.size() 50 // and row_dim = #cols(indices) - 1 51 // 52 // REQUIRES: 53 // host_dense_shape.size() in [2, 3] 54 // indices.dim_size(1) == host_dense_shape.size() 55 // coo_row_ind.size() == coo_col_ind.size() 56 // coo_row_ind.size() == indices.dim_size(0) 57 template <typename Device> 58 struct SparseTensorToCOOSparseMatrix { 59 void operator()(const Device& d, TTypes<int64>::ConstVec host_dense_shape, 60 TTypes<int64>::ConstMatrix indices, 61 TTypes<int32>::Vec coo_row_ind, 62 TTypes<int32>::Vec coo_col_ind); 63 }; 64 65 // Write coo batch, row, and column vectors to output matrix indices: 66 // 67 // indices[:, row_dim] = coo_row_ind 68 // indices[:, col_dim] = coo_col_ind 69 // 70 // where row_dim = #cols(indices) - 1 and n = coo_row_ind.size(). 71 // In addition, if #cols(indices) == 3, also store the batch: 72 // 73 // indices[i, 0] = batch_of(i) where 74 // host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1) 75 // 76 // REQUIRES: 77 // 78 // host_dense_shape.size() in [2, 3] 79 // indices.dim_size(1) == host_dense_shape.size() 80 // host_batch_ptr.size() == 81 // coo_row_ind.size() == coo_col_ind.size() 82 // 83 template <typename Device> 84 struct COOSparseMatrixToSparseTensor { 85 Status operator()(OpKernelContext* c, 86 TTypes<int64>::ConstVec host_dense_shape, 87 TTypes<int32>::ConstVec host_batch_ptrs, 88 TTypes<int32>::Vec coo_row_ind, 89 TTypes<int32>::ConstVec coo_col_ind, 90 TTypes<int64>::Matrix indices); 91 }; 92 93 // Convert a vector of coo row indices to csr row pointers. 94 // 95 // REQUIRES: 96 // 97 // csr_row_ptr.size() == rows + 1. 98 // max(coo_row_ptr) < rows. 99 // 100 template <typename Device> 101 struct COOSparseMatrixToCSRSparseMatrix { 102 Status operator()(OpKernelContext* c, const int rows, const int cols, 103 TTypes<int32>::UnalignedVec coo_row_ind, 104 TTypes<int32>::UnalignedVec csr_row_ptr); 105 }; 106 107 // Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix 108 // batch ptrs, csr row pointers and coo column indices. 109 // 110 // REQUIRES: 111 // batch_ptr.size() == batch_size + 1 112 // csr_row_ptr.size() == batch_size * (num_rows + 1) 113 // csr_col_ind.size() == total_nnz 114 // batch_size == 1 if rank == 2 115 // 116 // where 117 // total_nnz = indices.dim_size(0) 118 // rank = indices.dim_size(1) 119 // Also csr_row_ptr should be initially filled with zeros. 120 // 121 struct SparseTensorToCSRSparseMatrixCPUFunctor { 122 Status operator()(const int64 batch_size, const int num_rows, 123 TTypes<int64>::ConstMatrix indices, 124 TTypes<int32>::Vec batch_ptr, 125 TTypes<int32>::Vec csr_row_ptr, 126 TTypes<int32>::Vec csr_col_ind); 127 }; 128 129 // Convert a vector of csr row pointers to coo row indices. 130 // 131 // REQUIRES: 132 // 133 // coo_row_ptr.size() == nnz. 134 // csr_row_ptr[-1] == nnz. 135 // 136 template <typename Device> 137 struct CSRSparseMatrixToCOOSparseMatrix { 138 Status operator()(OpKernelContext* c, 139 TTypes<int32>::UnalignedConstVec csr_row_ptr, 140 TTypes<int32>::UnalignedVec coo_row_ind); 141 }; 142 143 // Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format 144 // and B and C are dense. 145 template <typename Device, typename T> 146 struct CSRSparseMatrixMatMul { 147 explicit CSRSparseMatrixMatMul(const bool transpose_output); 148 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 149 typename TTypes<T>::ConstMatrix b, 150 typename TTypes<T>::Matrix c); 151 }; 152 153 // Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format 154 // and x and y are dense vectors. 155 template <typename Device, typename T> 156 class CSRSparseMatrixMatVec { 157 CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a); 158 Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a, 159 const T* x, T* y); 160 }; 161 162 // Calculates C = functor(A, B) where A and B are CSR and C is CSR 163 // with a different sparsity pattern. 164 template <typename Device, typename T> 165 struct CSRStructureModifyingFunctor { ~CSRStructureModifyingFunctorCSRStructureModifyingFunctor166 virtual ~CSRStructureModifyingFunctor() {} 167 168 virtual Status Initialize() = 0; 169 170 virtual Status GetWorkspaceSize(const ConstCSRComponent<T>& a, 171 const ConstCSRComponent<T>& b, 172 size_t* bufferSize) = 0; 173 174 virtual Status GetOutputStructure(const ConstCSRComponent<T>& a, 175 const ConstCSRComponent<T>& b, 176 TTypes<int32>::UnalignedVec c_row_ptr, 177 int* output_nnz, void* workspace) = 0; 178 179 virtual Status Compute(const ConstCSRComponent<T>& a, 180 const ConstCSRComponent<T>& b, CSRComponent<T>* c, 181 void* workspace) = 0; 182 }; 183 184 // Calculates C = alpha * A + beta * B, where A and B are in CSR 185 // format, and alpha and beta are scalars on the host. 186 template <typename Device, typename T> 187 struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor<Device, T> { 188 explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, 189 const T beta); 190 }; 191 192 // Calculates C = matmul(A, B), where A, B, and C are in CSR format. 193 template <typename Device, typename T> 194 struct CSRSparseSparseMatrixMatMul 195 : public CSRStructureModifyingFunctor<Device, T> { 196 explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a, 197 bool transpose_b); 198 }; 199 200 // Calculates Y = transpose(X) where X and Y are CSR format components. 201 template <typename Device, typename T> 202 struct CSRSparseMatrixTransposeComponent { 203 Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x, 204 CSRComponent<T>* y); 205 }; 206 207 // Calculates Y = transpose(X) where X and Y are in CSR format. 208 template <typename Device, typename T> 209 struct CSRSparseMatrixTranspose { 210 Status operator()(OpKernelContext* ctx, bool conjugate, 211 const CSRSparseMatrix& input_matrix, 212 CSRSparseMatrix* output_matrix); 213 }; 214 215 // Calculates Y = softmax(X) where X and Y are in CSR format; 216 // missing coefficients in X are treates as -inf (logits of 0 probability). 217 template <typename Device, typename T> 218 struct CSRSparseMatrixSoftmax { 219 Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits, 220 typename TTypes<T>::Vec softmax_values); 221 }; 222 223 template <typename Device, typename T> 224 struct CSRSparseMatrixSoftmaxGrad { 225 Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax, 226 const CSRSparseMatrix& grad_softmax, 227 typename TTypes<T>::Vec gradient_values); 228 }; 229 230 template <typename Device, typename T> 231 class CSRSparseMatrixMulScalar { 232 public: CSRSparseMatrixMulScalar()233 explicit CSRSparseMatrixMulScalar() {} 234 235 Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, 236 typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c); 237 }; 238 239 template <typename Device, typename T> 240 class CSRSparseMatrixBatchMulVec { 241 public: CSRSparseMatrixBatchMulVec()242 explicit CSRSparseMatrixBatchMulVec() {} 243 244 Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, 245 typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c); 246 }; 247 248 } // namespace functor 249 250 } // namespace tensorflow 251 252 #endif // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ 253