1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template<int DataLayout>
test_simple_inflation()17 static void test_simple_inflation()
18 {
19 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20 tensor.setRandom();
21 array<ptrdiff_t, 4> strides;
22
23 strides[0] = 1;
24 strides[1] = 1;
25 strides[2] = 1;
26 strides[3] = 1;
27
28 Tensor<float, 4, DataLayout> no_stride;
29 no_stride = tensor.inflate(strides);
30
31 VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
32 VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
33 VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
34 VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
35
36 for (int i = 0; i < 2; ++i) {
37 for (int j = 0; j < 3; ++j) {
38 for (int k = 0; k < 5; ++k) {
39 for (int l = 0; l < 7; ++l) {
40 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
41 }
42 }
43 }
44 }
45
46 strides[0] = 2;
47 strides[1] = 4;
48 strides[2] = 2;
49 strides[3] = 3;
50 Tensor<float, 4, DataLayout> inflated;
51 inflated = tensor.inflate(strides);
52
53 VERIFY_IS_EQUAL(inflated.dimension(0), 3);
54 VERIFY_IS_EQUAL(inflated.dimension(1), 9);
55 VERIFY_IS_EQUAL(inflated.dimension(2), 9);
56 VERIFY_IS_EQUAL(inflated.dimension(3), 19);
57
58 for (int i = 0; i < 3; ++i) {
59 for (int j = 0; j < 9; ++j) {
60 for (int k = 0; k < 9; ++k) {
61 for (int l = 0; l < 19; ++l) {
62 if (i % 2 == 0 &&
63 j % 4 == 0 &&
64 k % 2 == 0 &&
65 l % 3 == 0) {
66 VERIFY_IS_EQUAL(inflated(i,j,k,l),
67 tensor(i/2, j/4, k/2, l/3));
68 } else {
69 VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
70 }
71 }
72 }
73 }
74 }
75 }
76
test_cxx11_tensor_inflation()77 void test_cxx11_tensor_inflation()
78 {
79 CALL_SUBTEST(test_simple_inflation<ColMajor>());
80 CALL_SUBTEST(test_simple_inflation<RowMajor>());
81 }
82