1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
5 // Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
6 // Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
7 //
8 // This Source Code Form is subject to the terms of the Mozilla
9 // Public License v. 2.0. If a copy of the MPL was not distributed
10 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
11 
12 #ifndef EIGEN_CWISE_TERNARY_OP_H
13 #define EIGEN_CWISE_TERNARY_OP_H
14 
15 namespace Eigen {
16 
17 namespace internal {
18 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
19 struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
20   // we must not inherit from traits<Arg1> since it has
21   // the potential to cause problems with MSVC
22   typedef typename remove_all<Arg1>::type Ancestor;
23   typedef typename traits<Ancestor>::XprKind XprKind;
24   enum {
25     RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
26     ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
27     MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
28     MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
29   };
30 
31   // even though we require Arg1, Arg2, and Arg3 to have the same scalar type
32   // (see CwiseTernaryOp constructor),
33   // we still want to handle the case when the result type is different.
34   typedef typename result_of<TernaryOp(
35       const typename Arg1::Scalar&, const typename Arg2::Scalar&,
36       const typename Arg3::Scalar&)>::type Scalar;
37 
38   typedef typename internal::traits<Arg1>::StorageKind StorageKind;
39   typedef typename internal::traits<Arg1>::StorageIndex StorageIndex;
40 
41   typedef typename Arg1::Nested Arg1Nested;
42   typedef typename Arg2::Nested Arg2Nested;
43   typedef typename Arg3::Nested Arg3Nested;
44   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
45   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
46   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
47   enum { Flags = _Arg1Nested::Flags & RowMajorBit };
48 };
49 }  // end namespace internal
50 
51 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
52           typename StorageKind>
53 class CwiseTernaryOpImpl;
54 
55 /** \class CwiseTernaryOp
56   * \ingroup Core_Module
57   *
58   * \brief Generic expression where a coefficient-wise ternary operator is
59  * applied to two expressions
60   *
61   * \tparam TernaryOp template functor implementing the operator
62   * \tparam Arg1Type the type of the first argument
63   * \tparam Arg2Type the type of the second argument
64   * \tparam Arg3Type the type of the third argument
65   *
66   * This class represents an expression where a coefficient-wise ternary
67  * operator is applied to three expressions.
68   * It is the return type of ternary operators, by which we mean only those
69  * ternary operators where
70   * all three arguments are Eigen expressions.
71   * For example, the return type of betainc(matrix1, matrix2, matrix3) is a
72  * CwiseTernaryOp.
73   *
74   * Most of the time, this is the only way that it is used, so you typically
75  * don't have to name
76   * CwiseTernaryOp types explicitly.
77   *
78   * \sa MatrixBase::ternaryExpr(const MatrixBase<Argument2> &, const
79  * MatrixBase<Argument3> &, const CustomTernaryOp &) const, class CwiseBinaryOp,
80  * class CwiseUnaryOp, class CwiseNullaryOp
81   */
82 template <typename TernaryOp, typename Arg1Type, typename Arg2Type,
83           typename Arg3Type>
84 class CwiseTernaryOp : public CwiseTernaryOpImpl<
85                            TernaryOp, Arg1Type, Arg2Type, Arg3Type,
86                            typename internal::traits<Arg1Type>::StorageKind>,
87                        internal::no_assignment_operator
88 {
89  public:
90   typedef typename internal::remove_all<Arg1Type>::type Arg1;
91   typedef typename internal::remove_all<Arg2Type>::type Arg2;
92   typedef typename internal::remove_all<Arg3Type>::type Arg3;
93 
94   typedef typename CwiseTernaryOpImpl<
95       TernaryOp, Arg1Type, Arg2Type, Arg3Type,
96       typename internal::traits<Arg1Type>::StorageKind>::Base Base;
97   EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
98 
99   typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
100   typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
101   typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
102   typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
103   typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
104   typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
105 
106   EIGEN_DEVICE_FUNC
107   EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
108                                      const Arg3& a3,
109                                      const TernaryOp& func = TernaryOp())
110       : m_arg1(a1), m_arg2(a2), m_arg3(a3), m_functor(func) {
111     // require the sizes to match
112     EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
113     EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
114 
115     // The index types should match
116     EIGEN_STATIC_ASSERT((internal::is_same<
117                          typename internal::traits<Arg1Type>::StorageKind,
118                          typename internal::traits<Arg2Type>::StorageKind>::value),
119                         STORAGE_KIND_MUST_MATCH)
120     EIGEN_STATIC_ASSERT((internal::is_same<
121                          typename internal::traits<Arg1Type>::StorageKind,
122                          typename internal::traits<Arg3Type>::StorageKind>::value),
123                         STORAGE_KIND_MUST_MATCH)
124 
125     eigen_assert(a1.rows() == a2.rows() && a1.cols() == a2.cols() &&
126                  a1.rows() == a3.rows() && a1.cols() == a3.cols());
127   }
128 
129   EIGEN_DEVICE_FUNC
130   EIGEN_STRONG_INLINE Index rows() const {
131     // return the fixed size type if available to enable compile time
132     // optimizations
133     if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
134                 RowsAtCompileTime == Dynamic &&
135         internal::traits<typename internal::remove_all<Arg2Nested>::type>::
136                 RowsAtCompileTime == Dynamic)
137       return m_arg3.rows();
138     else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
139                      RowsAtCompileTime == Dynamic &&
140              internal::traits<typename internal::remove_all<Arg3Nested>::type>::
141                      RowsAtCompileTime == Dynamic)
142       return m_arg2.rows();
143     else
144       return m_arg1.rows();
145   }
146   EIGEN_DEVICE_FUNC
147   EIGEN_STRONG_INLINE Index cols() const {
148     // return the fixed size type if available to enable compile time
149     // optimizations
150     if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
151                 ColsAtCompileTime == Dynamic &&
152         internal::traits<typename internal::remove_all<Arg2Nested>::type>::
153                 ColsAtCompileTime == Dynamic)
154       return m_arg3.cols();
155     else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
156                      ColsAtCompileTime == Dynamic &&
157              internal::traits<typename internal::remove_all<Arg3Nested>::type>::
158                      ColsAtCompileTime == Dynamic)
159       return m_arg2.cols();
160     else
161       return m_arg1.cols();
162   }
163 
164   /** \returns the first argument nested expression */
165   EIGEN_DEVICE_FUNC
166   const _Arg1Nested& arg1() const { return m_arg1; }
167   /** \returns the first argument nested expression */
168   EIGEN_DEVICE_FUNC
169   const _Arg2Nested& arg2() const { return m_arg2; }
170   /** \returns the third argument nested expression */
171   EIGEN_DEVICE_FUNC
172   const _Arg3Nested& arg3() const { return m_arg3; }
173   /** \returns the functor representing the ternary operation */
174   EIGEN_DEVICE_FUNC
175   const TernaryOp& functor() const { return m_functor; }
176 
177  protected:
178   Arg1Nested m_arg1;
179   Arg2Nested m_arg2;
180   Arg3Nested m_arg3;
181   const TernaryOp m_functor;
182 };
183 
184 // Generic API dispatcher
185 template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
186           typename StorageKind>
187 class CwiseTernaryOpImpl
188     : public internal::generic_xpr_base<
189           CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
190  public:
191   typedef typename internal::generic_xpr_base<
192       CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
193 };
194 
195 }  // end namespace Eigen
196 
197 #endif  // EIGEN_CWISE_TERNARY_OP_H
198