1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tflite/import.h"
16
17 #include "flatbuffers/flexbuffers.h"
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/schema/schema_generated.h"
21 #include "tensorflow/lite/version.h"
22
23 namespace toco {
24
25 namespace tflite {
26 namespace {
27
28 using ::testing::ElementsAre;
29
30 using flatbuffers::Offset;
31 using flatbuffers::Vector;
32 class ImportTest : public ::testing::Test {
33 protected:
34 template <typename T>
CreateDataVector(const std::vector<T> & data)35 Offset<Vector<unsigned char>> CreateDataVector(const std::vector<T>& data) {
36 return builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.data()),
37 sizeof(T) * data.size());
38 }
39
BuildBuffers()40 Offset<Vector<Offset<::tflite::Buffer>>> BuildBuffers() {
41 auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector<float>({}));
42 auto buf1 = ::tflite::CreateBuffer(
43 builder_, CreateDataVector<float>({1.0f, 2.0f, 3.0f, 4.0f}));
44 auto buf2 =
45 ::tflite::CreateBuffer(builder_, CreateDataVector<float>({3.0f, 4.0f}));
46 return builder_.CreateVector(
47 std::vector<Offset<::tflite::Buffer>>({buf0, buf1, buf2}));
48 }
49
BuildTensors()50 Offset<Vector<Offset<::tflite::Tensor>>> BuildTensors() {
51 auto q = ::tflite::CreateQuantizationParameters(
52 builder_,
53 /*min=*/builder_.CreateVector<float>({0.1f}),
54 /*max=*/builder_.CreateVector<float>({0.2f}),
55 /*scale=*/builder_.CreateVector<float>({0.3f}),
56 /*zero_point=*/builder_.CreateVector<int64_t>({100ll}));
57 auto t1 =
58 ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({1, 2, 2}),
59 ::tflite::TensorType_FLOAT32, 1,
60 builder_.CreateString("tensor_one"), q);
61 auto t2 =
62 ::tflite::CreateTensor(builder_, builder_.CreateVector<int>({2, 1}),
63 ::tflite::TensorType_FLOAT32, 0,
64 builder_.CreateString("tensor_two"), q);
65 return builder_.CreateVector(
66 std::vector<Offset<::tflite::Tensor>>({t1, t2}));
67 }
68
BuildOpCodes(std::initializer_list<::tflite::BuiltinOperator> op_codes)69 Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes(
70 std::initializer_list<::tflite::BuiltinOperator> op_codes) {
71 std::vector<Offset<::tflite::OperatorCode>> op_codes_vector;
72 for (auto op : op_codes) {
73 op_codes_vector.push_back(::tflite::CreateOperatorCode(builder_, op, 0));
74 }
75 return builder_.CreateVector(op_codes_vector);
76 }
77
BuildOpCodes()78 Offset<Vector<Offset<::tflite::OperatorCode>>> BuildOpCodes() {
79 return BuildOpCodes({::tflite::BuiltinOperator_MAX_POOL_2D,
80 ::tflite::BuiltinOperator_CONV_2D});
81 }
82
BuildOperators(std::initializer_list<int> inputs,std::initializer_list<int> outputs)83 Offset<Vector<Offset<::tflite::Operator>>> BuildOperators(
84 std::initializer_list<int> inputs, std::initializer_list<int> outputs) {
85 auto is = builder_.CreateVector<int>(inputs);
86 if (inputs.size() == 0) is = 0;
87 auto os = builder_.CreateVector<int>(outputs);
88 if (outputs.size() == 0) os = 0;
89 auto op = ::tflite::CreateOperator(
90 builder_, 0, is, os, ::tflite::BuiltinOptions_Conv2DOptions,
91 ::tflite::CreateConv2DOptions(builder_, ::tflite::Padding_VALID, 1, 1,
92 ::tflite::ActivationFunctionType_NONE)
93 .Union(),
94 /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS);
95
96 return builder_.CreateVector(std::vector<Offset<::tflite::Operator>>({op}));
97 }
98
BuildOperators()99 Offset<Vector<Offset<::tflite::Operator>>> BuildOperators() {
100 return BuildOperators({0}, {1});
101 }
102
BuildSubGraphs(Offset<Vector<Offset<::tflite::Tensor>>> tensors,Offset<Vector<Offset<::tflite::Operator>>> operators,int num_sub_graphs=1)103 Offset<Vector<Offset<::tflite::SubGraph>>> BuildSubGraphs(
104 Offset<Vector<Offset<::tflite::Tensor>>> tensors,
105 Offset<Vector<Offset<::tflite::Operator>>> operators,
106 int num_sub_graphs = 1) {
107 std::vector<int32_t> inputs = {0};
108 std::vector<int32_t> outputs = {1};
109 std::vector<Offset<::tflite::SubGraph>> v;
110 for (int i = 0; i < num_sub_graphs; ++i) {
111 v.push_back(::tflite::CreateSubGraph(
112 builder_, tensors, builder_.CreateVector(inputs),
113 builder_.CreateVector(outputs), operators,
114 builder_.CreateString("subgraph")));
115 }
116 return builder_.CreateVector(v);
117 }
118
119 // This is a very simplistic model. We are not interested in testing all the
120 // details here, since tf.mini's testing framework will be exercising all the
121 // conversions multiple times, and the conversion of operators is tested by
122 // separate unittests.
BuildTestModel()123 void BuildTestModel() {
124 auto buffers = BuildBuffers();
125 auto tensors = BuildTensors();
126 auto opcodes = BuildOpCodes();
127 auto operators = BuildOperators();
128 auto subgraphs = BuildSubGraphs(tensors, operators);
129 auto s = builder_.CreateString("");
130
131 ::tflite::FinishModelBuffer(
132 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
133 opcodes, subgraphs, s, buffers));
134
135 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
136 }
InputModelAsString()137 string InputModelAsString() {
138 return string(reinterpret_cast<char*>(builder_.GetBufferPointer()),
139 builder_.GetSize());
140 }
141 flatbuffers::FlatBufferBuilder builder_;
142 const ::tflite::Model* input_model_ = nullptr;
143 };
144
TEST_F(ImportTest,LoadTensorsTable)145 TEST_F(ImportTest, LoadTensorsTable) {
146 BuildTestModel();
147
148 details::TensorsTable tensors;
149 details::LoadTensorsTable(*input_model_, &tensors);
150 EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two"));
151 }
152
TEST_F(ImportTest,LoadOperatorsTable)153 TEST_F(ImportTest, LoadOperatorsTable) {
154 BuildTestModel();
155
156 details::OperatorsTable operators;
157 details::LoadOperatorsTable(*input_model_, &operators);
158 EXPECT_THAT(operators, ElementsAre("MAX_POOL_2D", "CONV_2D"));
159 }
160
TEST_F(ImportTest,Tensors)161 TEST_F(ImportTest, Tensors) {
162 BuildTestModel();
163
164 auto model = Import(ModelFlags(), InputModelAsString());
165
166 ASSERT_GT(model->HasArray("tensor_one"), 0);
167 Array& a1 = model->GetArray("tensor_one");
168 EXPECT_EQ(ArrayDataType::kFloat, a1.data_type);
169 EXPECT_THAT(a1.GetBuffer<ArrayDataType::kFloat>().data,
170 ElementsAre(1.0f, 2.0f, 3.0f, 4.0f));
171 ASSERT_TRUE(a1.has_shape());
172 EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 2));
173
174 const auto& mm = a1.minmax;
175 ASSERT_TRUE(mm.get());
176 EXPECT_FLOAT_EQ(0.1, mm->min);
177 EXPECT_FLOAT_EQ(0.2, mm->max);
178
179 const auto& q = a1.quantization_params;
180 ASSERT_TRUE(q.get());
181 EXPECT_FLOAT_EQ(0.3, q->scale);
182 EXPECT_EQ(100, q->zero_point);
183 }
184
TEST_F(ImportTest,NoBuffers)185 TEST_F(ImportTest, NoBuffers) {
186 auto buffers = 0;
187 auto tensors = BuildTensors();
188 auto opcodes = BuildOpCodes();
189 auto operators = BuildOperators();
190 auto subgraphs = BuildSubGraphs(tensors, operators);
191 auto comment = builder_.CreateString("");
192 ::tflite::FinishModelBuffer(
193 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
194 subgraphs, comment, buffers));
195 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
196 "Missing 'buffers' section.");
197 }
198
TEST_F(ImportTest,NoInputs)199 TEST_F(ImportTest, NoInputs) {
200 auto buffers = BuildBuffers();
201 auto tensors = BuildTensors();
202 auto opcodes = BuildOpCodes();
203 auto operators = BuildOperators({}, {1});
204 auto subgraphs = BuildSubGraphs(tensors, operators);
205 auto comment = builder_.CreateString("");
206 ::tflite::FinishModelBuffer(
207 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
208 subgraphs, comment, buffers));
209 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
210 "Missing 'inputs' for operator.");
211 }
212
TEST_F(ImportTest,NoOutputs)213 TEST_F(ImportTest, NoOutputs) {
214 auto buffers = BuildBuffers();
215 auto tensors = BuildTensors();
216 auto opcodes = BuildOpCodes();
217 auto operators = BuildOperators({0}, {});
218 auto subgraphs = BuildSubGraphs(tensors, operators);
219 auto comment = builder_.CreateString("");
220 ::tflite::FinishModelBuffer(
221 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
222 subgraphs, comment, buffers));
223 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
224 "Missing 'outputs' for operator.");
225 }
226
TEST_F(ImportTest,InvalidOpCode)227 TEST_F(ImportTest, InvalidOpCode) {
228 auto buffers = BuildBuffers();
229 auto tensors = BuildTensors();
230 auto opcodes = BuildOpCodes({static_cast<::tflite::BuiltinOperator>(-1),
231 ::tflite::BuiltinOperator_CONV_2D});
232 auto operators = BuildOperators();
233 auto subgraphs = BuildSubGraphs(tensors, operators);
234 auto comment = builder_.CreateString("");
235 ::tflite::FinishModelBuffer(
236 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
237 subgraphs, comment, buffers));
238 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
239 "Operator id '-1' is out of range.");
240 }
241
TEST_F(ImportTest,MultipleSubGraphs)242 TEST_F(ImportTest, MultipleSubGraphs) {
243 auto buffers = BuildBuffers();
244 auto tensors = BuildTensors();
245 auto opcodes = BuildOpCodes();
246 auto operators = BuildOperators();
247 auto subgraphs = BuildSubGraphs(tensors, operators, 2);
248 auto comment = builder_.CreateString("");
249 ::tflite::FinishModelBuffer(
250 builder_, ::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes,
251 subgraphs, comment, buffers));
252
253 input_model_ = ::tflite::GetModel(builder_.GetBufferPointer());
254
255 EXPECT_DEATH(Import(ModelFlags(), InputModelAsString()),
256 "Number of subgraphs in tflite should be exactly 1.");
257 }
258
259 // TODO(ahentz): still need tests for Operators and IOTensors.
260
261 } // namespace
262 } // namespace tflite
263
264 } // namespace toco
265