1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <gtest/gtest.h>
17 #include "tensorflow/lite/interpreter.h"
18 #include "tensorflow/lite/kernels/register.h"
19 #include "tensorflow/lite/kernels/test_util.h"
20 #include "tensorflow/lite/model.h"
21
22 namespace tflite {
23 namespace {
24
25 using ::testing::ElementsAreArray;
26
27 class BatchToSpaceNDOpModel : public SingleOpModel {
28 public:
29 template <typename T>
SetInput(std::initializer_list<T> data)30 void SetInput(std::initializer_list<T> data) {
31 PopulateTensor<T>(input_, data);
32 }
33
SetBlockShape(std::initializer_list<int> data)34 void SetBlockShape(std::initializer_list<int> data) {
35 PopulateTensor<int>(block_shape_, data);
36 }
37
SetCrops(std::initializer_list<int> data)38 void SetCrops(std::initializer_list<int> data) {
39 PopulateTensor<int>(crops_, data);
40 }
41
42 template <typename T>
GetOutput()43 std::vector<T> GetOutput() {
44 return ExtractVector<T>(output_);
45 }
GetOutputShape()46 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
47
48 protected:
49 int input_;
50 int block_shape_;
51 int crops_;
52 int output_;
53 };
54
55 // Tests case where block_shape and crops are const tensors.
56 //
57 // Example usage is as follows:
58 // BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops);
59 // m.SetInput(input_data);
60 // m.Invoke();
61 class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
62 public:
BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> crops,const TensorType & type=TensorType_FLOAT32)63 BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
64 std::initializer_list<int> block_shape,
65 std::initializer_list<int> crops,
66 const TensorType& type = TensorType_FLOAT32) {
67 input_ = AddInput(type);
68 block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
69 crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
70 output_ = AddOutput(type);
71
72 SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
73 BuiltinOptions_BatchToSpaceNDOptions,
74 CreateBatchToSpaceNDOptions(builder_).Union());
75 BuildInterpreter({input_shape});
76 }
77 };
78
79 // Tests case where block_shape and crops are non-const tensors.
80 //
81 // Example usage is as follows:
82 // BatchToSpaceNDOpDynamicModel m(input_shape);
83 // m.SetInput(input_data);
84 // m.SetBlockShape(block_shape);
85 // m.SetPaddings(crops);
86 // m.Invoke();
87 class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
88 public:
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,const TensorType & type=TensorType_FLOAT32)89 BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,
90 const TensorType& type = TensorType_FLOAT32) {
91 input_ = AddInput(type);
92 block_shape_ = AddInput(TensorType_INT32);
93 crops_ = AddInput(TensorType_INT32);
94 output_ = AddOutput(type);
95
96 SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
97 BuiltinOptions_BatchToSpaceNDOptions,
98 CreateBatchToSpaceNDOptions(builder_).Union());
99 BuildInterpreter({input_shape, {2}, {2, 2}});
100 }
101 };
102
TEST(BatchToSpaceNDOpTest,SimpleConstTest)103 TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
104 BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
105 m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
106 m.Invoke();
107 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
108 EXPECT_THAT(m.GetOutput<float>(),
109 ElementsAreArray(
110 {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
111 }
112
TEST(BatchToSpaceNDOpTest,SimpleConstTestInt8)113 TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) {
114 BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0},
115 TensorType_INT8);
116 m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
117 m.Invoke();
118 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
119 EXPECT_THAT(m.GetOutput<int8_t>(),
120 ElementsAreArray(
121 {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
122 }
123
TEST(BatchToSpaceNDOpTest,SimpleDynamicTest)124 TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
125 BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
126 m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
127 m.SetBlockShape({2, 2});
128 m.SetCrops({0, 0, 0, 0});
129 m.Invoke();
130 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
131 EXPECT_THAT(m.GetOutput<float>(),
132 ElementsAreArray(
133 {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
134 }
135
TEST(BatchToSpaceNDOpTest,SimpleDynamicTestInt8)136 TEST(BatchToSpaceNDOpTest, SimpleDynamicTestInt8) {
137 BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}, TensorType_INT8);
138 m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
139 m.SetBlockShape({2, 2});
140 m.SetCrops({0, 0, 0, 0});
141 m.Invoke();
142 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
143 EXPECT_THAT(m.GetOutput<int8_t>(),
144 ElementsAreArray(
145 {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
146 }
147
148 #ifdef GTEST_HAS_DEATH_TEST
TEST(BatchToSpaceNDOpTest,InvalidShapeTest)149 TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
150 EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}),
151 "Cannot allocate tensors");
152 }
153
TEST(BatchToSpaceNDOpTest,InvalidCropsConstTest)154 TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
155 EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, -1}),
156 "crops.3. >= 0 was not true.");
157 }
158
TEST(BatchToSpaceNDOpTest,InvalidCropsDynamicTest)159 TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
160 BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
161 m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
162 m.SetBlockShape({2, 2});
163 m.SetCrops({0, 0, -1, 0});
164 EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true.");
165 }
166 #endif
167
168 } // namespace
169 } // namespace tflite
170
main(int argc,char ** argv)171 int main(int argc, char** argv) {
172 ::tflite::LogToStderr();
173 ::testing::InitGoogleTest(&argc, argv);
174 return RUN_ALL_TESTS();
175 }
176