1 #include "main.h" 2 3 #include <Eigen/CXX11/Tensor> 4 5 using Eigen::Tensor; 6 7 static void test_single_voxel_patch() 8 { 9 Tensor<float, 5> tensor(4,2,3,5,7); 10 tensor.setRandom(); 11 Tensor<float, 5, RowMajor> tensor_row_major = tensor.swap_layout(); 12 13 Tensor<float, 6> single_voxel_patch; 14 single_voxel_patch = tensor.extract_volume_patches(1, 1, 1); 15 VERIFY_IS_EQUAL(single_voxel_patch.dimension(0), 4); 16 VERIFY_IS_EQUAL(single_voxel_patch.dimension(1), 1); 17 VERIFY_IS_EQUAL(single_voxel_patch.dimension(2), 1); 18 VERIFY_IS_EQUAL(single_voxel_patch.dimension(3), 1); 19 VERIFY_IS_EQUAL(single_voxel_patch.dimension(4), 2 * 3 * 5); 20 VERIFY_IS_EQUAL(single_voxel_patch.dimension(5), 7); 21 22 Tensor<float, 6, RowMajor> single_voxel_patch_row_major; 23 single_voxel_patch_row_major = tensor_row_major.extract_volume_patches(1, 1, 1); 24 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(0), 7); 25 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(1), 2 * 3 * 5); 26 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(2), 1); 27 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(3), 1); 28 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(4), 1); 29 VERIFY_IS_EQUAL(single_voxel_patch_row_major.dimension(5), 4); 30 31 for (int i = 0; i < tensor.size(); ++i) { 32 VERIFY_IS_EQUAL(tensor.data()[i], single_voxel_patch.data()[i]); 33 VERIFY_IS_EQUAL(tensor_row_major.data()[i], single_voxel_patch_row_major.data()[i]); 34 VERIFY_IS_EQUAL(tensor.data()[i], tensor_row_major.data()[i]); 35 } 36 } 37 38 39 static void test_entire_volume_patch() 40 { 41 const int depth = 4; 42 const int patch_z = 2; 43 const int patch_y = 3; 44 const int patch_x = 5; 45 const int batch = 7; 46 47 Tensor<float, 5> tensor(depth, patch_z, patch_y, patch_x, batch); 48 tensor.setRandom(); 49 Tensor<float, 5, RowMajor> tensor_row_major = tensor.swap_layout(); 50 51 Tensor<float, 6> entire_volume_patch; 52 entire_volume_patch = tensor.extract_volume_patches(patch_z, patch_y, patch_x); 53 VERIFY_IS_EQUAL(entire_volume_patch.dimension(0), depth); 54 VERIFY_IS_EQUAL(entire_volume_patch.dimension(1), patch_z); 55 VERIFY_IS_EQUAL(entire_volume_patch.dimension(2), patch_y); 56 VERIFY_IS_EQUAL(entire_volume_patch.dimension(3), patch_x); 57 VERIFY_IS_EQUAL(entire_volume_patch.dimension(4), patch_z * patch_y * patch_x); 58 VERIFY_IS_EQUAL(entire_volume_patch.dimension(5), batch); 59 60 Tensor<float, 6, RowMajor> entire_volume_patch_row_major; 61 entire_volume_patch_row_major = tensor_row_major.extract_volume_patches(patch_z, patch_y, patch_x); 62 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(0), batch); 63 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(1), patch_z * patch_y * patch_x); 64 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(2), patch_x); 65 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(3), patch_y); 66 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(4), patch_z); 67 VERIFY_IS_EQUAL(entire_volume_patch_row_major.dimension(5), depth); 68 69 const int dz = patch_z - 1; 70 const int dy = patch_y - 1; 71 const int dx = patch_x - 1; 72 73 const int forward_pad_z = dz - dz / 2; 74 const int forward_pad_y = dy - dy / 2; 75 const int forward_pad_x = dx - dx / 2; 76 77 for (int pz = 0; pz < patch_z; pz++) { 78 for (int py = 0; py < patch_y; py++) { 79 for (int px = 0; px < patch_x; px++) { 80 const int patchId = pz + patch_z * (py + px * patch_y); 81 for (int z = 0; z < patch_z; z++) { 82 for (int y = 0; y < patch_y; y++) { 83 for (int x = 0; x < patch_x; x++) { 84 for (int b = 0; b < batch; b++) { 85 for (int d = 0; d < depth; d++) { 86 float expected = 0.0f; 87 float expected_row_major = 0.0f; 88 const int eff_z = z - forward_pad_z + pz; 89 const int eff_y = y - forward_pad_y + py; 90 const int eff_x = x - forward_pad_x + px; 91 if (eff_z >= 0 && eff_y >= 0 && eff_x >= 0 && 92 eff_z < patch_z && eff_y < patch_y && eff_x < patch_x) { 93 expected = tensor(d, eff_z, eff_y, eff_x, b); 94 expected_row_major = tensor_row_major(b, eff_x, eff_y, eff_z, d); 95 } 96 VERIFY_IS_EQUAL(entire_volume_patch(d, z, y, x, patchId, b), expected); 97 VERIFY_IS_EQUAL(entire_volume_patch_row_major(b, patchId, x, y, z, d), expected_row_major); 98 } 99 } 100 } 101 } 102 } 103 } 104 } 105 } 106 } 107 108 void test_cxx11_tensor_volume_patch() 109 { 110 CALL_SUBTEST(test_single_voxel_patch()); 111 CALL_SUBTEST(test_entire_volume_patch()); 112 } 113