• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <gtest/gtest.h>
17 #include "tensorflow/lite/interpreter.h"
18 #include "tensorflow/lite/kernels/register.h"
19 #include "tensorflow/lite/kernels/test_util.h"
20 #include "tensorflow/lite/model.h"
21 
22 namespace tflite {
23 namespace {
24 
25 using ::testing::ElementsAreArray;
26 
27 class BatchToSpaceNDOpModel : public SingleOpModel {
28  public:
29   template <typename T>
SetInput(std::initializer_list<T> data)30   void SetInput(std::initializer_list<T> data) {
31     PopulateTensor<T>(input_, data);
32   }
33 
SetBlockShape(std::initializer_list<int> data)34   void SetBlockShape(std::initializer_list<int> data) {
35     PopulateTensor<int>(block_shape_, data);
36   }
37 
SetCrops(std::initializer_list<int> data)38   void SetCrops(std::initializer_list<int> data) {
39     PopulateTensor<int>(crops_, data);
40   }
41 
42   template <typename T>
GetOutput()43   std::vector<T> GetOutput() {
44     return ExtractVector<T>(output_);
45   }
GetOutputShape()46   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
47 
48  protected:
49   int input_;
50   int block_shape_;
51   int crops_;
52   int output_;
53 };
54 
55 // Tests case where block_shape and crops are const tensors.
56 //
57 // Example usage is as follows:
58 //    BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops);
59 //    m.SetInput(input_data);
60 //    m.Invoke();
61 class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
62  public:
BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> block_shape,std::initializer_list<int> crops,const TensorType & type=TensorType_FLOAT32)63   BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
64                              std::initializer_list<int> block_shape,
65                              std::initializer_list<int> crops,
66                              const TensorType& type = TensorType_FLOAT32) {
67     input_ = AddInput(type);
68     block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
69     crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
70     output_ = AddOutput(type);
71 
72     SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
73                  BuiltinOptions_BatchToSpaceNDOptions,
74                  CreateBatchToSpaceNDOptions(builder_).Union());
75     BuildInterpreter({input_shape});
76   }
77 };
78 
79 // Tests case where block_shape and crops are non-const tensors.
80 //
81 // Example usage is as follows:
82 //    BatchToSpaceNDOpDynamicModel m(input_shape);
83 //    m.SetInput(input_data);
84 //    m.SetBlockShape(block_shape);
85 //    m.SetPaddings(crops);
86 //    m.Invoke();
87 class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
88  public:
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,const TensorType & type=TensorType_FLOAT32)89   BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,
90                                const TensorType& type = TensorType_FLOAT32) {
91     input_ = AddInput(type);
92     block_shape_ = AddInput(TensorType_INT32);
93     crops_ = AddInput(TensorType_INT32);
94     output_ = AddOutput(type);
95 
96     SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
97                  BuiltinOptions_BatchToSpaceNDOptions,
98                  CreateBatchToSpaceNDOptions(builder_).Union());
99     BuildInterpreter({input_shape, {2}, {2, 2}});
100   }
101 };
102 
TEST(BatchToSpaceNDOpTest,SimpleConstTest)103 TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
104   BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
105   m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
106   m.Invoke();
107   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
108   EXPECT_THAT(m.GetOutput<float>(),
109               ElementsAreArray(
110                   {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
111 }
112 
TEST(BatchToSpaceNDOpTest,SimpleConstTestInt8)113 TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) {
114   BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0},
115                                TensorType_INT8);
116   m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
117   m.Invoke();
118   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
119   EXPECT_THAT(m.GetOutput<int8_t>(),
120               ElementsAreArray(
121                   {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
122 }
123 
TEST(BatchToSpaceNDOpTest,SimpleDynamicTest)124 TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
125   BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
126   m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
127   m.SetBlockShape({2, 2});
128   m.SetCrops({0, 0, 0, 0});
129   m.Invoke();
130   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
131   EXPECT_THAT(m.GetOutput<float>(),
132               ElementsAreArray(
133                   {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
134 }
135 
TEST(BatchToSpaceNDOpTest,SimpleDynamicTestInt8)136 TEST(BatchToSpaceNDOpTest, SimpleDynamicTestInt8) {
137   BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}, TensorType_INT8);
138   m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
139   m.SetBlockShape({2, 2});
140   m.SetCrops({0, 0, 0, 0});
141   m.Invoke();
142   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
143   EXPECT_THAT(m.GetOutput<int8_t>(),
144               ElementsAreArray(
145                   {1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
146 }
147 
148 #ifdef GTEST_HAS_DEATH_TEST
TEST(BatchToSpaceNDOpTest,InvalidShapeTest)149 TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
150   EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}),
151                "Cannot allocate tensors");
152 }
153 
TEST(BatchToSpaceNDOpTest,InvalidCropsConstTest)154 TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
155   EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, -1}),
156                "crops.3. >= 0 was not true.");
157 }
158 
TEST(BatchToSpaceNDOpTest,InvalidCropsDynamicTest)159 TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
160   BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
161   m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
162   m.SetBlockShape({2, 2});
163   m.SetCrops({0, 0, -1, 0});
164   EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true.");
165 }
166 #endif
167 
168 }  // namespace
169 }  // namespace tflite
170 
main(int argc,char ** argv)171 int main(int argc, char** argv) {
172   ::tflite::LogToStderr();
173   ::testing::InitGoogleTest(&argc, argv);
174   return RUN_ALL_TESTS();
175 }
176