1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <gtest/gtest.h>
16 #include "tensorflow/contrib/lite/interpreter.h"
17 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
18 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
19 #include "tensorflow/contrib/lite/kernels/register.h"
20 #include "tensorflow/contrib/lite/kernels/test_util.h"
21 #include "tensorflow/contrib/lite/model.h"
22 
23 namespace tflite {
24 namespace {
25 
26 using ::testing::ElementsAreArray;
27 
RunTestPermutation(const std::vector<int> & shape,const std::vector<int> & perms,std::vector<float> * input_transposed)28 void RunTestPermutation(const std::vector<int>& shape,
29                         const std::vector<int>& perms,
30                         std::vector<float>* input_transposed) {
31   // Count elements and allocate output.
32   int count = 1;
33   for (auto factor : shape) count *= factor;
34   input_transposed->resize(count);
35 
36   // Create the dummy data
37   std::vector<float> input(count);
38   for (int i = 0; i < input.size(); i++) {
39     input[i] = i;
40   }
41 
42   // Create reversed and padded perms.
43   int reversed_perms[4];
44   for (int output_k = 0, input_k = shape.size() - 1; output_k < shape.size();
45        output_k++, input_k--) {
46     reversed_perms[output_k] = shape.size() - perms[input_k] - 1;
47   }
48   // Unused dimensions should not be permuted so pad with identity transform
49   // subset.
50   for (int k = shape.size(); k < 4; k++) {
51     reversed_perms[k] = k;
52   }
53 
54   // Make input and output dims (i.e. reversed shape and dest_shape).
55   Dims<4> input_dims = GetTensorDims(shape);
56   Dims<4> output_dims;
57   for (int i = 0; i < 4; i++) {
58     output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
59   }
60   output_dims.strides[0] = 1;
61   for (int k = 1; k < 4; k++) {
62     output_dims.strides[k] =
63         output_dims.strides[k - 1] * output_dims.sizes[k - 1];
64   }
65 
66   reference_ops::Transpose<float>(input.data(), input_dims,
67                                   input_transposed->data(), output_dims,
68                                   reversed_perms);
69 }
70 
TEST(TransposeTest,TestRefOps1D)71 TEST(TransposeTest, TestRefOps1D) {
72   // Basic 1D identity.
73   std::vector<float> out;
74   RunTestPermutation({3}, {0}, &out);
75   ASSERT_EQ(out, std::vector<float>({0, 1, 2}));
76 }
77 
TEST(TransposeTest,TestRefOps2D)78 TEST(TransposeTest, TestRefOps2D) {
79   std::vector<float> out;
80   // Basic 2D.
81   RunTestPermutation({3, 2}, {1, 0}, &out);
82   ASSERT_EQ(out, std::vector<float>({0, 2, 4, 1, 3, 5}));
83   // Identity.
84   RunTestPermutation({3, 2}, {0, 1}, &out);
85   ASSERT_EQ(out, std::vector<float>({0, 1, 2, 3, 4, 5}));
86 }
87 
TEST(TransposeTest,TestRefOps3D)88 TEST(TransposeTest, TestRefOps3D) {
89   std::vector<float> out;
90   // Test 3 dimensional
91   {
92     std::vector<float> ref({0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
93                             2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23});
94     RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out);
95     ASSERT_EQ(out, ref);
96   }
97   // Test 3 dimensional identity transform
98   {
99     RunTestPermutation({2, 3, 4}, {0, 1, 2}, &out);
100     std::vector<float> ref(out.size());
101     for (int k = 0; k < ref.size(); k++) ref[k] = k;
102     ASSERT_EQ(out, ref);
103   }
104 }
105 
TEST(TransposeTest,TestRefOps4D)106 TEST(TransposeTest, TestRefOps4D) {
107   std::vector<float> out;
108   // Basic 4d.
109   RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
110   ASSERT_EQ(
111       out,
112       std::vector<float>(
113           {0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
114            60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
115            5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
116            65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
117            10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
118            70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
119            15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
120            75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}));
121   RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out);
122   // Basic identity.
123   std::vector<float> ref(out.size());
124   for (int k = 0; k < ref.size(); k++) ref[k] = k;
125   ASSERT_EQ(out, ref);
126 }
127 
128 class TransposeOpModel : public SingleOpModel {
129  public:
SetInput(std::initializer_list<float> data)130   void SetInput(std::initializer_list<float> data) {
131     PopulateTensor<float>(input_, data);
132   }
133 
SetPerm(std::initializer_list<int> data)134   void SetPerm(std::initializer_list<int> data) {
135     PopulateTensor<int>(perm_, data);
136   }
137 
GetOutput()138   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()139   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
140 
141  protected:
142   int input_;
143   int perm_;
144   int output_;
145 };
146 
147 // Tests case where perm is a const tensor.
148 //
149 // Example usage is as follows:
150 //    SpaceToBatchNDOpConstModel m(input_shape, perm_shape, perm_data);
151 //    m.SetInput(input_data);
152 //    m.Invoke();
153 class TransposeOpConstModel : public TransposeOpModel {
154  public:
TransposeOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape,std::initializer_list<int> perm)155   TransposeOpConstModel(std::initializer_list<int> input_shape,
156                         std::initializer_list<int> perm_shape,
157                         std::initializer_list<int> perm) {
158     input_ = AddInput(TensorType_FLOAT32);
159     perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
160     output_ = AddOutput(TensorType_FLOAT32);
161     SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
162                  CreateTransposeOptions(builder_).Union());
163     BuildInterpreter({input_shape});
164   }
165 };
166 
167 // Tests case where perm is a non-const tensor.
168 //
169 // Example usage is as follows:
170 //    TransposeOpDynamicModel m(input_shape, perm_shape);
171 //    m.SetInput(input_data);
172 //    m.SetPerm(perm_data);
173 //    m.Invoke();
174 class TransposeOpDynamicModel : public TransposeOpModel {
175  public:
TransposeOpDynamicModel(std::initializer_list<int> input_shape,std::initializer_list<int> perm_shape)176   TransposeOpDynamicModel(std::initializer_list<int> input_shape,
177                           std::initializer_list<int> perm_shape) {
178     input_ = AddInput(TensorType_FLOAT32);
179     perm_ = AddInput(TensorType_INT32);
180     output_ = AddOutput(TensorType_FLOAT32);
181     SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
182                  CreateTransposeOptions(builder_).Union());
183     BuildInterpreter({input_shape, perm_shape});
184   }
185 };
186 
TEST(TransposeTest,TestUnequalPermSize)187 TEST(TransposeTest, TestUnequalPermSize) {
188   EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {2}, {2, 2}), "2 != 4");
189 }
190 
TEST(TransposeTest,TestPermOutOfBounds)191 TEST(TransposeTest, TestPermOutOfBounds) {
192   EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, -1, -2, -3}),
193                "Transpose op permutations array is out of bounds.");
194   EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, 1, 2, 4}),
195                "Transpose op permutations array is out of bounds.");
196 }
197 
TEST(TransposeTest,Test1DInputConstTensor)198 TEST(TransposeTest, Test1DInputConstTensor) {
199   TransposeOpConstModel m({3}, {1}, {0});
200   m.SetInput({1, 2, 3});
201   m.Invoke();
202   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
203   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
204 }
205 
TEST(TransposeTest,Test1DInputDynamicTensor)206 TEST(TransposeTest, Test1DInputDynamicTensor) {
207   TransposeOpDynamicModel m({3}, {1});
208   m.SetInput({1, 2, 3});
209   m.SetPerm({0});
210   m.Invoke();
211   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
212   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
213 }
214 
TEST(TransposeTest,Test2DInputConstTensor)215 TEST(TransposeTest, Test2DInputConstTensor) {
216   TransposeOpConstModel m({3, 2}, {2}, {1, 0});
217   m.SetInput({0, 1, 2, 3, 4, 5});
218   m.Invoke();
219   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
220   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
221 }
222 
TEST(TransposeTest,Test2DInputDynamicTensor)223 TEST(TransposeTest, Test2DInputDynamicTensor) {
224   TransposeOpDynamicModel m({3, 2}, {2});
225   m.SetInput({0, 1, 2, 3, 4, 5});
226   m.SetPerm({1, 0});
227   m.Invoke();
228   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
229   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
230 }
231 
TEST(TransposeTest,Test3DInputConstTensor)232 TEST(TransposeTest, Test3DInputConstTensor) {
233   TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1});
234   m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
235               12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
236   m.Invoke();
237   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
238   EXPECT_THAT(m.GetOutput(),
239               ElementsAreArray({0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
240                                 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
241 }
242 
TEST(TransposeTest,Test3DInputDynamicTensor)243 TEST(TransposeTest, Test3DInputDynamicTensor) {
244   TransposeOpDynamicModel m({2, 3, 4}, {3});
245   m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
246               12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
247   m.SetPerm({2, 0, 1});
248   m.Invoke();
249   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
250   EXPECT_THAT(m.GetOutput(),
251               ElementsAreArray({0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
252                                 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
253 }
254 
TEST(TransposeTest,Test5DInputTensor)255 TEST(TransposeTest, Test5DInputTensor) {
256   EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}),
257                "Transpose op only supports 1D-4D input arrays.");
258 }
259 
TEST(TransposeTest,SimpleTestNoReorderConstTensor)260 TEST(TransposeTest, SimpleTestNoReorderConstTensor) {
261   TransposeOpConstModel m({1, 2, 3, 1}, {4}, {0, 1, 2, 3});
262   m.SetInput({1, 2, 3, 4, 5, 6});
263   m.Invoke();
264   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
265   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
266 }
267 
TEST(TransposeTest,SimpleTestNoReorderDynamicTensor)268 TEST(TransposeTest, SimpleTestNoReorderDynamicTensor) {
269   TransposeOpDynamicModel m({1, 2, 3, 1}, {4});
270   m.SetInput({1, 2, 3, 4, 5, 6});
271   m.SetPerm({0, 1, 2, 3});
272   m.Invoke();
273   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
274   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
275 }
276 
TEST(TransposeTest,SimpleTestWithReorderConstTensor)277 TEST(TransposeTest, SimpleTestWithReorderConstTensor) {
278   TransposeOpConstModel m({1, 2, 3, 1}, {4}, {2, 1, 3, 0});
279   m.SetInput({1, 2, 3, 4, 5, 6});
280   m.Invoke();
281   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1}));
282   EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
283 }
284 
TEST(TransposeTest,ComplexTestWithReorderConstTensor)285 TEST(TransposeTest, ComplexTestWithReorderConstTensor) {
286   TransposeOpConstModel m({2, 3, 4, 5}, {4}, {2, 0, 1, 3});
287   m.SetInput({0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,
288               12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
289               24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
290               36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
291               48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
292               60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
293               72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
294               84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
295               96,  97,  98,  99,  100, 101, 102, 103, 104, 105, 106, 107,
296               108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
297   m.Invoke();
298 
299   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
300   auto result = ElementsAreArray(
301       {0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
302        60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
303        5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
304        65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
305        10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
306        70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
307        15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
308        75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
309   EXPECT_THAT(m.GetOutput(), result);
310 }
311 
TEST(TransposeTest,ComplexTestWithReorderDynamicTensor)312 TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) {
313   TransposeOpDynamicModel m({2, 3, 4, 5}, {4});
314   m.SetInput({0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,
315               12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
316               24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
317               36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
318               48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
319               60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
320               72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
321               84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
322               96,  97,  98,  99,  100, 101, 102, 103, 104, 105, 106, 107,
323               108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
324   m.SetPerm({2, 0, 1, 3});
325   m.Invoke();
326 
327   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
328   auto result = ElementsAreArray(
329       {0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
330        60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
331        5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
332        65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
333        10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
334        70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
335        15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
336        75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
337   EXPECT_THAT(m.GetOutput(), result);
338 }
339 
340 }  // namespace
341 }  // namespace tflite
342 
main(int argc,char ** argv)343 int main(int argc, char** argv) {
344   ::tflite::LogToStderr();
345   ::testing::InitGoogleTest(&argc, argv);
346   return RUN_ALL_TESTS();
347 }
348