1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <gtest/gtest.h>
16 #include "tensorflow/contrib/lite/interpreter.h"
17 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
18 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
19 #include "tensorflow/contrib/lite/kernels/register.h"
20 #include "tensorflow/contrib/lite/kernels/test_util.h"
21 #include "tensorflow/contrib/lite/model.h"
22
23 namespace tflite {
24 namespace {
25
26 using ::testing::ElementsAreArray;
27
RunTestPermutation(const std::vector<int> & shape,const std::vector<int> & perms,std::vector<float> * input_transposed)28 void RunTestPermutation(const std::vector<int>& shape,
29 const std::vector<int>& perms,
30 std::vector<float>* input_transposed) {
31 // Count elements and allocate output.
32 int count = 1;
33 for (auto factor : shape) count *= factor;
34 input_transposed->resize(count);
35
36 // Create the dummy data
37 std::vector<float> input(count);
38 for (int i = 0; i < input.size(); i++) {
39 input[i] = i;
40 }
41
42 // Create reversed and padded perms.
43 int reversed_perms[4];
44 for (int output_k = 0, input_k = shape.size() - 1; output_k < shape.size();
45 output_k++, input_k--) {
46 reversed_perms[output_k] = shape.size() - perms[input_k] - 1;
47 }
48 // Unused dimensions should not be permuted so pad with identity transform
49 // subset.
50 for (int k = shape.size(); k < 4; k++) {
51 reversed_perms[k] = k;
52 }
53
54 // Make input and output dims (i.e. reversed shape and dest_shape).
55 Dims<4> input_dims = GetTensorDims(shape);
56 Dims<4> output_dims;
57 for (int i = 0; i < 4; i++) {
58 output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
59 }
60 output_dims.strides[0] = 1;
61 for (int k = 1; k < 4; k++) {
62 output_dims.strides[k] =
63 output_dims.strides[k - 1] * output_dims.sizes[k - 1];
64 }
65
66 reference_ops::Transpose<float>(input.data(), input_dims,
67 input_transposed->data(), output_dims,
68 reversed_perms);
69 }
70
TEST(TransposeTest,TestRefOps1D)71 TEST(TransposeTest, TestRefOps1D) {
72 // Basic 1D identity.
73 std::vector<float> out;
74 RunTestPermutation({3}, {0}, &out);
75 ASSERT_EQ(out, std::vector<float>({0, 1, 2}));
76 }
77
TEST(TransposeTest,TestRefOps2D)78 TEST(TransposeTest, TestRefOps2D) {
79 std::vector<float> out;
80 // Basic 2D.
81 RunTestPermutation({3, 2}, {1, 0}, &out);
82 ASSERT_EQ(out, std::vector<float>({0, 2, 4, 1, 3, 5}));
83 // Identity.
84 RunTestPermutation({3, 2}, {0, 1}, &out);
85 ASSERT_EQ(out, std::vector<float>({0, 1, 2, 3, 4, 5}));
86 }
87
TEST(TransposeTest,TestRefOps3D)88 TEST(TransposeTest, TestRefOps3D) {
89 std::vector<float> out;
90 // Test 3 dimensional
91 {
92 std::vector<float> ref({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
93 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23});
94 RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out);
95 ASSERT_EQ(out, ref);
96 }
97 // Test 3 dimensional identity transform
98 {
99 RunTestPermutation({2, 3, 4}, {0, 1, 2}, &out);
100 std::vector<float> ref(out.size());
101 for (int k = 0; k < ref.size(); k++) ref[k] = k;
102 ASSERT_EQ(out, ref);
103 }
104 }
105
TEST(TransposeTest,TestRefOps4D)106 TEST(TransposeTest, TestRefOps4D) {
107 std::vector<float> out;
108 // Basic 4d.
109 RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
110 ASSERT_EQ(
111 out,
112 std::vector<float>(
113 {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
114 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
115 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
116 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
117 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
118 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
119 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
120 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}));
121 RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out);
122 // Basic identity.
123 std::vector<float> ref(out.size());
124 for (int k = 0; k < ref.size(); k++) ref[k] = k;
125 ASSERT_EQ(out, ref);
126 }
127
128 class TransposeOpModel : public SingleOpModel {
129 public:
SetInput(std::initializer_list<float> data)130 void SetInput(std::initializer_list<float> data) {
131 PopulateTensor<float>(input_, data);
132 }
133
SetPerm(std::initializer_list<int> data)134 void SetPerm(std::initializer_list<int> data) {
135 PopulateTensor<int>(perm_, data);
136 }
137
GetOutput()138 std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()139 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
140
141 protected:
142 int input_;
143 int perm_;
144 int output_;
145 };
146
147 // Tests case where perm is a const tensor.
148 //
149 // Example usage is as follows:
150 // SpaceToBatchNDOpConstModel m(input_shape, perm_shape, perm_data);
151 // m.SetInput(input_data);
152 // m.Invoke();
153 class TransposeOpConstModel : public TransposeOpModel {
154 public:
TransposeOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)155 TransposeOpConstModel(std::initializer_list<int> input_shape,
156 std::initializer_list<int> perm_shape,
157 std::initializer_list<int> perm) {
158 input_ = AddInput(TensorType_FLOAT32);
159 perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
160 output_ = AddOutput(TensorType_FLOAT32);
161 SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
162 CreateTransposeOptions(builder_).Union());
163 BuildInterpreter({input_shape});
164 }
165 };
166
167 // Tests case where perm is a non-const tensor.
168 //
169 // Example usage is as follows:
170 // TransposeOpDynamicModel m(input_shape, perm_shape);
171 // m.SetInput(input_data);
172 // m.SetPerm(perm_data);
173 // m.Invoke();
174 class TransposeOpDynamicModel : public TransposeOpModel {
175 public:
TransposeOpDynamicModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape)176 TransposeOpDynamicModel(std::initializer_list<int> input_shape,
177 std::initializer_list<int> perm_shape) {
178 input_ = AddInput(TensorType_FLOAT32);
179 perm_ = AddInput(TensorType_INT32);
180 output_ = AddOutput(TensorType_FLOAT32);
181 SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
182 CreateTransposeOptions(builder_).Union());
183 BuildInterpreter({input_shape, perm_shape});
184 }
185 };
186
TEST(TransposeTest,TestUnequalPermSize)187 TEST(TransposeTest, TestUnequalPermSize) {
188 EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {2}, {2, 2}), "2 != 4");
189 }
190
TEST(TransposeTest,TestPermOutOfBounds)191 TEST(TransposeTest, TestPermOutOfBounds) {
192 EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, -1, -2, -3}),
193 "Transpose op permutations array is out of bounds.");
194 EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, 1, 2, 4}),
195 "Transpose op permutations array is out of bounds.");
196 }
197
TEST(TransposeTest,Test1DInputConstTensor)198 TEST(TransposeTest, Test1DInputConstTensor) {
199 TransposeOpConstModel m({3}, {1}, {0});
200 m.SetInput({1, 2, 3});
201 m.Invoke();
202 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
203 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
204 }
205
TEST(TransposeTest,Test1DInputDynamicTensor)206 TEST(TransposeTest, Test1DInputDynamicTensor) {
207 TransposeOpDynamicModel m({3}, {1});
208 m.SetInput({1, 2, 3});
209 m.SetPerm({0});
210 m.Invoke();
211 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
212 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
213 }
214
TEST(TransposeTest,Test2DInputConstTensor)215 TEST(TransposeTest, Test2DInputConstTensor) {
216 TransposeOpConstModel m({3, 2}, {2}, {1, 0});
217 m.SetInput({0, 1, 2, 3, 4, 5});
218 m.Invoke();
219 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
220 EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
221 }
222
TEST(TransposeTest,Test2DInputDynamicTensor)223 TEST(TransposeTest, Test2DInputDynamicTensor) {
224 TransposeOpDynamicModel m({3, 2}, {2});
225 m.SetInput({0, 1, 2, 3, 4, 5});
226 m.SetPerm({1, 0});
227 m.Invoke();
228 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
229 EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
230 }
231
TEST(TransposeTest,Test3DInputConstTensor)232 TEST(TransposeTest, Test3DInputConstTensor) {
233 TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1});
234 m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
235 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
236 m.Invoke();
237 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
238 EXPECT_THAT(m.GetOutput(),
239 ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
240 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
241 }
242
TEST(TransposeTest,Test3DInputDynamicTensor)243 TEST(TransposeTest, Test3DInputDynamicTensor) {
244 TransposeOpDynamicModel m({2, 3, 4}, {3});
245 m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
246 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
247 m.SetPerm({2, 0, 1});
248 m.Invoke();
249 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
250 EXPECT_THAT(m.GetOutput(),
251 ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
252 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
253 }
254
TEST(TransposeTest,Test5DInputTensor)255 TEST(TransposeTest, Test5DInputTensor) {
256 EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),
257 "Transpose op only supports 1D-4D input arrays.");
258 }
259
TEST(TransposeTest,SimpleTestNoReorderConstTensor)260 TEST(TransposeTest, SimpleTestNoReorderConstTensor) {
261 TransposeOpConstModel m({1, 2, 3, 1}, {4}, {0, 1, 2, 3});
262 m.SetInput({1, 2, 3, 4, 5, 6});
263 m.Invoke();
264 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
265 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
266 }
267
TEST(TransposeTest,SimpleTestNoReorderDynamicTensor)268 TEST(TransposeTest, SimpleTestNoReorderDynamicTensor) {
269 TransposeOpDynamicModel m({1, 2, 3, 1}, {4});
270 m.SetInput({1, 2, 3, 4, 5, 6});
271 m.SetPerm({0, 1, 2, 3});
272 m.Invoke();
273 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
274 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
275 }
276
TEST(TransposeTest,SimpleTestWithReorderConstTensor)277 TEST(TransposeTest, SimpleTestWithReorderConstTensor) {
278 TransposeOpConstModel m({1, 2, 3, 1}, {4}, {2, 1, 3, 0});
279 m.SetInput({1, 2, 3, 4, 5, 6});
280 m.Invoke();
281 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1}));
282 EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
283 }
284
TEST(TransposeTest,ComplexTestWithReorderConstTensor)285 TEST(TransposeTest, ComplexTestWithReorderConstTensor) {
286 TransposeOpConstModel m({2, 3, 4, 5}, {4}, {2, 0, 1, 3});
287 m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
288 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
289 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
290 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
291 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
292 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
293 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
294 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
295 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
296 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
297 m.Invoke();
298
299 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
300 auto result = ElementsAreArray(
301 {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
302 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
303 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
304 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
305 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
306 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
307 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
308 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
309 EXPECT_THAT(m.GetOutput(), result);
310 }
311
TEST(TransposeTest,ComplexTestWithReorderDynamicTensor)312 TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
313 TransposeOpDynamicModel m({2, 3, 4, 5}, {4});
314 m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
315 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
316 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
317 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
318 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
319 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
320 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
321 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
322 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
323 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
324 m.SetPerm({2, 0, 1, 3});
325 m.Invoke();
326
327 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
328 auto result = ElementsAreArray(
329 {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
330 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
331 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
332 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
333 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
334 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
335 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
336 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
337 EXPECT_THAT(m.GetOutput(), result);
338 }
339
340 } // namespace
341 } // namespace tflite
342
main(int argc,char ** argv)343 int main(int argc, char** argv) {
344 ::tflite::LogToStderr();
345 ::testing::InitGoogleTest(&argc, argv);
346 return RUN_ALL_TESTS();
347 }
348