1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_SPARSEPRODUCT_H 11 #define EIGEN_SPARSEPRODUCT_H 12 13 namespace Eigen { 14 15 template<typename Lhs, typename Rhs> 16 struct SparseSparseProductReturnType 17 { 18 typedef typename internal::traits<Lhs>::Scalar Scalar; 19 typedef typename internal::traits<Lhs>::Index Index; 20 enum { 21 LhsRowMajor = internal::traits<Lhs>::Flags & RowMajorBit, 22 RhsRowMajor = internal::traits<Rhs>::Flags & RowMajorBit, 23 TransposeRhs = (!LhsRowMajor) && RhsRowMajor, 24 TransposeLhs = LhsRowMajor && (!RhsRowMajor) 25 }; 26 27 typedef typename internal::conditional<TransposeLhs, 28 SparseMatrix<Scalar,0,Index>, 29 typename internal::nested<Lhs,Rhs::RowsAtCompileTime>::type>::type LhsNested; 30 31 typedef typename internal::conditional<TransposeRhs, 32 SparseMatrix<Scalar,0,Index>, 33 typename internal::nested<Rhs,Lhs::RowsAtCompileTime>::type>::type RhsNested; 34 35 typedef SparseSparseProduct<LhsNested, RhsNested> Type; 36 }; 37 38 namespace internal { 39 template<typename LhsNested, typename RhsNested> 40 struct traits<SparseSparseProduct<LhsNested, RhsNested> > 41 { 42 typedef MatrixXpr XprKind; 43 // clean the nested types: 44 typedef typename remove_all<LhsNested>::type _LhsNested; 45 typedef typename remove_all<RhsNested>::type _RhsNested; 46 typedef typename _LhsNested::Scalar Scalar; 47 typedef typename promote_index_type<typename traits<_LhsNested>::Index, 48 typename traits<_RhsNested>::Index>::type Index; 49 50 enum { 51 LhsCoeffReadCost = _LhsNested::CoeffReadCost, 52 RhsCoeffReadCost = _RhsNested::CoeffReadCost, 53 LhsFlags = _LhsNested::Flags, 54 RhsFlags = _RhsNested::Flags, 55 56 RowsAtCompileTime = _LhsNested::RowsAtCompileTime, 57 ColsAtCompileTime = _RhsNested::ColsAtCompileTime, 58 MaxRowsAtCompileTime = _LhsNested::MaxRowsAtCompileTime, 59 MaxColsAtCompileTime = _RhsNested::MaxColsAtCompileTime, 60 61 InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(_LhsNested::ColsAtCompileTime, _RhsNested::RowsAtCompileTime), 62 63 EvalToRowMajor = (RhsFlags & LhsFlags & RowMajorBit), 64 65 RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), 66 67 Flags = (int(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) 68 | EvalBeforeAssigningBit 69 | EvalBeforeNestingBit, 70 71 CoeffReadCost = Dynamic 72 }; 73 74 typedef Sparse StorageKind; 75 }; 76 77 } // end namespace internal 78 79 template<typename LhsNested, typename RhsNested> 80 class SparseSparseProduct : internal::no_assignment_operator, 81 public SparseMatrixBase<SparseSparseProduct<LhsNested, RhsNested> > 82 { 83 public: 84 85 typedef SparseMatrixBase<SparseSparseProduct> Base; 86 EIGEN_DENSE_PUBLIC_INTERFACE(SparseSparseProduct) 87 88 private: 89 90 typedef typename internal::traits<SparseSparseProduct>::_LhsNested _LhsNested; 91 typedef typename internal::traits<SparseSparseProduct>::_RhsNested _RhsNested; 92 93 public: 94 95 template<typename Lhs, typename Rhs> 96 EIGEN_STRONG_INLINE SparseSparseProduct(const Lhs& lhs, const Rhs& rhs) 97 : m_lhs(lhs), m_rhs(rhs), m_tolerance(0), m_conservative(true) 98 { 99 init(); 100 } 101 102 template<typename Lhs, typename Rhs> 103 EIGEN_STRONG_INLINE SparseSparseProduct(const Lhs& lhs, const Rhs& rhs, const RealScalar& tolerance) 104 : m_lhs(lhs), m_rhs(rhs), m_tolerance(tolerance), m_conservative(false) 105 { 106 init(); 107 } 108 109 SparseSparseProduct pruned(const Scalar& reference = 0, const RealScalar& epsilon = NumTraits<RealScalar>::dummy_precision()) const 110 { 111 using std::abs; 112 return SparseSparseProduct(m_lhs,m_rhs,abs(reference)*epsilon); 113 } 114 115 template<typename Dest> 116 void evalTo(Dest& result) const 117 { 118 if(m_conservative) 119 internal::conservative_sparse_sparse_product_selector<_LhsNested, _RhsNested, Dest>::run(lhs(),rhs(),result); 120 else 121 internal::sparse_sparse_product_with_pruning_selector<_LhsNested, _RhsNested, Dest>::run(lhs(),rhs(),result,m_tolerance); 122 } 123 124 EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); } 125 EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); } 126 127 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; } 128 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; } 129 130 protected: 131 void init() 132 { 133 eigen_assert(m_lhs.cols() == m_rhs.rows()); 134 135 enum { 136 ProductIsValid = _LhsNested::ColsAtCompileTime==Dynamic 137 || _RhsNested::RowsAtCompileTime==Dynamic 138 || int(_LhsNested::ColsAtCompileTime)==int(_RhsNested::RowsAtCompileTime), 139 AreVectors = _LhsNested::IsVectorAtCompileTime && _RhsNested::IsVectorAtCompileTime, 140 SameSizes = EIGEN_PREDICATE_SAME_MATRIX_SIZE(_LhsNested,_RhsNested) 141 }; 142 // note to the lost user: 143 // * for a dot product use: v1.dot(v2) 144 // * for a coeff-wise product use: v1.cwise()*v2 145 EIGEN_STATIC_ASSERT(ProductIsValid || !(AreVectors && SameSizes), 146 INVALID_VECTOR_VECTOR_PRODUCT__IF_YOU_WANTED_A_DOT_OR_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTIONS) 147 EIGEN_STATIC_ASSERT(ProductIsValid || !(SameSizes && !AreVectors), 148 INVALID_MATRIX_PRODUCT__IF_YOU_WANTED_A_COEFF_WISE_PRODUCT_YOU_MUST_USE_THE_EXPLICIT_FUNCTION) 149 EIGEN_STATIC_ASSERT(ProductIsValid || SameSizes, INVALID_MATRIX_PRODUCT) 150 } 151 152 LhsNested m_lhs; 153 RhsNested m_rhs; 154 RealScalar m_tolerance; 155 bool m_conservative; 156 }; 157 158 // sparse = sparse * sparse 159 template<typename Derived> 160 template<typename Lhs, typename Rhs> 161 inline Derived& SparseMatrixBase<Derived>::operator=(const SparseSparseProduct<Lhs,Rhs>& product) 162 { 163 product.evalTo(derived()); 164 return derived(); 165 } 166 167 /** \returns an expression of the product of two sparse matrices. 168 * By default a conservative product preserving the symbolic non zeros is performed. 169 * The automatic pruning of the small values can be achieved by calling the pruned() function 170 * in which case a totally different product algorithm is employed: 171 * \code 172 * C = (A*B).pruned(); // supress numerical zeros (exact) 173 * C = (A*B).pruned(ref); 174 * C = (A*B).pruned(ref,epsilon); 175 * \endcode 176 * where \c ref is a meaningful non zero reference value. 177 * */ 178 template<typename Derived> 179 template<typename OtherDerived> 180 inline const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type 181 SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const 182 { 183 return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived()); 184 } 185 186 } // end namespace Eigen 187 188 #endif // EIGEN_SPARSEPRODUCT_H 189