1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H
12 
13 namespace Eigen {
14 
15 namespace internal {
16 
17 template <typename Dimensions, typename Scalar>
18 class TensorLazyBaseEvaluator {
19  public:
TensorLazyBaseEvaluator()20   TensorLazyBaseEvaluator() : m_refcount(0) { }
~TensorLazyBaseEvaluator()21   virtual ~TensorLazyBaseEvaluator() { }
22 
23   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
24   EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
25 
26   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
27   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
28 
incrRefCount()29   void incrRefCount() { ++m_refcount; }
decrRefCount()30   void decrRefCount() { --m_refcount; }
refCount()31   int refCount() const { return m_refcount; }
32 
33  private:
34   // No copy, no assigment;
35   TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
36   TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
37 
38   int m_refcount;
39 };
40 
41 
42 template <typename Dimensions, typename Expr, typename Device>
43 class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
44  public:
45   //  typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
46   typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
47 
TensorLazyEvaluatorReadOnly(const Expr & expr,const Device & device)48   TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
49     m_dims = m_impl.dimensions();
50     m_impl.evalSubExprsIfNeeded(NULL);
51   }
~TensorLazyEvaluatorReadOnly()52   virtual ~TensorLazyEvaluatorReadOnly() {
53     m_impl.cleanup();
54   }
55 
dimensions()56   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const {
57     return m_dims;
58   }
data()59   EIGEN_DEVICE_FUNC virtual const Scalar* data() const {
60     return m_impl.data();
61   }
62 
coeff(DenseIndex index)63   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const {
64     return m_impl.coeff(index);
65   }
coeffRef(DenseIndex)66   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
67     eigen_assert(false && "can't reference the coefficient of a rvalue");
68     return m_dummy;
69   };
70 
71  protected:
72   TensorEvaluator<Expr, Device> m_impl;
73   Dimensions m_dims;
74   Scalar m_dummy;
75 };
76 
77 template <typename Dimensions, typename Expr, typename Device>
78 class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
79  public:
80   typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
81   typedef typename Base::Scalar Scalar;
82 
TensorLazyEvaluatorWritable(const Expr & expr,const Device & device)83   TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
84   }
~TensorLazyEvaluatorWritable()85   virtual ~TensorLazyEvaluatorWritable() {
86   }
87 
coeffRef(DenseIndex index)88   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) {
89     return this->m_impl.coeffRef(index);
90   }
91 };
92 
93 template <typename Dimensions, typename Expr, typename Device>
94 class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
95                             TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
96                             TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
97  public:
98   typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
99                                          TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
100                                          TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
101   typedef typename Base::Scalar Scalar;
102 
TensorLazyEvaluator(const Expr & expr,const Device & device)103   TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
104   }
~TensorLazyEvaluator()105   virtual ~TensorLazyEvaluator() {
106   }
107 };
108 
109 }  // namespace internal
110 
111 
112 /** \class TensorRef
113   * \ingroup CXX11_Tensor_Module
114   *
115   * \brief A reference to a tensor expression
116   * The expression will be evaluated lazily (as much as possible).
117   *
118   */
119 template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
120 {
121   public:
122     typedef TensorRef<PlainObjectType> Self;
123     typedef typename PlainObjectType::Base Base;
124     typedef typename Eigen::internal::nested<Self>::type Nested;
125     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
126     typedef typename internal::traits<PlainObjectType>::Index Index;
127     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
128     typedef typename NumTraits<Scalar>::Real RealScalar;
129     typedef typename Base::CoeffReturnType CoeffReturnType;
130     typedef Scalar* PointerType;
131     typedef PointerType PointerArgType;
132 
133     static const Index NumIndices = PlainObjectType::NumIndices;
134     typedef typename PlainObjectType::Dimensions Dimensions;
135 
136     enum {
137       IsAligned = false,
138       PacketAccess = false,
139       Layout = PlainObjectType::Layout,
140       CoordAccess = false,  // to be implemented
141       RawAccess = false
142     };
143 
TensorRef()144     EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
145     }
146 
147     template <typename Expression>
TensorRef(const Expression & expr)148     EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
149       m_evaluator->incrRefCount();
150     }
151 
152     template <typename Expression>
153     EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
154       unrefEvaluator();
155       m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
156       m_evaluator->incrRefCount();
157       return *this;
158     }
159 
~TensorRef()160     ~TensorRef() {
161       unrefEvaluator();
162     }
163 
TensorRef(const TensorRef & other)164     TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
165       eigen_assert(m_evaluator->refCount() > 0);
166       m_evaluator->incrRefCount();
167     }
168 
169     TensorRef& operator = (const TensorRef& other) {
170       if (this != &other) {
171         unrefEvaluator();
172         m_evaluator = other.m_evaluator;
173         eigen_assert(m_evaluator->refCount() > 0);
174         m_evaluator->incrRefCount();
175       }
176       return *this;
177     }
178 
179     EIGEN_DEVICE_FUNC
rank()180     EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
181     EIGEN_DEVICE_FUNC
dimension(Index n)182     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
183     EIGEN_DEVICE_FUNC
dimensions()184     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
185     EIGEN_DEVICE_FUNC
size()186     EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
187     EIGEN_DEVICE_FUNC
data()188     EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
189 
190     EIGEN_DEVICE_FUNC
operator()191     EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
192     {
193       return m_evaluator->coeff(index);
194     }
195 
196 #if EIGEN_HAS_VARIADIC_TEMPLATES
197     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()198     EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
199     {
200       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
201       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
202       return coeff(indices);
203     }
204     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
coeffRef(Index firstIndex,IndexTypes...otherIndices)205     EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
206     {
207       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
208       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
209       return coeffRef(indices);
210     }
211 #else
212 
213     EIGEN_DEVICE_FUNC
operator()214     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
215     {
216       array<Index, 2> indices;
217       indices[0] = i0;
218       indices[1] = i1;
219       return coeff(indices);
220     }
221     EIGEN_DEVICE_FUNC
operator()222     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
223     {
224       array<Index, 3> indices;
225       indices[0] = i0;
226       indices[1] = i1;
227       indices[2] = i2;
228       return coeff(indices);
229     }
230     EIGEN_DEVICE_FUNC
operator()231     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
232     {
233       array<Index, 4> indices;
234       indices[0] = i0;
235       indices[1] = i1;
236       indices[2] = i2;
237       indices[3] = i3;
238       return coeff(indices);
239     }
240     EIGEN_DEVICE_FUNC
operator()241     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
242     {
243       array<Index, 5> indices;
244       indices[0] = i0;
245       indices[1] = i1;
246       indices[2] = i2;
247       indices[3] = i3;
248       indices[4] = i4;
249       return coeff(indices);
250     }
251     EIGEN_DEVICE_FUNC
coeffRef(Index i0,Index i1)252     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
253     {
254       array<Index, 2> indices;
255       indices[0] = i0;
256       indices[1] = i1;
257       return coeffRef(indices);
258     }
259     EIGEN_DEVICE_FUNC
coeffRef(Index i0,Index i1,Index i2)260     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
261     {
262       array<Index, 3> indices;
263       indices[0] = i0;
264       indices[1] = i1;
265       indices[2] = i2;
266       return coeffRef(indices);
267     }
268     EIGEN_DEVICE_FUNC
operator()269     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
270     {
271       array<Index, 4> indices;
272       indices[0] = i0;
273       indices[1] = i1;
274       indices[2] = i2;
275       indices[3] = i3;
276       return coeffRef(indices);
277     }
278     EIGEN_DEVICE_FUNC
coeffRef(Index i0,Index i1,Index i2,Index i3,Index i4)279     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
280     {
281       array<Index, 5> indices;
282       indices[0] = i0;
283       indices[1] = i1;
284       indices[2] = i2;
285       indices[3] = i3;
286       indices[4] = i4;
287       return coeffRef(indices);
288     }
289 #endif
290 
291     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
coeff(const array<Index,NumIndices> & indices)292     EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
293     {
294       const Dimensions& dims = this->dimensions();
295       Index index = 0;
296       if (PlainObjectType::Options & RowMajor) {
297         index += indices[0];
298         for (size_t i = 1; i < NumIndices; ++i) {
299           index = index * dims[i] + indices[i];
300         }
301       } else {
302         index += indices[NumIndices-1];
303         for (int i = NumIndices-2; i >= 0; --i) {
304           index = index * dims[i] + indices[i];
305         }
306       }
307       return m_evaluator->coeff(index);
308     }
309     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
coeffRef(const array<Index,NumIndices> & indices)310     EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
311     {
312       const Dimensions& dims = this->dimensions();
313       Index index = 0;
314       if (PlainObjectType::Options & RowMajor) {
315         index += indices[0];
316         for (size_t i = 1; i < NumIndices; ++i) {
317           index = index * dims[i] + indices[i];
318         }
319       } else {
320         index += indices[NumIndices-1];
321         for (int i = NumIndices-2; i >= 0; --i) {
322           index = index * dims[i] + indices[i];
323         }
324       }
325       return m_evaluator->coeffRef(index);
326     }
327 
328     EIGEN_DEVICE_FUNC
coeff(Index index)329     EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
330     {
331       return m_evaluator->coeff(index);
332     }
333 
334     EIGEN_DEVICE_FUNC
coeffRef(Index index)335     EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
336     {
337       return m_evaluator->coeffRef(index);
338     }
339 
340   private:
unrefEvaluator()341     EIGEN_STRONG_INLINE void unrefEvaluator() {
342       if (m_evaluator) {
343         m_evaluator->decrRefCount();
344         if (m_evaluator->refCount() == 0) {
345           delete m_evaluator;
346         }
347       }
348     }
349 
350   internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
351 };
352 
353 
354 // evaluator for rvalues
355 template<typename Derived, typename Device>
356 struct TensorEvaluator<const TensorRef<Derived>, Device>
357 {
358   typedef typename Derived::Index Index;
359   typedef typename Derived::Scalar Scalar;
360   typedef typename Derived::Scalar CoeffReturnType;
361   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
362   typedef typename Derived::Dimensions Dimensions;
363 
364   enum {
365     IsAligned = false,
366     PacketAccess = false,
367     Layout = TensorRef<Derived>::Layout,
368     CoordAccess = false,  // to be implemented
369     RawAccess = false
370   };
371 
372   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
373       : m_ref(m)
374   { }
375 
376   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
377 
378   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
379     return true;
380   }
381 
382   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
383 
384   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
385     return m_ref.coeff(index);
386   }
387 
388   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
389     return m_ref.coeffRef(index);
390   }
391 
392   EIGEN_DEVICE_FUNC Scalar* data() const { return m_ref.data(); }
393 
394  protected:
395   TensorRef<Derived> m_ref;
396 };
397 
398 
399 // evaluator for lvalues
400 template<typename Derived, typename Device>
401 struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
402 {
403   typedef typename Derived::Index Index;
404   typedef typename Derived::Scalar Scalar;
405   typedef typename Derived::Scalar CoeffReturnType;
406   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
407   typedef typename Derived::Dimensions Dimensions;
408 
409   typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
410 
411   enum {
412     IsAligned = false,
413     PacketAccess = false,
414     RawAccess = false
415   };
416 
417   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
418   { }
419 
420   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
421     return this->m_ref.coeffRef(index);
422   }
423 };
424 
425 
426 
427 } // end namespace Eigen
428 
429 #endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
430