1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Mehdi Goli Codeplay Software Ltd. 5 // Ralph Potter Codeplay Software Ltd. 6 // Luke Iwanski Codeplay Software Ltd. 7 // Contact: <eigen@codeplay.com> 8 // 9 // This Source Code Form is subject to the terms of the Mozilla 10 // Public License v. 2.0. If a copy of the MPL was not distributed 11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 12 13 /***************************************************************** 14 * TensorSyclextractFunctors.h 15 * 16 * \brief: 17 * Used to extract all the functors allocated to each node of the expression 18 *tree. 19 * 20 *****************************************************************/ 21 22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP 23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP 24 25 namespace Eigen { 26 namespace TensorSycl { 27 namespace internal { 28 /// \struct FunctorExtractor: This struct is used to extract the functors 29 /// constructed on 30 /// the host-side, to pack them and reuse them in reconstruction of the 31 /// expression on the device. 32 /// We have to do that as in Eigen the functors are not stateless so we cannot 33 /// re-instantiate them on the device. 34 /// We have to pass instantiated functors to the device. 35 // This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval). 36 template <typename Evaluator> struct FunctorExtractor{ 37 typedef typename Evaluator::Dimensions Dimensions; 38 const Dimensions m_dimensions; dimensionsFunctorExtractor39 const Dimensions& dimensions() const { return m_dimensions; } FunctorExtractorFunctorExtractor40 FunctorExtractor(const Evaluator& expr) 41 : m_dimensions(expr.dimensions()) {} 42 43 }; 44 45 /// specialisation of the \ref FunctorExtractor struct when the node type is 46 /// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp 47 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> 48 struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > { 49 FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; 50 OP func; 51 FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr) 52 : rhsExpr(expr.impl()), func(expr.functor()) {} 53 }; 54 /// specialisation of the \ref FunctorExtractor struct when the node type is 55 /// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp 56 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev> 57 struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> > 58 : FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{}; 59 60 /// specialisation of the \ref FunctorExtractor struct when the node type is 61 /// const TensorCwiseBinaryOp 62 template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> 63 struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > { 64 FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; 65 FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; 66 OP func; 67 FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr) 68 : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {} 69 }; 70 71 /// specialisation of the \ref FunctorExtractor struct when the node type is 72 /// const TensorCwiseBinaryOp 73 template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev> 74 struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > 75 : FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> >{}; 76 77 /// specialisation of the \ref FunctorExtractor struct when the node type is 78 /// const TensorCwiseTernaryOp 79 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev> 80 struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > { 81 FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr; 82 FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr; 83 FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr; 84 OP func; 85 FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr) 86 : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {} 87 }; 88 89 /// specialisation of the \ref FunctorExtractor struct when the node type is 90 /// TensorCwiseTernaryOp 91 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev> 92 struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > 93 :FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{}; 94 95 /// specialisation of the \ref FunctorExtractor struct when the node type is 96 /// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated. 97 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> 98 struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > { 99 FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr; 100 FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr; 101 FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr; 102 FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr) 103 : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {} 104 }; 105 106 /// specialisation of the \ref FunctorExtractor struct when the node type is 107 /// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated 108 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev> 109 struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > 110 :FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {}; 111 112 /// specialisation of the \ref FunctorExtractor struct when the node type is 113 /// const TensorAssignOp. This is an specialisation without OP so it has to be separated. 114 template <typename LHSExpr, typename RHSExpr, typename Dev> 115 struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > { 116 FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr; 117 FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; 118 FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr) 119 : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {} 120 }; 121 122 /// specialisation of the \ref FunctorExtractor struct when the node type is 123 /// TensorAssignOp. This is an specialisation without OP so it has to be separated. 124 template <typename LHSExpr, typename RHSExpr, typename Dev> 125 struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> > 126 :FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{}; 127 128 129 /// specialisation of the \ref FunctorExtractor struct when the node type is 130 /// const TensorEvalToOp, This is an specialisation without OP so it has to be separated. 131 template <typename RHSExpr, typename Dev> 132 struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > { 133 FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr; 134 FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr) 135 : rhsExpr(expr.impl()) {} 136 }; 137 138 /// specialisation of the \ref FunctorExtractor struct when the node type is 139 /// TensorEvalToOp. This is a specialisation without OP so it has to be separated. 140 template <typename RHSExpr, typename Dev> 141 struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> > 142 : FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {}; 143 144 template<typename Dim, size_t NumOutputDim> struct DimConstr { 145 template<typename InDim> 146 static inline Dim getDim(InDim dims ) {return dims;} 147 }; 148 149 template<typename Dim> struct DimConstr<Dim, 0> { 150 template<typename InDim> 151 static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());} 152 }; 153 154 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device> 155 struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{ 156 typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator; 157 typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions; 158 const Dimensions m_dimensions; 159 const Dimensions& dimensions() const { return m_dimensions; } 160 FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr) 161 : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {} 162 }; 163 164 165 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device> 166 struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>> 167 : FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{}; 168 /// template deduction function for FunctorExtractor 169 template <typename Evaluator> 170 auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> { 171 return FunctorExtractor<Evaluator>(evaluator); 172 } 173 } // namespace internal 174 } // namespace TensorSycl 175 } // namespace Eigen 176 177 #endif // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP 178