1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Igor Babuschkin <igor@babuschk.in>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 #include <limits>
12 #include <numeric>
13 #include <Eigen/CXX11/Tensor>
14 
15 using Eigen::Tensor;
16 
17 template <int DataLayout, typename Type=float, bool Exclusive = false>
test_1d_scan()18 static void test_1d_scan()
19 {
20   int size = 50;
21   Tensor<Type, 1, DataLayout> tensor(size);
22   tensor.setRandom();
23   Tensor<Type, 1, DataLayout> result = tensor.cumsum(0, Exclusive);
24 
25   VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0));
26 
27   float accum = 0;
28   for (int i = 0; i < size; i++) {
29     if (Exclusive) {
30       VERIFY_IS_EQUAL(result(i), accum);
31       accum += tensor(i);
32     } else {
33       accum += tensor(i);
34       VERIFY_IS_EQUAL(result(i), accum);
35     }
36   }
37 
38   accum = 1;
39   result = tensor.cumprod(0, Exclusive);
40   for (int i = 0; i < size; i++) {
41     if (Exclusive) {
42       VERIFY_IS_EQUAL(result(i), accum);
43       accum *= tensor(i);
44     } else {
45       accum *= tensor(i);
46       VERIFY_IS_EQUAL(result(i), accum);
47     }
48   }
49 }
50 
51 template <int DataLayout, typename Type=float>
test_4d_scan()52 static void test_4d_scan()
53 {
54   int size = 5;
55   Tensor<Type, 4, DataLayout> tensor(size, size, size, size);
56   tensor.setRandom();
57 
58   Tensor<Type, 4, DataLayout> result(size, size, size, size);
59 
60   result = tensor.cumsum(0);
61   float accum = 0;
62   for (int i = 0; i < size; i++) {
63     accum += tensor(i, 1, 2, 3);
64     VERIFY_IS_EQUAL(result(i, 1, 2, 3), accum);
65   }
66   result = tensor.cumsum(1);
67   accum = 0;
68   for (int i = 0; i < size; i++) {
69     accum += tensor(1, i, 2, 3);
70     VERIFY_IS_EQUAL(result(1, i, 2, 3), accum);
71   }
72   result = tensor.cumsum(2);
73   accum = 0;
74   for (int i = 0; i < size; i++) {
75     accum += tensor(1, 2, i, 3);
76     VERIFY_IS_EQUAL(result(1, 2, i, 3), accum);
77   }
78   result = tensor.cumsum(3);
79   accum = 0;
80   for (int i = 0; i < size; i++) {
81     accum += tensor(1, 2, 3, i);
82     VERIFY_IS_EQUAL(result(1, 2, 3, i), accum);
83   }
84 }
85 
86 template <int DataLayout>
test_tensor_maps()87 static void test_tensor_maps() {
88   int inputs[20];
89   TensorMap<Tensor<int, 1, DataLayout> > tensor_map(inputs, 20);
90   tensor_map.setRandom();
91 
92   Tensor<int, 1, DataLayout> result = tensor_map.cumsum(0);
93 
94   int accum = 0;
95   for (int i = 0; i < 20; ++i) {
96     accum += tensor_map(i);
97     VERIFY_IS_EQUAL(result(i), accum);
98   }
99 }
100 
test_cxx11_tensor_scan()101 void test_cxx11_tensor_scan() {
102   CALL_SUBTEST((test_1d_scan<ColMajor, float, true>()));
103   CALL_SUBTEST((test_1d_scan<ColMajor, float, false>()));
104   CALL_SUBTEST((test_1d_scan<RowMajor, float, true>()));
105   CALL_SUBTEST((test_1d_scan<RowMajor, float, false>()));
106   CALL_SUBTEST(test_4d_scan<ColMajor>());
107   CALL_SUBTEST(test_4d_scan<RowMajor>());
108   CALL_SUBTEST(test_tensor_maps<ColMajor>());
109   CALL_SUBTEST(test_tensor_maps<RowMajor>());
110 }
111