1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16 
17 #include "flatbuffers/flexbuffers.h"
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/schema/schema_generated.h"
21 #include "tensorflow/lite/version.h"
22 
23 namespace toco {
24 
25 namespace tflite {
26 namespace {
27 
28 using ::testing::ElementsAre;
29 
30 using flatbuffers::Offset;
31 using flatbuffers::Vector;
32 class ImportTest : public ::testing::Test {
33  protected:
34   template <typename T>
CreateDataVector(const std::vector<T> & data)35   Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
36     return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
37                                  sizeof(T) * data.size());
38   }
39 
BuildBuffers()40   Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
41     auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
42     auto buf1 = ::tflite::CreateBuffer(
43         builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
44     auto buf2 =
45         ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
46     return builder_.CreateVector(
47         std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
48   }
49 
BuildTensors()50   Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
51     auto q = ::tflite::CreateQuantizationParameters(
52         builder_,
53         /*min=*/builder_.CreateVector<float>({0.1f}),
54         /*max=*/builder_.CreateVector<float>({0.2f}),
55         /*scale=*/builder_.CreateVector<float>({0.3f}),
56         /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
57     auto t1 =
58         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
59                                ::tflite::TensorType_FLOAT32, 1,
60                                builder_.CreateString("tensor_one"), q);
61     auto t2 =
62         ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
63                                ::tflite::TensorType_FLOAT32, 0,
64                                builder_.CreateString("tensor_two"), q);
65     return builder_.CreateVector(
66         std::vector<Offset<::tflite::Tensor>>({t1, t2}));
67   }
68 
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)69   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
70       std::initializer_list<::tflite::BuiltinOperator> op_codes) {
71     std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
72     for (auto op : op_codes) {
73       op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
74     }
75     return builder_.CreateVector(op_codes_vector);
76   }
77 
BuildOpCodes()78   Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
79     return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
80                          ::tflite::BuiltinOperator_CONV_2D});
81   }
82 
BuildOperators(std::initializer_list<int> inputs,std::initializer_list<int> outputs)83   Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
84       std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
85     auto is = builder_.CreateVector<int>(inputs);
86     if (inputs.size() == 0) is = 0;
87     auto os = builder_.CreateVector<int>(outputs);
88     if (outputs.size() == 0) os = 0;
89     auto op = ::tflite::CreateOperator(
90         builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
91         ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
92                                       ::tflite::ActivationFunctionType_NONE)
93             .Union(),
94         /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS);
95 
96     return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
97   }
98 
BuildOperators()99   Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
100     return BuildOperators({0}, {1});
101   }
102 
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)103   Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
104       Offset<Vector<Offset<::tflite::Tensor>>> tensors,
105       Offset<Vector<Offset<::tflite::Operator>>> operators,
106       int num_sub_graphs = 1) {
107     std::vector<int32_t> inputs = {0};
108     std::vector<int32_t> outputs = {1};
109     std::vector<Offset<::tflite::SubGraph>> v;
110     for (int i = 0; i < num_sub_graphs; ++i) {
111       v.push_back(::tflite::CreateSubGraph(
112           builder_, tensors, builder_.CreateVector(inputs),
113           builder_.CreateVector(outputs), operators,
114           builder_.CreateString("subgraph")));
115     }
116     return builder_.CreateVector(v);
117   }
118 
119   // This is a very simplistic model. We are not interested in testing all the
120   // details here, since tf.mini's testing framework will be exercising all the
121   // conversions multiple times, and the conversion of operators is tested by
122   // separate unittests.
BuildTestModel()123   void BuildTestModel() {
124     auto buffers = BuildBuffers();
125     auto tensors = BuildTensors();
126     auto opcodes = BuildOpCodes();
127     auto operators = BuildOperators();
128     auto subgraphs = BuildSubGraphs(tensors, operators);
129     auto s = builder_.CreateString("");
130 
131     ::tflite::FinishModelBuffer(
132         builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
133                                         opcodes, subgraphs, s, buffers));
134 
135     input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
136   }
InputModelAsString()137   string InputModelAsString() {
138     return string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
139                   builder_.GetSize());
140   }
141   flatbuffers::FlatBufferBuilder builder_;
142   const ::tflite::Model* input_model_ = nullptr;
143 };
144 
TEST_F(ImportTest,LoadTensorsTable)145 TEST_F(ImportTest, LoadTensorsTable) {
146   BuildTestModel();
147 
148   details::TensorsTable tensors;
149   details::LoadTensorsTable(*input_model_, &tensors);
150   EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
151 }
152 
TEST_F(ImportTest,LoadOperatorsTable)153 TEST_F(ImportTest, LoadOperatorsTable) {
154   BuildTestModel();
155 
156   details::OperatorsTable operators;
157   details::LoadOperatorsTable(*input_model_, &operators);
158   EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D"));
159 }
160 
TEST_F(ImportTest,Tensors)161 TEST_F(ImportTest, Tensors) {
162   BuildTestModel();
163 
164   auto model = Import(ModelFlags(), InputModelAsString());
165 
166   ASSERT_GT(model->HasArray("tensor_one"), 0);
167   Array& a1 = model->GetArray("tensor_one");
168   EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
169   EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
170               ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
171   ASSERT_TRUE(a1.has_shape());
172   EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));
173 
174   const auto& mm = a1.minmax;
175   ASSERT_TRUE(mm.get());
176   EXPECT_FLOAT_EQ(0.1, mm->min);
177   EXPECT_FLOAT_EQ(0.2, mm->max);
178 
179   const auto& q = a1.quantization_params;
180   ASSERT_TRUE(q.get());
181   EXPECT_FLOAT_EQ(0.3, q->scale);
182   EXPECT_EQ(100, q->zero_point);
183 }
184 
TEST_F(ImportTest,NoBuffers)185 TEST_F(ImportTest, NoBuffers) {
186   auto buffers = 0;
187   auto tensors = BuildTensors();
188   auto opcodes = BuildOpCodes();
189   auto operators = BuildOperators();
190   auto subgraphs = BuildSubGraphs(tensors, operators);
191   auto comment = builder_.CreateString("");
192   ::tflite::FinishModelBuffer(
193       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
194                                       subgraphs, comment, buffers));
195   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
196                "Missing 'buffers' section.");
197 }
198 
TEST_F(ImportTest,NoInputs)199 TEST_F(ImportTest, NoInputs) {
200   auto buffers = BuildBuffers();
201   auto tensors = BuildTensors();
202   auto opcodes = BuildOpCodes();
203   auto operators = BuildOperators({}, {1});
204   auto subgraphs = BuildSubGraphs(tensors, operators);
205   auto comment = builder_.CreateString("");
206   ::tflite::FinishModelBuffer(
207       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
208                                       subgraphs, comment, buffers));
209   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
210                "Missing 'inputs' for operator.");
211 }
212 
TEST_F(ImportTest,NoOutputs)213 TEST_F(ImportTest, NoOutputs) {
214   auto buffers = BuildBuffers();
215   auto tensors = BuildTensors();
216   auto opcodes = BuildOpCodes();
217   auto operators = BuildOperators({0}, {});
218   auto subgraphs = BuildSubGraphs(tensors, operators);
219   auto comment = builder_.CreateString("");
220   ::tflite::FinishModelBuffer(
221       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
222                                       subgraphs, comment, buffers));
223   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
224                "Missing 'outputs' for operator.");
225 }
226 
TEST_F(ImportTest,InvalidOpCode)227 TEST_F(ImportTest, InvalidOpCode) {
228   auto buffers = BuildBuffers();
229   auto tensors = BuildTensors();
230   auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
231                                ::tflite::BuiltinOperator_CONV_2D});
232   auto operators = BuildOperators();
233   auto subgraphs = BuildSubGraphs(tensors, operators);
234   auto comment = builder_.CreateString("");
235   ::tflite::FinishModelBuffer(
236       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
237                                       subgraphs, comment, buffers));
238   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
239                "Operator id '-1' is out of range.");
240 }
241 
TEST_F(ImportTest,MultipleSubGraphs)242 TEST_F(ImportTest, MultipleSubGraphs) {
243   auto buffers = BuildBuffers();
244   auto tensors = BuildTensors();
245   auto opcodes = BuildOpCodes();
246   auto operators = BuildOperators();
247   auto subgraphs = BuildSubGraphs(tensors, operators, 2);
248   auto comment = builder_.CreateString("");
249   ::tflite::FinishModelBuffer(
250       builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
251                                       subgraphs, comment, buffers));
252 
253   input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
254 
255   EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
256                "Number of subgraphs in tflite should be exactly 1.");
257 }
258 
259 // TODO(ahentz): still need tests for Operators and IOTensors.
260 
261 }  // namespace
262 }  // namespace tflite
263 
264 }  // namespace toco
265