1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include "main.h"
11
12 #include <Eigen/CXX11/Tensor>
13
14 using Eigen::Tensor;
15
16 template<int DataLayout>
test_simple_striding()17 static void test_simple_striding()
18 {
19 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20 tensor.setRandom();
21 array<ptrdiff_t, 4> strides;
22 strides[0] = 1;
23 strides[1] = 1;
24 strides[2] = 1;
25 strides[3] = 1;
26
27 Tensor<float, 4, DataLayout> no_stride;
28 no_stride = tensor.stride(strides);
29
30 VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
31 VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
32 VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
33 VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
34
35 for (int i = 0; i < 2; ++i) {
36 for (int j = 0; j < 3; ++j) {
37 for (int k = 0; k < 5; ++k) {
38 for (int l = 0; l < 7; ++l) {
39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
40 }
41 }
42 }
43 }
44
45 strides[0] = 2;
46 strides[1] = 4;
47 strides[2] = 2;
48 strides[3] = 3;
49 Tensor<float, 4, DataLayout> stride;
50 stride = tensor.stride(strides);
51
52 VERIFY_IS_EQUAL(stride.dimension(0), 1);
53 VERIFY_IS_EQUAL(stride.dimension(1), 1);
54 VERIFY_IS_EQUAL(stride.dimension(2), 3);
55 VERIFY_IS_EQUAL(stride.dimension(3), 3);
56
57 for (int i = 0; i < 1; ++i) {
58 for (int j = 0; j < 1; ++j) {
59 for (int k = 0; k < 3; ++k) {
60 for (int l = 0; l < 3; ++l) {
61 VERIFY_IS_EQUAL(tensor(2*i,4*j,2*k,3*l), stride(i,j,k,l));
62 }
63 }
64 }
65 }
66 }
67
68
69 template<int DataLayout>
test_striding_as_lvalue()70 static void test_striding_as_lvalue()
71 {
72 Tensor<float, 4, DataLayout> tensor(2,3,5,7);
73 tensor.setRandom();
74 array<ptrdiff_t, 4> strides;
75 strides[0] = 2;
76 strides[1] = 4;
77 strides[2] = 2;
78 strides[3] = 3;
79
80 Tensor<float, 4, DataLayout> result(3, 12, 10, 21);
81 result.stride(strides) = tensor;
82
83 for (int i = 0; i < 2; ++i) {
84 for (int j = 0; j < 3; ++j) {
85 for (int k = 0; k < 5; ++k) {
86 for (int l = 0; l < 7; ++l) {
87 VERIFY_IS_EQUAL(tensor(i,j,k,l), result(2*i,4*j,2*k,3*l));
88 }
89 }
90 }
91 }
92
93 array<ptrdiff_t, 4> no_strides;
94 no_strides[0] = 1;
95 no_strides[1] = 1;
96 no_strides[2] = 1;
97 no_strides[3] = 1;
98 Tensor<float, 4, DataLayout> result2(3, 12, 10, 21);
99 result2.stride(strides) = tensor.stride(no_strides);
100
101 for (int i = 0; i < 2; ++i) {
102 for (int j = 0; j < 3; ++j) {
103 for (int k = 0; k < 5; ++k) {
104 for (int l = 0; l < 7; ++l) {
105 VERIFY_IS_EQUAL(tensor(i,j,k,l), result2(2*i,4*j,2*k,3*l));
106 }
107 }
108 }
109 }
110 }
111
112
test_cxx11_tensor_striding()113 void test_cxx11_tensor_striding()
114 {
115 CALL_SUBTEST(test_simple_striding<ColMajor>());
116 CALL_SUBTEST(test_simple_striding<RowMajor>());
117 CALL_SUBTEST(test_striding_as_lvalue<ColMajor>());
118 CALL_SUBTEST(test_striding_as_lvalue<RowMajor>());
119 }
120