1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
12 
13 namespace Eigen {
14 
15 /** \class TensorDevice
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief Pseudo expression providing an operator = that will evaluate its argument
19   * on the specified computing 'device' (GPU, thread pool, ...)
20   *
21   * Example:
22   *    C.device(EIGEN_GPU) = A + B;
23   *
24   * Todo: operator *= and /=.
25   */
26 
27 template <typename ExpressionType, typename DeviceType> class TensorDevice {
28   public:
TensorDevice(const DeviceType & device,ExpressionType & expression)29     TensorDevice(const DeviceType& device, ExpressionType& expression) : m_device(device), m_expression(expression) {}
30 
31     template<typename OtherDerived>
32     EIGEN_STRONG_INLINE TensorDevice& operator=(const OtherDerived& other) {
33       typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
34       Assign assign(m_expression, other);
35       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
36       return *this;
37     }
38 
39     template<typename OtherDerived>
40     EIGEN_STRONG_INLINE TensorDevice& operator+=(const OtherDerived& other) {
41       typedef typename OtherDerived::Scalar Scalar;
42       typedef TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const ExpressionType, const OtherDerived> Sum;
43       Sum sum(m_expression, other);
44       typedef TensorAssignOp<ExpressionType, const Sum> Assign;
45       Assign assign(m_expression, sum);
46       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
47       return *this;
48     }
49 
50     template<typename OtherDerived>
51     EIGEN_STRONG_INLINE TensorDevice& operator-=(const OtherDerived& other) {
52       typedef typename OtherDerived::Scalar Scalar;
53       typedef TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const ExpressionType, const OtherDerived> Difference;
54       Difference difference(m_expression, other);
55       typedef TensorAssignOp<ExpressionType, const Difference> Assign;
56       Assign assign(m_expression, difference);
57       internal::TensorExecutor<const Assign, DeviceType>::run(assign, m_device);
58       return *this;
59     }
60 
61   protected:
62     const DeviceType& m_device;
63     ExpressionType& m_expression;
64 };
65 
66 } // end namespace Eigen
67 
68 #endif // EIGEN_CXX11_TENSOR_TENSOR_DEVICE_H
69