1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 using Eigen::RowMajor; 16 17 static void test_simple_lvalue_ref() 18 { 19 Tensor<int, 1> input(6); 20 input.setRandom(); 21 22 TensorRef<Tensor<int, 1>> ref3(input); 23 TensorRef<Tensor<int, 1>> ref4 = input; 24 25 VERIFY_IS_EQUAL(ref3.data(), input.data()); 26 VERIFY_IS_EQUAL(ref4.data(), input.data()); 27 28 for (int i = 0; i < 6; ++i) { 29 VERIFY_IS_EQUAL(ref3(i), input(i)); 30 VERIFY_IS_EQUAL(ref4(i), input(i)); 31 } 32 33 for (int i = 0; i < 6; ++i) { 34 ref3.coeffRef(i) = i; 35 } 36 for (int i = 0; i < 6; ++i) { 37 VERIFY_IS_EQUAL(input(i), i); 38 } 39 for (int i = 0; i < 6; ++i) { 40 ref4.coeffRef(i) = -i * 2; 41 } 42 for (int i = 0; i < 6; ++i) { 43 VERIFY_IS_EQUAL(input(i), -i*2); 44 } 45 } 46 47 48 static void test_simple_rvalue_ref() 49 { 50 Tensor<int, 1> input1(6); 51 input1.setRandom(); 52 Tensor<int, 1> input2(6); 53 input2.setRandom(); 54 55 TensorRef<Tensor<int, 1>> ref3(input1 + input2); 56 TensorRef<Tensor<int, 1>> ref4 = input1 + input2; 57 58 VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data()); 59 VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data()); 60 VERIFY_IS_NOT_EQUAL(ref3.data(), input2.data()); 61 VERIFY_IS_NOT_EQUAL(ref4.data(), input2.data()); 62 63 for (int i = 0; i < 6; ++i) { 64 VERIFY_IS_EQUAL(ref3(i), input1(i) + input2(i)); 65 VERIFY_IS_EQUAL(ref4(i), input1(i) + input2(i)); 66 } 67 } 68 69 70 static void test_multiple_dims() 71 { 72 Tensor<float, 3> input(3,5,7); 73 input.setRandom(); 74 75 TensorRef<Tensor<float, 3>> ref(input); 76 VERIFY_IS_EQUAL(ref.data(), input.data()); 77 VERIFY_IS_EQUAL(ref.dimension(0), 3); 78 VERIFY_IS_EQUAL(ref.dimension(1), 5); 79 VERIFY_IS_EQUAL(ref.dimension(2), 7); 80 81 for (int i = 0; i < 3; ++i) { 82 for (int j = 0; j < 5; ++j) { 83 for (int k = 0; k < 7; ++k) { 84 VERIFY_IS_EQUAL(ref(i,j,k), input(i,j,k)); 85 } 86 } 87 } 88 } 89 90 91 static void test_slice() 92 { 93 Tensor<float, 5> tensor(2,3,5,7,11); 94 tensor.setRandom(); 95 96 Eigen::DSizes<ptrdiff_t, 5> indices(1,2,3,4,5); 97 Eigen::DSizes<ptrdiff_t, 5> sizes(1,1,1,1,1); 98 TensorRef<Tensor<float, 5>> slice = tensor.slice(indices, sizes); 99 VERIFY_IS_EQUAL(slice(0,0,0,0,0), tensor(1,2,3,4,5)); 100 101 Eigen::DSizes<ptrdiff_t, 5> indices2(1,1,3,4,5); 102 Eigen::DSizes<ptrdiff_t, 5> sizes2(1,1,2,2,3); 103 slice = tensor.slice(indices2, sizes2); 104 for (int i = 0; i < 2; ++i) { 105 for (int j = 0; j < 2; ++j) { 106 for (int k = 0; k < 3; ++k) { 107 VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k)); 108 } 109 } 110 } 111 112 Eigen::DSizes<ptrdiff_t, 5> indices3(0,0,0,0,0); 113 Eigen::DSizes<ptrdiff_t, 5> sizes3(2,3,1,1,1); 114 slice = tensor.slice(indices3, sizes3); 115 VERIFY_IS_EQUAL(slice.data(), tensor.data()); 116 } 117 118 119 static void test_ref_of_ref() 120 { 121 Tensor<float, 3> input(3,5,7); 122 input.setRandom(); 123 124 TensorRef<Tensor<float, 3>> ref(input); 125 TensorRef<Tensor<float, 3>> ref_of_ref(ref); 126 TensorRef<Tensor<float, 3>> ref_of_ref2; 127 ref_of_ref2 = ref; 128 129 VERIFY_IS_EQUAL(ref_of_ref.data(), input.data()); 130 VERIFY_IS_EQUAL(ref_of_ref.dimension(0), 3); 131 VERIFY_IS_EQUAL(ref_of_ref.dimension(1), 5); 132 VERIFY_IS_EQUAL(ref_of_ref.dimension(2), 7); 133 134 VERIFY_IS_EQUAL(ref_of_ref2.data(), input.data()); 135 VERIFY_IS_EQUAL(ref_of_ref2.dimension(0), 3); 136 VERIFY_IS_EQUAL(ref_of_ref2.dimension(1), 5); 137 VERIFY_IS_EQUAL(ref_of_ref2.dimension(2), 7); 138 139 for (int i = 0; i < 3; ++i) { 140 for (int j = 0; j < 5; ++j) { 141 for (int k = 0; k < 7; ++k) { 142 VERIFY_IS_EQUAL(ref_of_ref(i,j,k), input(i,j,k)); 143 VERIFY_IS_EQUAL(ref_of_ref2(i,j,k), input(i,j,k)); 144 } 145 } 146 } 147 } 148 149 150 static void test_ref_in_expr() 151 { 152 Tensor<float, 3> input(3,5,7); 153 input.setRandom(); 154 TensorRef<Tensor<float, 3>> input_ref(input); 155 156 Tensor<float, 3> result(3,5,7); 157 result.setRandom(); 158 TensorRef<Tensor<float, 3>> result_ref(result); 159 160 Tensor<float, 3> bias(3,5,7); 161 bias.setRandom(); 162 163 result_ref = input_ref + bias; 164 for (int i = 0; i < 3; ++i) { 165 for (int j = 0; j < 5; ++j) { 166 for (int k = 0; k < 7; ++k) { 167 VERIFY_IS_EQUAL(result_ref(i,j,k), input(i,j,k) + bias(i,j,k)); 168 VERIFY_IS_NOT_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k)); 169 } 170 } 171 } 172 173 result = result_ref; 174 for (int i = 0; i < 3; ++i) { 175 for (int j = 0; j < 5; ++j) { 176 for (int k = 0; k < 7; ++k) { 177 VERIFY_IS_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k)); 178 } 179 } 180 } 181 } 182 183 184 static void test_coeff_ref() 185 { 186 Tensor<float, 5> tensor(2,3,5,7,11); 187 tensor.setRandom(); 188 Tensor<float, 5> original = tensor; 189 190 TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4); 191 slice.coeffRef(0, 0, 0, 0) = 1.0f; 192 slice.coeffRef(1, 0, 0, 0) += 2.0f; 193 194 VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f); 195 VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f); 196 } 197 198 199 static void test_nested_ops_with_ref() 200 { 201 Tensor<float, 4> t(2, 3, 5, 7); 202 t.setRandom(); 203 TensorMap<Tensor<const float, 4> > m(t.data(), 2, 3, 5, 7); 204 array<std::pair<ptrdiff_t, ptrdiff_t>, 4> paddings; 205 paddings[0] = std::make_pair(0, 0); 206 paddings[1] = std::make_pair(2, 1); 207 paddings[2] = std::make_pair(3, 4); 208 paddings[3] = std::make_pair(0, 0); 209 DSizes<Eigen::DenseIndex, 4> shuffle_dims(0, 1, 2, 3); 210 TensorRef<Tensor<const float, 4> > ref(m.pad(paddings)); 211 array<std::pair<ptrdiff_t, ptrdiff_t>, 4> trivial; 212 trivial[0] = std::make_pair(0, 0); 213 trivial[1] = std::make_pair(0, 0); 214 trivial[2] = std::make_pair(0, 0); 215 trivial[3] = std::make_pair(0, 0); 216 Tensor<float, 4> padded = ref.shuffle(shuffle_dims).pad(trivial); 217 VERIFY_IS_EQUAL(padded.dimension(0), 2+0); 218 VERIFY_IS_EQUAL(padded.dimension(1), 3+3); 219 VERIFY_IS_EQUAL(padded.dimension(2), 5+7); 220 VERIFY_IS_EQUAL(padded.dimension(3), 7+0); 221 222 for (int i = 0; i < 2; ++i) { 223 for (int j = 0; j < 6; ++j) { 224 for (int k = 0; k < 12; ++k) { 225 for (int l = 0; l < 7; ++l) { 226 if (j >= 2 && j < 5 && k >= 3 && k < 8) { 227 VERIFY_IS_EQUAL(padded(i,j,k,l), t(i,j-2,k-3,l)); 228 } else { 229 VERIFY_IS_EQUAL(padded(i,j,k,l), 0.0f); 230 } 231 } 232 } 233 } 234 } 235 } 236 237 238 void test_cxx11_tensor_ref() 239 { 240 CALL_SUBTEST(test_simple_lvalue_ref()); 241 CALL_SUBTEST(test_simple_rvalue_ref()); 242 CALL_SUBTEST(test_multiple_dims()); 243 CALL_SUBTEST(test_slice()); 244 CALL_SUBTEST(test_ref_of_ref()); 245 CALL_SUBTEST(test_ref_in_expr()); 246 CALL_SUBTEST(test_coeff_ref()); 247 CALL_SUBTEST(test_nested_ops_with_ref()); 248 } 249