1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
11 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 
18 // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
19 template<typename Lhs, typename Rhs, typename ResultType>
sparse_sparse_product_with_pruning_impl(const Lhs & lhs,const Rhs & rhs,ResultType & res,const typename ResultType::RealScalar & tolerance)20 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
21 {
22   // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
23 
24   typedef typename remove_all<Lhs>::type::Scalar Scalar;
25   typedef typename remove_all<Lhs>::type::Index Index;
26 
27   // make sure to call innerSize/outerSize since we fake the storage order.
28   Index rows = lhs.innerSize();
29   Index cols = rhs.outerSize();
30   //Index size = lhs.outerSize();
31   eigen_assert(lhs.outerSize() == rhs.innerSize());
32 
33   // allocate a temporary buffer
34   AmbiVector<Scalar,Index> tempVector(rows);
35 
36   // estimate the number of non zero entries
37   // given a rhs column containing Y non zeros, we assume that the respective Y columns
38   // of the lhs differs in average of one non zeros, thus the number of non zeros for
39   // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
40   // per column of the lhs.
41   // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
42   Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros();
43 
44   // mimics a resizeByInnerOuter:
45   if(ResultType::IsRowMajor)
46     res.resize(cols, rows);
47   else
48     res.resize(rows, cols);
49 
50   res.reserve(estimated_nnz_prod);
51   double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
52   for (Index j=0; j<cols; ++j)
53   {
54     // FIXME:
55     //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
56     // let's do a more accurate determination of the nnz ratio for the current column j of res
57     tempVector.init(ratioColRes);
58     tempVector.setZero();
59     for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
60     {
61       // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
62       tempVector.restart();
63       Scalar x = rhsIt.value();
64       for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
65       {
66         tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
67       }
68     }
69     res.startVec(j);
70     for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it)
71       res.insertBackByOuterInner(j,it.index()) = it.value();
72   }
73   res.finalize();
74 }
75 
76 template<typename Lhs, typename Rhs, typename ResultType,
77   int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
78   int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
79   int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
80 struct sparse_sparse_product_with_pruning_selector;
81 
82 template<typename Lhs, typename Rhs, typename ResultType>
83 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
84 {
85   typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
86   typedef typename ResultType::RealScalar RealScalar;
87 
88   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
89   {
90     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
91     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
92     res.swap(_res);
93   }
94 };
95 
96 template<typename Lhs, typename Rhs, typename ResultType>
97 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
98 {
99   typedef typename ResultType::RealScalar RealScalar;
100   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
101   {
102     // we need a col-major matrix to hold the result
103     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::Index> SparseTemporaryType;
104     SparseTemporaryType _res(res.rows(), res.cols());
105     internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
106     res = _res;
107   }
108 };
109 
110 template<typename Lhs, typename Rhs, typename ResultType>
111 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
112 {
113   typedef typename ResultType::RealScalar RealScalar;
114   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
115   {
116     // let's transpose the product to get a column x column product
117     typename remove_all<ResultType>::type _res(res.rows(), res.cols());
118     internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
119     res.swap(_res);
120   }
121 };
122 
123 template<typename Lhs, typename Rhs, typename ResultType>
124 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
125 {
126   typedef typename ResultType::RealScalar RealScalar;
127   static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
128   {
129     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
130     typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
131     ColMajorMatrixLhs colLhs(lhs);
132     ColMajorMatrixRhs colRhs(rhs);
133     internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
134 
135     // let's transpose the product to get a column x column product
136 //     typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
137 //     SparseTemporaryType _res(res.cols(), res.rows());
138 //     sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
139 //     res = _res.transpose();
140   }
141 };
142 
143 // NOTE the 2 others cases (col row *) must never occur since they are caught
144 // by ProductReturnType which transforms it to (col col *) by evaluating rhs.
145 
146 } // end namespace internal
147 
148 } // end namespace Eigen
149 
150 #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
151