1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 using Eigen::RowMajor; 16 17 static void test_1d() 18 { 19 Tensor<int, 1> vec1(6); 20 Tensor<int, 1, RowMajor> vec2(6); 21 vec1(0) = 4; vec2(0) = 0; 22 vec1(1) = 8; vec2(1) = 1; 23 vec1(2) = 15; vec2(2) = 2; 24 vec1(3) = 16; vec2(3) = 3; 25 vec1(4) = 23; vec2(4) = 4; 26 vec1(5) = 42; vec2(5) = 5; 27 28 int col_major[6]; 29 int row_major[6]; 30 memset(col_major, 0, 6*sizeof(int)); 31 memset(row_major, 0, 6*sizeof(int)); 32 TensorMap<Tensor<int, 1> > vec3(col_major, 6); 33 TensorMap<Tensor<int, 1, RowMajor> > vec4(row_major, 6); 34 35 vec3 = vec1; 36 vec4 = vec2; 37 38 VERIFY_IS_EQUAL(vec3(0), 4); 39 VERIFY_IS_EQUAL(vec3(1), 8); 40 VERIFY_IS_EQUAL(vec3(2), 15); 41 VERIFY_IS_EQUAL(vec3(3), 16); 42 VERIFY_IS_EQUAL(vec3(4), 23); 43 VERIFY_IS_EQUAL(vec3(5), 42); 44 45 VERIFY_IS_EQUAL(vec4(0), 0); 46 VERIFY_IS_EQUAL(vec4(1), 1); 47 VERIFY_IS_EQUAL(vec4(2), 2); 48 VERIFY_IS_EQUAL(vec4(3), 3); 49 VERIFY_IS_EQUAL(vec4(4), 4); 50 VERIFY_IS_EQUAL(vec4(5), 5); 51 52 vec1.setZero(); 53 vec2.setZero(); 54 vec1 = vec3; 55 vec2 = vec4; 56 57 VERIFY_IS_EQUAL(vec1(0), 4); 58 VERIFY_IS_EQUAL(vec1(1), 8); 59 VERIFY_IS_EQUAL(vec1(2), 15); 60 VERIFY_IS_EQUAL(vec1(3), 16); 61 VERIFY_IS_EQUAL(vec1(4), 23); 62 VERIFY_IS_EQUAL(vec1(5), 42); 63 64 VERIFY_IS_EQUAL(vec2(0), 0); 65 VERIFY_IS_EQUAL(vec2(1), 1); 66 VERIFY_IS_EQUAL(vec2(2), 2); 67 VERIFY_IS_EQUAL(vec2(3), 3); 68 VERIFY_IS_EQUAL(vec2(4), 4); 69 VERIFY_IS_EQUAL(vec2(5), 5); 70 } 71 72 static void test_2d() 73 { 74 Tensor<int, 2> mat1(2,3); 75 Tensor<int, 2, RowMajor> mat2(2,3); 76 77 mat1(0,0) = 0; 78 mat1(0,1) = 1; 79 mat1(0,2) = 2; 80 mat1(1,0) = 3; 81 mat1(1,1) = 4; 82 mat1(1,2) = 5; 83 84 mat2(0,0) = 0; 85 mat2(0,1) = 1; 86 mat2(0,2) = 2; 87 mat2(1,0) = 3; 88 mat2(1,1) = 4; 89 mat2(1,2) = 5; 90 91 int col_major[6]; 92 int row_major[6]; 93 memset(col_major, 0, 6*sizeof(int)); 94 memset(row_major, 0, 6*sizeof(int)); 95 TensorMap<Tensor<int, 2> > mat3(row_major, 2, 3); 96 TensorMap<Tensor<int, 2, RowMajor> > mat4(col_major, 2, 3); 97 98 mat3 = mat1; 99 mat4 = mat2; 100 101 VERIFY_IS_EQUAL(mat3(0,0), 0); 102 VERIFY_IS_EQUAL(mat3(0,1), 1); 103 VERIFY_IS_EQUAL(mat3(0,2), 2); 104 VERIFY_IS_EQUAL(mat3(1,0), 3); 105 VERIFY_IS_EQUAL(mat3(1,1), 4); 106 VERIFY_IS_EQUAL(mat3(1,2), 5); 107 108 VERIFY_IS_EQUAL(mat4(0,0), 0); 109 VERIFY_IS_EQUAL(mat4(0,1), 1); 110 VERIFY_IS_EQUAL(mat4(0,2), 2); 111 VERIFY_IS_EQUAL(mat4(1,0), 3); 112 VERIFY_IS_EQUAL(mat4(1,1), 4); 113 VERIFY_IS_EQUAL(mat4(1,2), 5); 114 115 mat1.setZero(); 116 mat2.setZero(); 117 mat1 = mat3; 118 mat2 = mat4; 119 120 VERIFY_IS_EQUAL(mat1(0,0), 0); 121 VERIFY_IS_EQUAL(mat1(0,1), 1); 122 VERIFY_IS_EQUAL(mat1(0,2), 2); 123 VERIFY_IS_EQUAL(mat1(1,0), 3); 124 VERIFY_IS_EQUAL(mat1(1,1), 4); 125 VERIFY_IS_EQUAL(mat1(1,2), 5); 126 127 VERIFY_IS_EQUAL(mat2(0,0), 0); 128 VERIFY_IS_EQUAL(mat2(0,1), 1); 129 VERIFY_IS_EQUAL(mat2(0,2), 2); 130 VERIFY_IS_EQUAL(mat2(1,0), 3); 131 VERIFY_IS_EQUAL(mat2(1,1), 4); 132 VERIFY_IS_EQUAL(mat2(1,2), 5); 133 } 134 135 static void test_3d() 136 { 137 Tensor<int, 3> mat1(2,3,7); 138 Tensor<int, 3, RowMajor> mat2(2,3,7); 139 140 int val = 0; 141 for (int i = 0; i < 2; ++i) { 142 for (int j = 0; j < 3; ++j) { 143 for (int k = 0; k < 7; ++k) { 144 mat1(i,j,k) = val; 145 mat2(i,j,k) = val; 146 val++; 147 } 148 } 149 } 150 151 int col_major[2*3*7]; 152 int row_major[2*3*7]; 153 memset(col_major, 0, 2*3*7*sizeof(int)); 154 memset(row_major, 0, 2*3*7*sizeof(int)); 155 TensorMap<Tensor<int, 3> > mat3(col_major, 2, 3, 7); 156 TensorMap<Tensor<int, 3, RowMajor> > mat4(row_major, 2, 3, 7); 157 158 mat3 = mat1; 159 mat4 = mat2; 160 161 val = 0; 162 for (int i = 0; i < 2; ++i) { 163 for (int j = 0; j < 3; ++j) { 164 for (int k = 0; k < 7; ++k) { 165 VERIFY_IS_EQUAL(mat3(i,j,k), val); 166 VERIFY_IS_EQUAL(mat4(i,j,k), val); 167 val++; 168 } 169 } 170 } 171 172 mat1.setZero(); 173 mat2.setZero(); 174 mat1 = mat3; 175 mat2 = mat4; 176 177 val = 0; 178 for (int i = 0; i < 2; ++i) { 179 for (int j = 0; j < 3; ++j) { 180 for (int k = 0; k < 7; ++k) { 181 VERIFY_IS_EQUAL(mat1(i,j,k), val); 182 VERIFY_IS_EQUAL(mat2(i,j,k), val); 183 val++; 184 } 185 } 186 } 187 } 188 189 static void test_same_type() 190 { 191 Tensor<int, 1> orig_tensor(5); 192 Tensor<int, 1> dest_tensor(5); 193 orig_tensor.setRandom(); 194 dest_tensor.setRandom(); 195 int* orig_data = orig_tensor.data(); 196 int* dest_data = dest_tensor.data(); 197 dest_tensor = orig_tensor; 198 VERIFY_IS_EQUAL(orig_tensor.data(), orig_data); 199 VERIFY_IS_EQUAL(dest_tensor.data(), dest_data); 200 for (int i = 0; i < 5; ++i) { 201 VERIFY_IS_EQUAL(dest_tensor(i), orig_tensor(i)); 202 } 203 204 TensorFixedSize<int, Sizes<5> > orig_array; 205 TensorFixedSize<int, Sizes<5> > dest_array; 206 orig_array.setRandom(); 207 dest_array.setRandom(); 208 orig_data = orig_array.data(); 209 dest_data = dest_array.data(); 210 dest_array = orig_array; 211 VERIFY_IS_EQUAL(orig_array.data(), orig_data); 212 VERIFY_IS_EQUAL(dest_array.data(), dest_data); 213 for (int i = 0; i < 5; ++i) { 214 VERIFY_IS_EQUAL(dest_array(i), orig_array(i)); 215 } 216 217 int orig[5] = {1, 2, 3, 4, 5}; 218 int dest[5] = {6, 7, 8, 9, 10}; 219 TensorMap<Tensor<int, 1> > orig_map(orig, 5); 220 TensorMap<Tensor<int, 1> > dest_map(dest, 5); 221 orig_data = orig_map.data(); 222 dest_data = dest_map.data(); 223 dest_map = orig_map; 224 VERIFY_IS_EQUAL(orig_map.data(), orig_data); 225 VERIFY_IS_EQUAL(dest_map.data(), dest_data); 226 for (int i = 0; i < 5; ++i) { 227 VERIFY_IS_EQUAL(dest[i], i+1); 228 } 229 } 230 231 static void test_auto_resize() 232 { 233 Tensor<int, 1> tensor1; 234 Tensor<int, 1> tensor2(3); 235 Tensor<int, 1> tensor3(5); 236 Tensor<int, 1> tensor4(7); 237 238 Tensor<int, 1> new_tensor(5); 239 new_tensor.setRandom(); 240 241 tensor1 = tensor2 = tensor3 = tensor4 = new_tensor; 242 243 VERIFY_IS_EQUAL(tensor1.dimension(0), new_tensor.dimension(0)); 244 VERIFY_IS_EQUAL(tensor2.dimension(0), new_tensor.dimension(0)); 245 VERIFY_IS_EQUAL(tensor3.dimension(0), new_tensor.dimension(0)); 246 VERIFY_IS_EQUAL(tensor4.dimension(0), new_tensor.dimension(0)); 247 for (int i = 0; i < new_tensor.dimension(0); ++i) { 248 VERIFY_IS_EQUAL(tensor1(i), new_tensor(i)); 249 VERIFY_IS_EQUAL(tensor2(i), new_tensor(i)); 250 VERIFY_IS_EQUAL(tensor3(i), new_tensor(i)); 251 VERIFY_IS_EQUAL(tensor4(i), new_tensor(i)); 252 } 253 } 254 255 256 static void test_compound_assign() 257 { 258 Tensor<int, 1> start_tensor(10); 259 Tensor<int, 1> offset_tensor(10); 260 start_tensor.setRandom(); 261 offset_tensor.setRandom(); 262 263 Tensor<int, 1> tensor = start_tensor; 264 tensor += offset_tensor; 265 for (int i = 0; i < 10; ++i) { 266 VERIFY_IS_EQUAL(tensor(i), start_tensor(i) + offset_tensor(i)); 267 } 268 269 tensor = start_tensor; 270 tensor -= offset_tensor; 271 for (int i = 0; i < 10; ++i) { 272 VERIFY_IS_EQUAL(tensor(i), start_tensor(i) - offset_tensor(i)); 273 } 274 275 tensor = start_tensor; 276 tensor *= offset_tensor; 277 for (int i = 0; i < 10; ++i) { 278 VERIFY_IS_EQUAL(tensor(i), start_tensor(i) * offset_tensor(i)); 279 } 280 281 tensor = start_tensor; 282 tensor /= offset_tensor; 283 for (int i = 0; i < 10; ++i) { 284 VERIFY_IS_EQUAL(tensor(i), start_tensor(i) / offset_tensor(i)); 285 } 286 } 287 288 static void test_std_initializers_tensor() { 289 #if EIGEN_HAS_VARIADIC_TEMPLATES 290 Tensor<int, 1> a(3); 291 a.setValues({0, 1, 2}); 292 VERIFY_IS_EQUAL(a(0), 0); 293 VERIFY_IS_EQUAL(a(1), 1); 294 VERIFY_IS_EQUAL(a(2), 2); 295 296 // It fills the top-left slice. 297 a.setValues({10, 20}); 298 VERIFY_IS_EQUAL(a(0), 10); 299 VERIFY_IS_EQUAL(a(1), 20); 300 VERIFY_IS_EQUAL(a(2), 2); 301 302 // Chaining. 303 Tensor<int, 1> a2(3); 304 a2 = a.setValues({100, 200, 300}); 305 VERIFY_IS_EQUAL(a(0), 100); 306 VERIFY_IS_EQUAL(a(1), 200); 307 VERIFY_IS_EQUAL(a(2), 300); 308 VERIFY_IS_EQUAL(a2(0), 100); 309 VERIFY_IS_EQUAL(a2(1), 200); 310 VERIFY_IS_EQUAL(a2(2), 300); 311 312 Tensor<int, 2> b(2, 3); 313 b.setValues({{0, 1, 2}, {3, 4, 5}}); 314 VERIFY_IS_EQUAL(b(0, 0), 0); 315 VERIFY_IS_EQUAL(b(0, 1), 1); 316 VERIFY_IS_EQUAL(b(0, 2), 2); 317 VERIFY_IS_EQUAL(b(1, 0), 3); 318 VERIFY_IS_EQUAL(b(1, 1), 4); 319 VERIFY_IS_EQUAL(b(1, 2), 5); 320 321 // It fills the top-left slice. 322 b.setValues({{10, 20}, {30}}); 323 VERIFY_IS_EQUAL(b(0, 0), 10); 324 VERIFY_IS_EQUAL(b(0, 1), 20); 325 VERIFY_IS_EQUAL(b(0, 2), 2); 326 VERIFY_IS_EQUAL(b(1, 0), 30); 327 VERIFY_IS_EQUAL(b(1, 1), 4); 328 VERIFY_IS_EQUAL(b(1, 2), 5); 329 330 Eigen::Tensor<int, 3> c(3, 2, 4); 331 c.setValues({{{0, 1, 2, 3}, {4, 5, 6, 7}}, 332 {{10, 11, 12, 13}, {14, 15, 16, 17}}, 333 {{20, 21, 22, 23}, {24, 25, 26, 27}}}); 334 VERIFY_IS_EQUAL(c(0, 0, 0), 0); 335 VERIFY_IS_EQUAL(c(0, 0, 1), 1); 336 VERIFY_IS_EQUAL(c(0, 0, 2), 2); 337 VERIFY_IS_EQUAL(c(0, 0, 3), 3); 338 VERIFY_IS_EQUAL(c(0, 1, 0), 4); 339 VERIFY_IS_EQUAL(c(0, 1, 1), 5); 340 VERIFY_IS_EQUAL(c(0, 1, 2), 6); 341 VERIFY_IS_EQUAL(c(0, 1, 3), 7); 342 VERIFY_IS_EQUAL(c(1, 0, 0), 10); 343 VERIFY_IS_EQUAL(c(1, 0, 1), 11); 344 VERIFY_IS_EQUAL(c(1, 0, 2), 12); 345 VERIFY_IS_EQUAL(c(1, 0, 3), 13); 346 VERIFY_IS_EQUAL(c(1, 1, 0), 14); 347 VERIFY_IS_EQUAL(c(1, 1, 1), 15); 348 VERIFY_IS_EQUAL(c(1, 1, 2), 16); 349 VERIFY_IS_EQUAL(c(1, 1, 3), 17); 350 VERIFY_IS_EQUAL(c(2, 0, 0), 20); 351 VERIFY_IS_EQUAL(c(2, 0, 1), 21); 352 VERIFY_IS_EQUAL(c(2, 0, 2), 22); 353 VERIFY_IS_EQUAL(c(2, 0, 3), 23); 354 VERIFY_IS_EQUAL(c(2, 1, 0), 24); 355 VERIFY_IS_EQUAL(c(2, 1, 1), 25); 356 VERIFY_IS_EQUAL(c(2, 1, 2), 26); 357 VERIFY_IS_EQUAL(c(2, 1, 3), 27); 358 #endif // EIGEN_HAS_VARIADIC_TEMPLATES 359 } 360 361 void test_cxx11_tensor_assign() 362 { 363 CALL_SUBTEST(test_1d()); 364 CALL_SUBTEST(test_2d()); 365 CALL_SUBTEST(test_3d()); 366 CALL_SUBTEST(test_same_type()); 367 CALL_SUBTEST(test_auto_resize()); 368 CALL_SUBTEST(test_compound_assign()); 369 CALL_SUBTEST(test_std_initializers_tensor()); 370 } 371