1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Navdeep Jaitly <ndjaitly@google.com and 5 // Benoit Steiner <benoit.steiner.goog@gmail.com> 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #include "main.h" 12 13 #include <Eigen/CXX11/Tensor> 14 15 using Eigen::Tensor; 16 using Eigen::array; 17 18 template <int DataLayout> 19 static void test_simple_reverse() 20 { 21 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 22 tensor.setRandom(); 23 24 array<bool, 4> dim_rev; 25 dim_rev[0] = false; 26 dim_rev[1] = true; 27 dim_rev[2] = true; 28 dim_rev[3] = false; 29 30 Tensor<float, 4, DataLayout> reversed_tensor; 31 reversed_tensor = tensor.reverse(dim_rev); 32 33 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2); 34 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3); 35 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5); 36 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7); 37 38 for (int i = 0; i < 2; ++i) { 39 for (int j = 0; j < 3; ++j) { 40 for (int k = 0; k < 5; ++k) { 41 for (int l = 0; l < 7; ++l) { 42 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(i,2-j,4-k,l)); 43 } 44 } 45 } 46 } 47 48 dim_rev[0] = true; 49 dim_rev[1] = false; 50 dim_rev[2] = false; 51 dim_rev[3] = false; 52 53 reversed_tensor = tensor.reverse(dim_rev); 54 55 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2); 56 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3); 57 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5); 58 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7); 59 60 61 for (int i = 0; i < 2; ++i) { 62 for (int j = 0; j < 3; ++j) { 63 for (int k = 0; k < 5; ++k) { 64 for (int l = 0; l < 7; ++l) { 65 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,l)); 66 } 67 } 68 } 69 } 70 71 dim_rev[0] = true; 72 dim_rev[1] = false; 73 dim_rev[2] = false; 74 dim_rev[3] = true; 75 76 reversed_tensor = tensor.reverse(dim_rev); 77 78 VERIFY_IS_EQUAL(reversed_tensor.dimension(0), 2); 79 VERIFY_IS_EQUAL(reversed_tensor.dimension(1), 3); 80 VERIFY_IS_EQUAL(reversed_tensor.dimension(2), 5); 81 VERIFY_IS_EQUAL(reversed_tensor.dimension(3), 7); 82 83 84 for (int i = 0; i < 2; ++i) { 85 for (int j = 0; j < 3; ++j) { 86 for (int k = 0; k < 5; ++k) { 87 for (int l = 0; l < 7; ++l) { 88 VERIFY_IS_EQUAL(tensor(i,j,k,l), reversed_tensor(1-i,j,k,6-l)); 89 } 90 } 91 } 92 } 93 } 94 95 96 template <int DataLayout> 97 static void test_expr_reverse(bool LValue) 98 { 99 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 100 tensor.setRandom(); 101 102 array<bool, 4> dim_rev; 103 dim_rev[0] = false; 104 dim_rev[1] = true; 105 dim_rev[2] = false; 106 dim_rev[3] = true; 107 108 Tensor<float, 4, DataLayout> expected(2, 3, 5, 7); 109 if (LValue) { 110 expected.reverse(dim_rev) = tensor; 111 } else { 112 expected = tensor.reverse(dim_rev); 113 } 114 115 Tensor<float, 4, DataLayout> result(2,3,5,7); 116 117 array<ptrdiff_t, 4> src_slice_dim; 118 src_slice_dim[0] = 2; 119 src_slice_dim[1] = 3; 120 src_slice_dim[2] = 1; 121 src_slice_dim[3] = 7; 122 array<ptrdiff_t, 4> src_slice_start; 123 src_slice_start[0] = 0; 124 src_slice_start[1] = 0; 125 src_slice_start[2] = 0; 126 src_slice_start[3] = 0; 127 array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim; 128 array<ptrdiff_t, 4> dst_slice_start = src_slice_start; 129 130 for (int i = 0; i < 5; ++i) { 131 if (LValue) { 132 result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) = 133 tensor.slice(src_slice_start, src_slice_dim); 134 } else { 135 result.slice(dst_slice_start, dst_slice_dim) = 136 tensor.slice(src_slice_start, src_slice_dim).reverse(dim_rev); 137 } 138 src_slice_start[2] += 1; 139 dst_slice_start[2] += 1; 140 } 141 142 VERIFY_IS_EQUAL(result.dimension(0), 2); 143 VERIFY_IS_EQUAL(result.dimension(1), 3); 144 VERIFY_IS_EQUAL(result.dimension(2), 5); 145 VERIFY_IS_EQUAL(result.dimension(3), 7); 146 147 for (int i = 0; i < expected.dimension(0); ++i) { 148 for (int j = 0; j < expected.dimension(1); ++j) { 149 for (int k = 0; k < expected.dimension(2); ++k) { 150 for (int l = 0; l < expected.dimension(3); ++l) { 151 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l)); 152 } 153 } 154 } 155 } 156 157 dst_slice_start[2] = 0; 158 result.setRandom(); 159 for (int i = 0; i < 5; ++i) { 160 if (LValue) { 161 result.slice(dst_slice_start, dst_slice_dim).reverse(dim_rev) = 162 tensor.slice(dst_slice_start, dst_slice_dim); 163 } else { 164 result.slice(dst_slice_start, dst_slice_dim) = 165 tensor.reverse(dim_rev).slice(dst_slice_start, dst_slice_dim); 166 } 167 dst_slice_start[2] += 1; 168 } 169 170 for (int i = 0; i < expected.dimension(0); ++i) { 171 for (int j = 0; j < expected.dimension(1); ++j) { 172 for (int k = 0; k < expected.dimension(2); ++k) { 173 for (int l = 0; l < expected.dimension(3); ++l) { 174 VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l)); 175 } 176 } 177 } 178 } 179 } 180 181 182 void test_cxx11_tensor_reverse() 183 { 184 CALL_SUBTEST(test_simple_reverse<ColMajor>()); 185 CALL_SUBTEST(test_simple_reverse<RowMajor>()); 186 CALL_SUBTEST(test_expr_reverse<ColMajor>(true)); 187 CALL_SUBTEST(test_expr_reverse<RowMajor>(true)); 188 CALL_SUBTEST(test_expr_reverse<ColMajor>(false)); 189 CALL_SUBTEST(test_expr_reverse<RowMajor>(false)); 190 } 191