1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@google.com> 5 // Benoit Steiner <benoit.steiner.goog@gmail.com> 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #include "main.h" 12 13 #include <Eigen/CXX11/Tensor> 14 15 using Eigen::Tensor; 16 using Eigen::array; 17 using Eigen::Tuple; 18 19 template <int DataLayout> 20 static void test_simple_index_tuples() 21 { 22 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 23 tensor.setRandom(); 24 tensor = (tensor + tensor.constant(0.5)).log(); 25 26 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7); 27 index_tuples = tensor.index_tuples(); 28 29 for (DenseIndex n = 0; n < 2*3*5*7; ++n) { 30 const Tuple<DenseIndex, float>& v = index_tuples.coeff(n); 31 VERIFY_IS_EQUAL(v.first, n); 32 VERIFY_IS_EQUAL(v.second, tensor.coeff(n)); 33 } 34 } 35 36 template <int DataLayout> 37 static void test_index_tuples_dim() 38 { 39 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 40 tensor.setRandom(); 41 tensor = (tensor + tensor.constant(0.5)).log(); 42 43 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7); 44 45 index_tuples = tensor.index_tuples(); 46 47 for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) { 48 const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l); 49 VERIFY_IS_EQUAL(v.first, n); 50 VERIFY_IS_EQUAL(v.second, tensor(n)); 51 } 52 } 53 54 template <int DataLayout> 55 static void test_argmax_tuple_reducer() 56 { 57 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 58 tensor.setRandom(); 59 tensor = (tensor + tensor.constant(0.5)).log(); 60 61 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7); 62 index_tuples = tensor.index_tuples(); 63 64 Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced; 65 DimensionList<DenseIndex, 4> dims; 66 reduced = index_tuples.reduce( 67 dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >()); 68 69 Tensor<float, 0, DataLayout> maxi = tensor.maximum(); 70 71 VERIFY_IS_EQUAL(maxi(), reduced(0).second); 72 73 array<DenseIndex, 3> reduce_dims; 74 for (int d = 0; d < 3; ++d) reduce_dims[d] = d; 75 Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7); 76 reduced_by_dims = index_tuples.reduce( 77 reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >()); 78 79 Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims); 80 81 for (int l = 0; l < 7; ++l) { 82 VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second); 83 } 84 } 85 86 template <int DataLayout> 87 static void test_argmin_tuple_reducer() 88 { 89 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 90 tensor.setRandom(); 91 tensor = (tensor + tensor.constant(0.5)).log(); 92 93 Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7); 94 index_tuples = tensor.index_tuples(); 95 96 Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced; 97 DimensionList<DenseIndex, 4> dims; 98 reduced = index_tuples.reduce( 99 dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >()); 100 101 Tensor<float, 0, DataLayout> mini = tensor.minimum(); 102 103 VERIFY_IS_EQUAL(mini(), reduced(0).second); 104 105 array<DenseIndex, 3> reduce_dims; 106 for (int d = 0; d < 3; ++d) reduce_dims[d] = d; 107 Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7); 108 reduced_by_dims = index_tuples.reduce( 109 reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >()); 110 111 Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims); 112 113 for (int l = 0; l < 7; ++l) { 114 VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second); 115 } 116 } 117 118 template <int DataLayout> 119 static void test_simple_argmax() 120 { 121 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 122 tensor.setRandom(); 123 tensor = (tensor + tensor.constant(0.5)).log(); 124 tensor(0,0,0,0) = 10.0; 125 126 Tensor<DenseIndex, 0, DataLayout> tensor_argmax; 127 128 tensor_argmax = tensor.argmax(); 129 130 VERIFY_IS_EQUAL(tensor_argmax(0), 0); 131 132 tensor(1,2,4,6) = 20.0; 133 134 tensor_argmax = tensor.argmax(); 135 136 VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1); 137 } 138 139 template <int DataLayout> 140 static void test_simple_argmin() 141 { 142 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 143 tensor.setRandom(); 144 tensor = (tensor + tensor.constant(0.5)).log(); 145 tensor(0,0,0,0) = -10.0; 146 147 Tensor<DenseIndex, 0, DataLayout> tensor_argmin; 148 149 tensor_argmin = tensor.argmin(); 150 151 VERIFY_IS_EQUAL(tensor_argmin(0), 0); 152 153 tensor(1,2,4,6) = -20.0; 154 155 tensor_argmin = tensor.argmin(); 156 157 VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1); 158 } 159 160 template <int DataLayout> 161 static void test_argmax_dim() 162 { 163 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 164 std::vector<int> dims {2, 3, 5, 7}; 165 166 for (int dim = 0; dim < 4; ++dim) { 167 tensor.setRandom(); 168 tensor = (tensor + tensor.constant(0.5)).log(); 169 170 Tensor<DenseIndex, 3, DataLayout> tensor_argmax; 171 array<DenseIndex, 4> ix; 172 for (int i = 0; i < 2; ++i) { 173 for (int j = 0; j < 3; ++j) { 174 for (int k = 0; k < 5; ++k) { 175 for (int l = 0; l < 7; ++l) { 176 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l; 177 if (ix[dim] != 0) continue; 178 // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = 10.0 179 tensor(ix) = 10.0; 180 } 181 } 182 } 183 } 184 185 tensor_argmax = tensor.argmax(dim); 186 187 VERIFY_IS_EQUAL(tensor_argmax.size(), 188 ptrdiff_t(2*3*5*7 / tensor.dimension(dim))); 189 for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) { 190 // Expect max to be in the first index of the reduced dimension 191 VERIFY_IS_EQUAL(tensor_argmax.data()[n], 0); 192 } 193 194 for (int i = 0; i < 2; ++i) { 195 for (int j = 0; j < 3; ++j) { 196 for (int k = 0; k < 5; ++k) { 197 for (int l = 0; l < 7; ++l) { 198 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l; 199 if (ix[dim] != tensor.dimension(dim) - 1) continue; 200 // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = 20.0 201 tensor(ix) = 20.0; 202 } 203 } 204 } 205 } 206 207 tensor_argmax = tensor.argmax(dim); 208 209 VERIFY_IS_EQUAL(tensor_argmax.size(), 210 ptrdiff_t(2*3*5*7 / tensor.dimension(dim))); 211 for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) { 212 // Expect max to be in the last index of the reduced dimension 213 VERIFY_IS_EQUAL(tensor_argmax.data()[n], tensor.dimension(dim) - 1); 214 } 215 } 216 } 217 218 template <int DataLayout> 219 static void test_argmin_dim() 220 { 221 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 222 std::vector<int> dims {2, 3, 5, 7}; 223 224 for (int dim = 0; dim < 4; ++dim) { 225 tensor.setRandom(); 226 tensor = (tensor + tensor.constant(0.5)).log(); 227 228 Tensor<DenseIndex, 3, DataLayout> tensor_argmin; 229 array<DenseIndex, 4> ix; 230 for (int i = 0; i < 2; ++i) { 231 for (int j = 0; j < 3; ++j) { 232 for (int k = 0; k < 5; ++k) { 233 for (int l = 0; l < 7; ++l) { 234 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l; 235 if (ix[dim] != 0) continue; 236 // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = -10.0 237 tensor(ix) = -10.0; 238 } 239 } 240 } 241 } 242 243 tensor_argmin = tensor.argmin(dim); 244 245 VERIFY_IS_EQUAL(tensor_argmin.size(), 246 ptrdiff_t(2*3*5*7 / tensor.dimension(dim))); 247 for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) { 248 // Expect min to be in the first index of the reduced dimension 249 VERIFY_IS_EQUAL(tensor_argmin.data()[n], 0); 250 } 251 252 for (int i = 0; i < 2; ++i) { 253 for (int j = 0; j < 3; ++j) { 254 for (int k = 0; k < 5; ++k) { 255 for (int l = 0; l < 7; ++l) { 256 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l; 257 if (ix[dim] != tensor.dimension(dim) - 1) continue; 258 // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = -20.0 259 tensor(ix) = -20.0; 260 } 261 } 262 } 263 } 264 265 tensor_argmin = tensor.argmin(dim); 266 267 VERIFY_IS_EQUAL(tensor_argmin.size(), 268 ptrdiff_t(2*3*5*7 / tensor.dimension(dim))); 269 for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) { 270 // Expect min to be in the last index of the reduced dimension 271 VERIFY_IS_EQUAL(tensor_argmin.data()[n], tensor.dimension(dim) - 1); 272 } 273 } 274 } 275 276 void test_cxx11_tensor_argmax() 277 { 278 CALL_SUBTEST(test_simple_index_tuples<RowMajor>()); 279 CALL_SUBTEST(test_simple_index_tuples<ColMajor>()); 280 CALL_SUBTEST(test_index_tuples_dim<RowMajor>()); 281 CALL_SUBTEST(test_index_tuples_dim<ColMajor>()); 282 CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>()); 283 CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>()); 284 CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>()); 285 CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>()); 286 CALL_SUBTEST(test_simple_argmax<RowMajor>()); 287 CALL_SUBTEST(test_simple_argmax<ColMajor>()); 288 CALL_SUBTEST(test_simple_argmin<RowMajor>()); 289 CALL_SUBTEST(test_simple_argmin<ColMajor>()); 290 CALL_SUBTEST(test_argmax_dim<RowMajor>()); 291 CALL_SUBTEST(test_argmax_dim<ColMajor>()); 292 CALL_SUBTEST(test_argmin_dim<RowMajor>()); 293 CALL_SUBTEST(test_argmin_dim<ColMajor>()); 294 } 295