1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
5 // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #ifndef EIGEN_SPARSE_TRIANGULARVIEW_H
12 #define EIGEN_SPARSE_TRIANGULARVIEW_H
13 
14 namespace Eigen {
15 
16 /** \ingroup SparseCore_Module
17   *
18   * \brief Base class for a triangular part in a \b sparse matrix
19   *
20   * This class is an abstract base class of class TriangularView, and objects of type TriangularViewImpl cannot be instantiated.
21   * It extends class TriangularView with additional methods which are available for sparse expressions only.
22   *
23   * \sa class TriangularView, SparseMatrixBase::triangularView()
24   */
25 template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
26   : public SparseMatrixBase<TriangularView<MatrixType,Mode> >
27 {
28     enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
29                     || ((Mode&Upper) &&  (MatrixType::Flags&RowMajorBit)),
30            SkipLast = !SkipFirst,
31            SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
32            HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
33     };
34 
35     typedef TriangularView<MatrixType,Mode> TriangularViewType;
36 
37   protected:
38     // dummy solve function to make TriangularView happy.
39     void solve() const;
40 
41     typedef SparseMatrixBase<TriangularViewType> Base;
42   public:
43 
44     EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
45 
46     typedef typename MatrixType::Nested MatrixTypeNested;
47     typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
48     typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
49 
50     template<typename RhsType, typename DstType>
51     EIGEN_DEVICE_FUNC
_solve_impl(const RhsType & rhs,DstType & dst)52     EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
53       if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
54         dst = rhs;
55       this->solveInPlace(dst);
56     }
57 
58     /** Applies the inverse of \c *this to the dense vector or matrix \a other, "in-place" */
59     template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
60 
61     /** Applies the inverse of \c *this to the sparse vector or matrix \a other, "in-place" */
62     template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
63 
64 };
65 
66 namespace internal {
67 
68 template<typename ArgType, unsigned int Mode>
69 struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased>
70  : evaluator_base<TriangularView<ArgType,Mode> >
71 {
72   typedef TriangularView<ArgType,Mode> XprType;
73 
74 protected:
75 
76   typedef typename XprType::Scalar Scalar;
77   typedef typename XprType::StorageIndex StorageIndex;
78   typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
79 
80   enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit))
81                     || ((Mode&Upper) &&  (ArgType::Flags&RowMajorBit)),
82          SkipLast = !SkipFirst,
83          SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
84          HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
85   };
86 
87 public:
88 
89   enum {
90     CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
91     Flags = XprType::Flags
92   };
93 
94   explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
95 
96   inline Index nonZerosEstimate() const {
97     return m_argImpl.nonZerosEstimate();
98   }
99 
100   class InnerIterator : public EvalIterator
101   {
102       typedef EvalIterator Base;
103     public:
104 
105       EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
106         : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
107       {
108         if(SkipFirst)
109         {
110           while((*this) && ((HasUnitDiag||SkipDiag)  ? this->index()<=outer : this->index()<outer))
111             Base::operator++();
112           if(HasUnitDiag)
113             m_returnOne = m_containsDiag;
114         }
115         else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
116         {
117           if((!SkipFirst) && Base::operator bool())
118             Base::operator++();
119           m_returnOne = m_containsDiag;
120         }
121       }
122 
123       EIGEN_STRONG_INLINE InnerIterator& operator++()
124       {
125         if(HasUnitDiag && m_returnOne)
126           m_returnOne = false;
127         else
128         {
129           Base::operator++();
130           if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
131           {
132             if((!SkipFirst) && Base::operator bool())
133               Base::operator++();
134             m_returnOne = m_containsDiag;
135           }
136         }
137         return *this;
138       }
139 
140       EIGEN_STRONG_INLINE operator bool() const
141       {
142         if(HasUnitDiag && m_returnOne)
143           return true;
144         if(SkipFirst) return  Base::operator bool();
145         else
146         {
147           if (SkipDiag) return (Base::operator bool() && this->index() < this->outer());
148           else return (Base::operator bool() && this->index() <= this->outer());
149         }
150       }
151 
152 //       inline Index row() const { return (ArgType::Flags&RowMajorBit ? Base::outer() : this->index()); }
153 //       inline Index col() const { return (ArgType::Flags&RowMajorBit ? this->index() : Base::outer()); }
154       inline StorageIndex index() const
155       {
156         if(HasUnitDiag && m_returnOne)  return internal::convert_index<StorageIndex>(Base::outer());
157         else                            return Base::index();
158       }
159       inline Scalar value() const
160       {
161         if(HasUnitDiag && m_returnOne)  return Scalar(1);
162         else                            return Base::value();
163       }
164 
165     protected:
166       bool m_returnOne;
167       bool m_containsDiag;
168     private:
169       Scalar& valueRef();
170   };
171 
172 protected:
173   evaluator<ArgType> m_argImpl;
174   const ArgType& m_arg;
175 };
176 
177 } // end namespace internal
178 
179 template<typename Derived>
180 template<int Mode>
181 inline const TriangularView<const Derived, Mode>
182 SparseMatrixBase<Derived>::triangularView() const
183 {
184   return TriangularView<const Derived, Mode>(derived());
185 }
186 
187 } // end namespace Eigen
188 
189 #endif // EIGEN_SPARSE_TRIANGULARVIEW_H
190