1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 16 template<int DataLayout> 17 static void test_simple_striding() 18 { 19 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 20 tensor.setRandom(); 21 array<ptrdiff_t, 4> strides; 22 strides[0] = 1; 23 strides[1] = 1; 24 strides[2] = 1; 25 strides[3] = 1; 26 27 Tensor<float, 4, DataLayout> no_stride; 28 no_stride = tensor.stride(strides); 29 30 VERIFY_IS_EQUAL(no_stride.dimension(0), 2); 31 VERIFY_IS_EQUAL(no_stride.dimension(1), 3); 32 VERIFY_IS_EQUAL(no_stride.dimension(2), 5); 33 VERIFY_IS_EQUAL(no_stride.dimension(3), 7); 34 35 for (int i = 0; i < 2; ++i) { 36 for (int j = 0; j < 3; ++j) { 37 for (int k = 0; k < 5; ++k) { 38 for (int l = 0; l < 7; ++l) { 39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l)); 40 } 41 } 42 } 43 } 44 45 strides[0] = 2; 46 strides[1] = 4; 47 strides[2] = 2; 48 strides[3] = 3; 49 Tensor<float, 4, DataLayout> stride; 50 stride = tensor.stride(strides); 51 52 VERIFY_IS_EQUAL(stride.dimension(0), 1); 53 VERIFY_IS_EQUAL(stride.dimension(1), 1); 54 VERIFY_IS_EQUAL(stride.dimension(2), 3); 55 VERIFY_IS_EQUAL(stride.dimension(3), 3); 56 57 for (int i = 0; i < 1; ++i) { 58 for (int j = 0; j < 1; ++j) { 59 for (int k = 0; k < 3; ++k) { 60 for (int l = 0; l < 3; ++l) { 61 VERIFY_IS_EQUAL(tensor(2*i,4*j,2*k,3*l), stride(i,j,k,l)); 62 } 63 } 64 } 65 } 66 } 67 68 69 template<int DataLayout> 70 static void test_striding_as_lvalue() 71 { 72 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 73 tensor.setRandom(); 74 array<ptrdiff_t, 4> strides; 75 strides[0] = 2; 76 strides[1] = 4; 77 strides[2] = 2; 78 strides[3] = 3; 79 80 Tensor<float, 4, DataLayout> result(3, 12, 10, 21); 81 result.stride(strides) = tensor; 82 83 for (int i = 0; i < 2; ++i) { 84 for (int j = 0; j < 3; ++j) { 85 for (int k = 0; k < 5; ++k) { 86 for (int l = 0; l < 7; ++l) { 87 VERIFY_IS_EQUAL(tensor(i,j,k,l), result(2*i,4*j,2*k,3*l)); 88 } 89 } 90 } 91 } 92 93 array<ptrdiff_t, 4> no_strides; 94 no_strides[0] = 1; 95 no_strides[1] = 1; 96 no_strides[2] = 1; 97 no_strides[3] = 1; 98 Tensor<float, 4, DataLayout> result2(3, 12, 10, 21); 99 result2.stride(strides) = tensor.stride(no_strides); 100 101 for (int i = 0; i < 2; ++i) { 102 for (int j = 0; j < 3; ++j) { 103 for (int k = 0; k < 5; ++k) { 104 for (int l = 0; l < 7; ++l) { 105 VERIFY_IS_EQUAL(tensor(i,j,k,l), result2(2*i,4*j,2*k,3*l)); 106 } 107 } 108 } 109 } 110 } 111 112 113 void test_cxx11_tensor_striding() 114 { 115 CALL_SUBTEST(test_simple_striding<ColMajor>()); 116 CALL_SUBTEST(test_simple_striding<RowMajor>()); 117 CALL_SUBTEST(test_striding_as_lvalue<ColMajor>()); 118 CALL_SUBTEST(test_striding_as_lvalue<RowMajor>()); 119 } 120