• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_ARRAYWRAPPER_H
11 #define EIGEN_ARRAYWRAPPER_H
12 
13 namespace Eigen {
14 
15 /** \class ArrayWrapper
16   * \ingroup Core_Module
17   *
18   * \brief Expression of a mathematical vector or matrix as an array object
19   *
20   * This class is the return type of MatrixBase::array(), and most of the time
21   * this is the only way it is use.
22   *
23   * \sa MatrixBase::array(), class MatrixWrapper
24   */
25 
26 namespace internal {
27 template<typename ExpressionType>
28 struct traits<ArrayWrapper<ExpressionType> >
29   : public traits<typename remove_all<typename ExpressionType::Nested>::type >
30 {
31   typedef ArrayXpr XprKind;
32   // Let's remove NestByRefBit
33   enum {
34     Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
35     Flags = Flags0 & ~NestByRefBit
36   };
37 };
38 }
39 
40 template<typename ExpressionType>
41 class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
42 {
43   public:
44     typedef ArrayBase<ArrayWrapper> Base;
45     EIGEN_DENSE_PUBLIC_INTERFACE(ArrayWrapper)
46     EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ArrayWrapper)
47 
48     typedef typename internal::conditional<
49                        internal::is_lvalue<ExpressionType>::value,
50                        Scalar,
51                        const Scalar
52                      >::type ScalarWithConstIfNotLvalue;
53 
54     typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
55 
56     inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {}
57 
58     inline Index rows() const { return m_expression.rows(); }
59     inline Index cols() const { return m_expression.cols(); }
60     inline Index outerStride() const { return m_expression.outerStride(); }
61     inline Index innerStride() const { return m_expression.innerStride(); }
62 
63     inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
64     inline const Scalar* data() const { return m_expression.data(); }
65 
66     inline CoeffReturnType coeff(Index rowId, Index colId) const
67     {
68       return m_expression.coeff(rowId, colId);
69     }
70 
71     inline Scalar& coeffRef(Index rowId, Index colId)
72     {
73       return m_expression.const_cast_derived().coeffRef(rowId, colId);
74     }
75 
76     inline const Scalar& coeffRef(Index rowId, Index colId) const
77     {
78       return m_expression.const_cast_derived().coeffRef(rowId, colId);
79     }
80 
81     inline CoeffReturnType coeff(Index index) const
82     {
83       return m_expression.coeff(index);
84     }
85 
86     inline Scalar& coeffRef(Index index)
87     {
88       return m_expression.const_cast_derived().coeffRef(index);
89     }
90 
91     inline const Scalar& coeffRef(Index index) const
92     {
93       return m_expression.const_cast_derived().coeffRef(index);
94     }
95 
96     template<int LoadMode>
97     inline const PacketScalar packet(Index rowId, Index colId) const
98     {
99       return m_expression.template packet<LoadMode>(rowId, colId);
100     }
101 
102     template<int LoadMode>
103     inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
104     {
105       m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val);
106     }
107 
108     template<int LoadMode>
109     inline const PacketScalar packet(Index index) const
110     {
111       return m_expression.template packet<LoadMode>(index);
112     }
113 
114     template<int LoadMode>
115     inline void writePacket(Index index, const PacketScalar& val)
116     {
117       m_expression.const_cast_derived().template writePacket<LoadMode>(index, val);
118     }
119 
120     template<typename Dest>
121     inline void evalTo(Dest& dst) const { dst = m_expression; }
122 
123     const typename internal::remove_all<NestedExpressionType>::type&
124     nestedExpression() const
125     {
126       return m_expression;
127     }
128 
129     /** Forwards the resizing request to the nested expression
130       * \sa DenseBase::resize(Index)  */
131     void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
132     /** Forwards the resizing request to the nested expression
133       * \sa DenseBase::resize(Index,Index)*/
134     void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
135 
136   protected:
137     NestedExpressionType m_expression;
138 };
139 
140 /** \class MatrixWrapper
141   * \ingroup Core_Module
142   *
143   * \brief Expression of an array as a mathematical vector or matrix
144   *
145   * This class is the return type of ArrayBase::matrix(), and most of the time
146   * this is the only way it is use.
147   *
148   * \sa MatrixBase::matrix(), class ArrayWrapper
149   */
150 
151 namespace internal {
152 template<typename ExpressionType>
153 struct traits<MatrixWrapper<ExpressionType> >
154  : public traits<typename remove_all<typename ExpressionType::Nested>::type >
155 {
156   typedef MatrixXpr XprKind;
157   // Let's remove NestByRefBit
158   enum {
159     Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags,
160     Flags = Flags0 & ~NestByRefBit
161   };
162 };
163 }
164 
165 template<typename ExpressionType>
166 class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
167 {
168   public:
169     typedef MatrixBase<MatrixWrapper<ExpressionType> > Base;
170     EIGEN_DENSE_PUBLIC_INTERFACE(MatrixWrapper)
171     EIGEN_INHERIT_ASSIGNMENT_OPERATORS(MatrixWrapper)
172 
173     typedef typename internal::conditional<
174                        internal::is_lvalue<ExpressionType>::value,
175                        Scalar,
176                        const Scalar
177                      >::type ScalarWithConstIfNotLvalue;
178 
179     typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
180 
181     inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {}
182 
183     inline Index rows() const { return m_expression.rows(); }
184     inline Index cols() const { return m_expression.cols(); }
185     inline Index outerStride() const { return m_expression.outerStride(); }
186     inline Index innerStride() const { return m_expression.innerStride(); }
187 
188     inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
189     inline const Scalar* data() const { return m_expression.data(); }
190 
191     inline CoeffReturnType coeff(Index rowId, Index colId) const
192     {
193       return m_expression.coeff(rowId, colId);
194     }
195 
196     inline Scalar& coeffRef(Index rowId, Index colId)
197     {
198       return m_expression.const_cast_derived().coeffRef(rowId, colId);
199     }
200 
201     inline const Scalar& coeffRef(Index rowId, Index colId) const
202     {
203       return m_expression.derived().coeffRef(rowId, colId);
204     }
205 
206     inline CoeffReturnType coeff(Index index) const
207     {
208       return m_expression.coeff(index);
209     }
210 
211     inline Scalar& coeffRef(Index index)
212     {
213       return m_expression.const_cast_derived().coeffRef(index);
214     }
215 
216     inline const Scalar& coeffRef(Index index) const
217     {
218       return m_expression.const_cast_derived().coeffRef(index);
219     }
220 
221     template<int LoadMode>
222     inline const PacketScalar packet(Index rowId, Index colId) const
223     {
224       return m_expression.template packet<LoadMode>(rowId, colId);
225     }
226 
227     template<int LoadMode>
228     inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
229     {
230       m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val);
231     }
232 
233     template<int LoadMode>
234     inline const PacketScalar packet(Index index) const
235     {
236       return m_expression.template packet<LoadMode>(index);
237     }
238 
239     template<int LoadMode>
240     inline void writePacket(Index index, const PacketScalar& val)
241     {
242       m_expression.const_cast_derived().template writePacket<LoadMode>(index, val);
243     }
244 
245     const typename internal::remove_all<NestedExpressionType>::type&
246     nestedExpression() const
247     {
248       return m_expression;
249     }
250 
251     /** Forwards the resizing request to the nested expression
252       * \sa DenseBase::resize(Index)  */
253     void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
254     /** Forwards the resizing request to the nested expression
255       * \sa DenseBase::resize(Index,Index)*/
256     void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
257 
258   protected:
259     NestedExpressionType m_expression;
260 };
261 
262 } // end namespace Eigen
263 
264 #endif // EIGEN_ARRAYWRAPPER_H
265