1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <fcntl.h>
16 #include <stdint.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <sys/mman.h>
20 #include <sys/stat.h>
21 #include <sys/types.h>
22
23 #include "tensorflow/lite/model.h"
24
25 #include <gtest/gtest.h>
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/kernels/register.h"
28 #include "tensorflow/lite/testing/util.h"
29
30 // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
31 // we must declare this in global namespace, so argument-dependent operator
32 // lookup works.
operator ==(const TfLiteRegistration & a,const TfLiteRegistration & b)33 inline bool operator==(const TfLiteRegistration& a,
34 const TfLiteRegistration& b) {
35 return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare &&
36 a.free == b.free;
37 }
38
39 namespace tflite {
40
41 // Provide a dummy operation that does nothing.
42 namespace {
dummy_init(TfLiteContext *,const char *,size_t)43 void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; }
dummy_free(TfLiteContext *,void *)44 void dummy_free(TfLiteContext*, void*) {}
dummy_resize(TfLiteContext *,TfLiteNode *)45 TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
dummy_invoke(TfLiteContext *,TfLiteNode *)46 TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
47 TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize,
48 dummy_invoke};
49 } // namespace
50
51 // Provide a trivial resolver that returns a constant value no matter what
52 // op is asked for.
53 class TrivialResolver : public OpResolver {
54 public:
TrivialResolver(TfLiteRegistration * constant_return=nullptr)55 explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr)
56 : constant_return_(constant_return) {}
57 // Find the op registration of a custom operator by op name.
FindOp(tflite::BuiltinOperator op,int version) const58 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
59 int version) const override {
60 return constant_return_;
61 }
62 // Find the op registration of a custom operator by op name.
FindOp(const char * op,int version) const63 const TfLiteRegistration* FindOp(const char* op, int version) const override {
64 return constant_return_;
65 }
66
67 private:
68 TfLiteRegistration* constant_return_;
69 };
70
TEST(BasicFlatBufferModel,TestNonExistantFiles)71 TEST(BasicFlatBufferModel, TestNonExistantFiles) {
72 ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234"));
73 }
74
75 // Make sure a model with nothing in it loads properly.
TEST(BasicFlatBufferModel,TestEmptyModelsAndNullDestination)76 TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
77 auto model = FlatBufferModel::BuildFromFile(
78 "tensorflow/lite/testdata/empty_model.bin");
79 ASSERT_TRUE(model);
80 // Now try to build it into a model.
81 std::unique_ptr<Interpreter> interpreter;
82 ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
83 kTfLiteOk);
84 ASSERT_NE(interpreter, nullptr);
85 ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
86 }
87
88 // Make sure currently unsupported # of subgraphs are checked
89 // TODO(aselle): Replace this test when multiple subgraphs are supported.
TEST(BasicFlatBufferModel,TestZeroSubgraphs)90 TEST(BasicFlatBufferModel, TestZeroSubgraphs) {
91 auto m = FlatBufferModel::BuildFromFile(
92 "tensorflow/lite/testdata/0_subgraphs.bin");
93 ASSERT_TRUE(m);
94 std::unique_ptr<Interpreter> interpreter;
95 ASSERT_NE(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
96 }
97
TEST(BasicFlatBufferModel,TestMultipleSubgraphs)98 TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
99 auto m = FlatBufferModel::BuildFromFile(
100 "tensorflow/lite/testdata/2_subgraphs.bin");
101 ASSERT_TRUE(m);
102 std::unique_ptr<Interpreter> interpreter;
103 ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
104 EXPECT_EQ(interpreter->subgraphs_size(), 2);
105 }
106
107 // Test what happens if we cannot bind any of the ops.
TEST(BasicFlatBufferModel,TestModelWithoutNullRegistrations)108 TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
109 auto model = FlatBufferModel::BuildFromFile(
110 "tensorflow/lite/testdata/test_model.bin");
111 ASSERT_TRUE(model);
112 // Check that we get an error code and interpreter pointer is reset.
113 std::unique_ptr<Interpreter> interpreter(new Interpreter);
114 ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter),
115 kTfLiteOk);
116 ASSERT_EQ(interpreter, nullptr);
117 }
118
119 // Make sure model is read to interpreter properly
TEST(BasicFlatBufferModel,TestModelInInterpreter)120 TEST(BasicFlatBufferModel, TestModelInInterpreter) {
121 auto model = FlatBufferModel::BuildFromFile(
122 "tensorflow/lite/testdata/test_model.bin");
123 ASSERT_TRUE(model);
124 // Check that we get an error code and interpreter pointer is reset.
125 std::unique_ptr<Interpreter> interpreter(new Interpreter);
126 ASSERT_EQ(
127 InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
128 kTfLiteOk);
129 ASSERT_NE(interpreter, nullptr);
130 ASSERT_EQ(interpreter->tensors_size(), 4);
131 ASSERT_EQ(interpreter->nodes_size(), 2);
132 std::vector<int> inputs = {0, 1};
133 std::vector<int> outputs = {2, 3};
134 ASSERT_EQ(interpreter->inputs(), inputs);
135 ASSERT_EQ(interpreter->outputs(), outputs);
136
137 EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0");
138 EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1");
139 EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1");
140 EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2");
141
142 // Make sure all input tensors are correct
143 TfLiteTensor* i0 = interpreter->tensor(0);
144 ASSERT_EQ(i0->type, kTfLiteFloat32);
145 ASSERT_NE(i0->data.raw, nullptr); // mmapped
146 ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo);
147 TfLiteTensor* i1 = interpreter->tensor(1);
148 ASSERT_EQ(i1->type, kTfLiteFloat32);
149 ASSERT_EQ(i1->data.raw, nullptr);
150 ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw);
151 TfLiteTensor* o0 = interpreter->tensor(2);
152 ASSERT_EQ(o0->type, kTfLiteFloat32);
153 ASSERT_EQ(o0->data.raw, nullptr);
154 ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw);
155 TfLiteTensor* o1 = interpreter->tensor(3);
156 ASSERT_EQ(o1->type, kTfLiteFloat32);
157 ASSERT_EQ(o1->data.raw, nullptr);
158 ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw);
159
160 // Check op 0 which has inputs {0, 1} outputs {2}.
161 {
162 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg0 =
163 interpreter->node_and_registration(0);
164 ASSERT_NE(node_and_reg0, nullptr);
165 const TfLiteNode& node0 = node_and_reg0->first;
166 const TfLiteRegistration& reg0 = node_and_reg0->second;
167 TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2);
168 desired_inputs->data[0] = 0;
169 desired_inputs->data[1] = 1;
170 TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
171 desired_outputs->data[0] = 2;
172 ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs));
173 ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs));
174 TfLiteIntArrayFree(desired_inputs);
175 TfLiteIntArrayFree(desired_outputs);
176 ASSERT_EQ(reg0, dummy_reg);
177 }
178
179 // Check op 1 which has inputs {2} outputs {3}.
180 {
181 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg1 =
182 interpreter->node_and_registration(1);
183 ASSERT_NE(node_and_reg1, nullptr);
184 const TfLiteNode& node1 = node_and_reg1->first;
185 const TfLiteRegistration& reg1 = node_and_reg1->second;
186 TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1);
187 TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
188 desired_inputs->data[0] = 2;
189 desired_outputs->data[0] = 3;
190 ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs));
191 ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs));
192 TfLiteIntArrayFree(desired_inputs);
193 TfLiteIntArrayFree(desired_outputs);
194 ASSERT_EQ(reg1, dummy_reg);
195 }
196 }
197
198 // Test that loading a model with TensorFlow ops fails when the flex delegate is
199 // not linked into the target.
TEST(FlexModel,FailureWithoutFlexDelegate)200 TEST(FlexModel, FailureWithoutFlexDelegate) {
201 auto model = FlatBufferModel::BuildFromFile(
202 "tensorflow/lite/testdata/multi_add_flex.bin");
203 ASSERT_TRUE(model);
204
205 // Note that creation will succeed when using the BuiltinOpResolver, but
206 // unless the appropriate delegate is linked into the target or the client
207 // explicitly installs the delegate, execution will fail.
208 std::unique_ptr<Interpreter> interpreter;
209 ASSERT_EQ(InterpreterBuilder(*model,
210 ops::builtin::BuiltinOpResolver{})(&interpreter),
211 kTfLiteOk);
212 ASSERT_TRUE(interpreter);
213
214 // As the flex ops weren't resolved implicitly by the flex delegate, runtime
215 // allocation and execution will fail.
216 ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
217 }
218
219 // This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
220 // buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel,TestBrokenMmap)221 TEST(BasicFlatBufferModel, TestBrokenMmap) {
222 ASSERT_FALSE(FlatBufferModel::BuildFromFile(
223 "tensorflow/lite/testdata/test_model_broken.bin"));
224 }
225
TEST(BasicFlatBufferModel,TestNullModel)226 TEST(BasicFlatBufferModel, TestNullModel) {
227 // Check that we get an error code and interpreter pointer is reset.
228 std::unique_ptr<Interpreter> interpreter(new Interpreter);
229 ASSERT_NE(
230 InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter),
231 kTfLiteOk);
232 ASSERT_EQ(interpreter.get(), nullptr);
233 }
234
235 // Mocks the verifier by setting the result in ctor.
236 class FakeVerifier : public tflite::TfLiteVerifier {
237 public:
FakeVerifier(bool result)238 explicit FakeVerifier(bool result) : result_(result) {}
Verify(const char * data,int length,tflite::ErrorReporter * reporter)239 bool Verify(const char* data, int length,
240 tflite::ErrorReporter* reporter) override {
241 return result_;
242 }
243
244 private:
245 bool result_;
246 };
247
TEST(BasicFlatBufferModel,TestWithTrueVerifier)248 TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
249 FakeVerifier verifier(true);
250 ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
251 "tensorflow/lite/testdata/test_model.bin",
252 &verifier));
253 }
254
TEST(BasicFlatBufferModel,TestWithFalseVerifier)255 TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
256 FakeVerifier verifier(false);
257 ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
258 "tensorflow/lite/testdata/test_model.bin",
259 &verifier));
260 }
261
TEST(BasicFlatBufferModel,TestWithNullVerifier)262 TEST(BasicFlatBufferModel, TestWithNullVerifier) {
263 ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
264 "tensorflow/lite/testdata/test_model.bin", nullptr));
265 }
266
267 // This makes sure the ErrorReporter is marshalled from FlatBufferModel to
268 // the Interpreter.
TEST(BasicFlatBufferModel,TestCustomErrorReporter)269 TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
270 TestErrorReporter reporter;
271 auto model = FlatBufferModel::BuildFromFile(
272 "tensorflow/lite/testdata/empty_model.bin",
273 &reporter);
274 ASSERT_TRUE(model);
275
276 std::unique_ptr<Interpreter> interpreter;
277 TrivialResolver resolver;
278 InterpreterBuilder(*model, resolver)(&interpreter);
279 ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
280 ASSERT_EQ(reporter.num_calls(), 1);
281 }
282
283 // This makes sure the ErrorReporter is marshalled from FlatBufferModel to
284 // the Interpreter.
TEST(BasicFlatBufferModel,TestNullErrorReporter)285 TEST(BasicFlatBufferModel, TestNullErrorReporter) {
286 auto model = FlatBufferModel::BuildFromFile(
287 "tensorflow/lite/testdata/empty_model.bin", nullptr);
288 ASSERT_TRUE(model);
289
290 std::unique_ptr<Interpreter> interpreter;
291 TrivialResolver resolver;
292 InterpreterBuilder(*model, resolver)(&interpreter);
293 ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
294 }
295
296 // Test that loading model directly from a Model flatbuffer works.
TEST(BasicFlatBufferModel,TestBuildFromModel)297 TEST(BasicFlatBufferModel, TestBuildFromModel) {
298 TestErrorReporter reporter;
299 FileCopyAllocation model_allocation(
300 "tensorflow/lite/testdata/test_model.bin", &reporter);
301 ASSERT_TRUE(model_allocation.valid());
302 ::flatbuffers::Verifier verifier(
303 reinterpret_cast<const uint8_t*>(model_allocation.base()),
304 model_allocation.bytes());
305 ASSERT_TRUE(VerifyModelBuffer(verifier));
306 const Model* model_fb = ::tflite::GetModel(model_allocation.base());
307
308 auto model = FlatBufferModel::BuildFromModel(model_fb);
309 ASSERT_TRUE(model);
310
311 std::unique_ptr<Interpreter> interpreter;
312 ASSERT_EQ(
313 InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
314 kTfLiteOk);
315 ASSERT_NE(interpreter, nullptr);
316 }
317
318 // TODO(aselle): Add tests for serialization of builtin op data types.
319 // These tests will occur with the evaluation tests of individual operators,
320 // not here.
321
322 } // namespace tflite
323
main(int argc,char ** argv)324 int main(int argc, char** argv) {
325 ::tflite::LogToStderr();
326 ::testing::InitGoogleTest(&argc, argv);
327 return RUN_ALL_TESTS();
328 }
329