• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11 #define EIGEN_TRIANGULARMATRIXVECTOR_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder, int Version=Specialized>
18 struct triangular_matrix_vector_product;
19 
20 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
21 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
22 {
23   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
24   enum {
25     IsLower = ((Mode&Lower)==Lower),
26     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
27     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
28   };
29   static EIGEN_DONT_INLINE  void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
30                                      const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
31 };
32 
33 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
34 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
35   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
36         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
37   {
38     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
39     Index size = (std::min)(_rows,_cols);
40     Index rows = IsLower ? _rows : (std::min)(_rows,_cols);
41     Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
42 
43     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
44     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
45     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
46 
47     typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
48     const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
49     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
50 
51     typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
52     ResMap res(_res,rows);
53 
54     for (Index pi=0; pi<size; pi+=PanelWidth)
55     {
56       Index actualPanelWidth = (std::min)(PanelWidth, size-pi);
57       for (Index k=0; k<actualPanelWidth; ++k)
58       {
59         Index i = pi + k;
60         Index s = IsLower ? ((HasUnitDiag||HasZeroDiag) ? i+1 : i ) : pi;
61         Index r = IsLower ? actualPanelWidth-k : k+1;
62         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
63           res.segment(s,r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s,r);
64         if (HasUnitDiag)
65           res.coeffRef(i) += alpha * cjRhs.coeff(i);
66       }
67       Index r = IsLower ? rows - pi - actualPanelWidth : pi;
68       if (r>0)
69       {
70         Index s = IsLower ? pi+actualPanelWidth : 0;
71         general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
72             r, actualPanelWidth,
73             &lhs.coeffRef(s,pi), lhsStride,
74             &rhs.coeffRef(pi), rhsIncr,
75             &res.coeffRef(s), resIncr, alpha);
76       }
77     }
78     if((!IsLower) && cols>size)
79     {
80       general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
81           rows, cols-size,
82           &lhs.coeffRef(0,size), lhsStride,
83           &rhs.coeffRef(size), rhsIncr,
84           _res, resIncr, alpha);
85     }
86   }
87 
88 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
89 struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
90 {
91   typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
92   enum {
93     IsLower = ((Mode&Lower)==Lower),
94     HasUnitDiag = (Mode & UnitDiag)==UnitDiag,
95     HasZeroDiag = (Mode & ZeroDiag)==ZeroDiag
96   };
97   static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
98                                     const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha);
99 };
100 
101 template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
102 EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
103   ::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride,
104         const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha)
105   {
106     static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
107     Index diagSize = (std::min)(_rows,_cols);
108     Index rows = IsLower ? _rows : diagSize;
109     Index cols = IsLower ? diagSize : _cols;
110 
111     typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
112     const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
113     typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
114 
115     typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
116     const RhsMap rhs(_rhs,cols);
117     typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
118 
119     typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
120     ResMap res(_res,rows,InnerStride<>(resIncr));
121 
122     for (Index pi=0; pi<diagSize; pi+=PanelWidth)
123     {
124       Index actualPanelWidth = (std::min)(PanelWidth, diagSize-pi);
125       for (Index k=0; k<actualPanelWidth; ++k)
126       {
127         Index i = pi + k;
128         Index s = IsLower ? pi  : ((HasUnitDiag||HasZeroDiag) ? i+1 : i);
129         Index r = IsLower ? k+1 : actualPanelWidth-k;
130         if ((!(HasUnitDiag||HasZeroDiag)) || (--r)>0)
131           res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s,r).cwiseProduct(cjRhs.segment(s,r).transpose())).sum();
132         if (HasUnitDiag)
133           res.coeffRef(i) += alpha * cjRhs.coeff(i);
134       }
135       Index r = IsLower ? pi : cols - pi - actualPanelWidth;
136       if (r>0)
137       {
138         Index s = IsLower ? 0 : pi + actualPanelWidth;
139         general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs,BuiltIn>::run(
140             actualPanelWidth, r,
141             &lhs.coeffRef(pi,s), lhsStride,
142             &rhs.coeffRef(s), rhsIncr,
143             &res.coeffRef(pi), resIncr, alpha);
144       }
145     }
146     if(IsLower && rows>diagSize)
147     {
148       general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
149             rows-diagSize, cols,
150             &lhs.coeffRef(diagSize,0), lhsStride,
151             &rhs.coeffRef(0), rhsIncr,
152             &res.coeffRef(diagSize), resIncr, alpha);
153     }
154   }
155 
156 /***************************************************************************
157 * Wrapper to product_triangular_vector
158 ***************************************************************************/
159 
160 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
161 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true> >
162  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,true>, Lhs, Rhs> >
163 {};
164 
165 template<int Mode, bool LhsIsTriangular, typename Lhs, typename Rhs>
166 struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
167  : traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
168 {};
169 
170 
171 template<int StorageOrder>
172 struct trmv_selector;
173 
174 } // end namespace internal
175 
176 template<int Mode, typename Lhs, typename Rhs>
177 struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
178   : public ProductBase<TriangularProduct<Mode,true,Lhs,false,Rhs,true>, Lhs, Rhs >
179 {
180   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
181 
182   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
183 
184   template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
185   {
186     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
187 
188     internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
189   }
190 };
191 
192 template<int Mode, typename Lhs, typename Rhs>
193 struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
194   : public ProductBase<TriangularProduct<Mode,false,Lhs,true,Rhs,false>, Lhs, Rhs >
195 {
196   EIGEN_PRODUCT_PUBLIC_INTERFACE(TriangularProduct)
197 
198   TriangularProduct(const Lhs& lhs, const Rhs& rhs) : Base(lhs,rhs) {}
199 
200   template<typename Dest> void scaleAndAddTo(Dest& dst, const Scalar& alpha) const
201   {
202     eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
203 
204     typedef TriangularProduct<(Mode & (UnitDiag|ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
205     Transpose<Dest> dstT(dst);
206     internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
207       TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
208   }
209 };
210 
211 namespace internal {
212 
213 // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
214 
215 template<> struct trmv_selector<ColMajor>
216 {
217   template<int Mode, typename Lhs, typename Rhs, typename Dest>
218   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
219   {
220     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
221     typedef typename ProductType::Index Index;
222     typedef typename ProductType::LhsScalar   LhsScalar;
223     typedef typename ProductType::RhsScalar   RhsScalar;
224     typedef typename ProductType::Scalar      ResScalar;
225     typedef typename ProductType::RealScalar  RealScalar;
226     typedef typename ProductType::ActualLhsType ActualLhsType;
227     typedef typename ProductType::ActualRhsType ActualRhsType;
228     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
229     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
230     typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
231 
232     typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
233     typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
234 
235     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
236                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
237 
238     enum {
239       // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
240       // on, the other hand it is good for the cache to pack the vector anyways...
241       EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
242       ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
243       MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
244     };
245 
246     gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
247 
248     bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
249     bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
250 
251     RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
252 
253     ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
254                                                   evalToDest ? dest.data() : static_dest.data());
255 
256     if(!evalToDest)
257     {
258       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
259       Index size = dest.size();
260       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
261       #endif
262       if(!alphaIsCompatible)
263       {
264         MappedDest(actualDestPtr, dest.size()).setZero();
265         compatibleAlpha = RhsScalar(1);
266       }
267       else
268         MappedDest(actualDestPtr, dest.size()) = dest;
269     }
270 
271     internal::triangular_matrix_vector_product
272       <Index,Mode,
273        LhsScalar, LhsBlasTraits::NeedToConjugate,
274        RhsScalar, RhsBlasTraits::NeedToConjugate,
275        ColMajor>
276       ::run(actualLhs.rows(),actualLhs.cols(),
277             actualLhs.data(),actualLhs.outerStride(),
278             actualRhs.data(),actualRhs.innerStride(),
279             actualDestPtr,1,compatibleAlpha);
280 
281     if (!evalToDest)
282     {
283       if(!alphaIsCompatible)
284         dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
285       else
286         dest = MappedDest(actualDestPtr, dest.size());
287     }
288   }
289 };
290 
291 template<> struct trmv_selector<RowMajor>
292 {
293   template<int Mode, typename Lhs, typename Rhs, typename Dest>
294   static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, const typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar& alpha)
295   {
296     typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
297     typedef typename ProductType::LhsScalar LhsScalar;
298     typedef typename ProductType::RhsScalar RhsScalar;
299     typedef typename ProductType::Scalar    ResScalar;
300     typedef typename ProductType::Index Index;
301     typedef typename ProductType::ActualLhsType ActualLhsType;
302     typedef typename ProductType::ActualRhsType ActualRhsType;
303     typedef typename ProductType::_ActualRhsType _ActualRhsType;
304     typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
305     typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
306 
307     typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
308     typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
309 
310     ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
311                                   * RhsBlasTraits::extractScalarFactor(prod.rhs());
312 
313     enum {
314       DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
315     };
316 
317     gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
318 
319     ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
320         DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
321 
322     if(!DirectlyUseRhs)
323     {
324       #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
325       int size = actualRhs.size();
326       EIGEN_DENSE_STORAGE_CTOR_PLUGIN
327       #endif
328       Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
329     }
330 
331     internal::triangular_matrix_vector_product
332       <Index,Mode,
333        LhsScalar, LhsBlasTraits::NeedToConjugate,
334        RhsScalar, RhsBlasTraits::NeedToConjugate,
335        RowMajor>
336       ::run(actualLhs.rows(),actualLhs.cols(),
337             actualLhs.data(),actualLhs.outerStride(),
338             actualRhsPtr,1,
339             dest.data(),dest.innerStride(),
340             actualAlpha);
341   }
342 };
343 
344 } // end namespace internal
345 
346 } // end namespace Eigen
347 
348 #endif // EIGEN_TRIANGULARMATRIXVECTOR_H
349