1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <fcntl.h>
16 #include <stdint.h>
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <sys/mman.h>
20 #include <sys/stat.h>
21 #include <sys/types.h>
22 
23 #include "tensorflow/lite/model.h"
24 
25 #include <gtest/gtest.h>
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/kernels/register.h"
28 #include "tensorflow/lite/testing/util.h"
29 
30 // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
31 // we must declare this in global namespace, so argument-dependent operator
32 // lookup works.
operator ==(const TfLiteRegistration & a,const TfLiteRegistration & b)33 inline bool operator==(const TfLiteRegistration& a,
34                        const TfLiteRegistration& b) {
35   return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare &&
36          a.free == b.free;
37 }
38 
39 namespace tflite {
40 
41 // Provide a dummy operation that does nothing.
42 namespace {
dummy_init(TfLiteContext *,const char *,size_t)43 void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; }
dummy_free(TfLiteContext *,void *)44 void dummy_free(TfLiteContext*, void*) {}
dummy_resize(TfLiteContext *,TfLiteNode *)45 TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
dummy_invoke(TfLiteContext *,TfLiteNode *)46 TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; }
47 TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize,
48                                 dummy_invoke};
49 }  // namespace
50 
51 // Provide a trivial resolver that returns a constant value no matter what
52 // op is asked for.
53 class TrivialResolver : public OpResolver {
54  public:
TrivialResolver(TfLiteRegistration * constant_return=nullptr)55   explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr)
56       : constant_return_(constant_return) {}
57   // Find the op registration of a custom operator by op name.
FindOp(tflite::BuiltinOperator op,int version) const58   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
59                                    int version) const override {
60     return constant_return_;
61   }
62   // Find the op registration of a custom operator by op name.
FindOp(const char * op,int version) const63   const TfLiteRegistration* FindOp(const char* op, int version) const override {
64     return constant_return_;
65   }
66 
67  private:
68   TfLiteRegistration* constant_return_;
69 };
70 
TEST(BasicFlatBufferModel,TestNonExistantFiles)71 TEST(BasicFlatBufferModel, TestNonExistantFiles) {
72   ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234"));
73 }
74 
75 // Make sure a model with nothing in it loads properly.
TEST(BasicFlatBufferModel,TestEmptyModelsAndNullDestination)76 TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) {
77   auto model = FlatBufferModel::BuildFromFile(
78       "tensorflow/lite/testdata/empty_model.bin");
79   ASSERT_TRUE(model);
80   // Now try to build it into a model.
81   std::unique_ptr<Interpreter> interpreter;
82   ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter),
83             kTfLiteOk);
84   ASSERT_NE(interpreter, nullptr);
85   ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk);
86 }
87 
88 // Make sure currently unsupported # of subgraphs are checked
89 // TODO(aselle): Replace this test when multiple subgraphs are supported.
TEST(BasicFlatBufferModel,TestZeroSubgraphs)90 TEST(BasicFlatBufferModel, TestZeroSubgraphs) {
91   auto m = FlatBufferModel::BuildFromFile(
92       "tensorflow/lite/testdata/0_subgraphs.bin");
93   ASSERT_TRUE(m);
94   std::unique_ptr<Interpreter> interpreter;
95   ASSERT_NE(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
96 }
97 
TEST(BasicFlatBufferModel,TestMultipleSubgraphs)98 TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
99   auto m = FlatBufferModel::BuildFromFile(
100       "tensorflow/lite/testdata/2_subgraphs.bin");
101   ASSERT_TRUE(m);
102   std::unique_ptr<Interpreter> interpreter;
103   ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
104   EXPECT_EQ(interpreter->subgraphs_size(), 2);
105 }
106 
107 // Test what happens if we cannot bind any of the ops.
TEST(BasicFlatBufferModel,TestModelWithoutNullRegistrations)108 TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) {
109   auto model = FlatBufferModel::BuildFromFile(
110       "tensorflow/lite/testdata/test_model.bin");
111   ASSERT_TRUE(model);
112   // Check that we get an error code and interpreter pointer is reset.
113   std::unique_ptr<Interpreter> interpreter(new Interpreter);
114   ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter),
115             kTfLiteOk);
116   ASSERT_EQ(interpreter, nullptr);
117 }
118 
119 // Make sure model is read to interpreter properly
TEST(BasicFlatBufferModel,TestModelInInterpreter)120 TEST(BasicFlatBufferModel, TestModelInInterpreter) {
121   auto model = FlatBufferModel::BuildFromFile(
122       "tensorflow/lite/testdata/test_model.bin");
123   ASSERT_TRUE(model);
124   // Check that we get an error code and interpreter pointer is reset.
125   std::unique_ptr<Interpreter> interpreter(new Interpreter);
126   ASSERT_EQ(
127       InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
128       kTfLiteOk);
129   ASSERT_NE(interpreter, nullptr);
130   ASSERT_EQ(interpreter->tensors_size(), 4);
131   ASSERT_EQ(interpreter->nodes_size(), 2);
132   std::vector<int> inputs = {0, 1};
133   std::vector<int> outputs = {2, 3};
134   ASSERT_EQ(interpreter->inputs(), inputs);
135   ASSERT_EQ(interpreter->outputs(), outputs);
136 
137   EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0");
138   EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1");
139   EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1");
140   EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2");
141 
142   // Make sure all input tensors are correct
143   TfLiteTensor* i0 = interpreter->tensor(0);
144   ASSERT_EQ(i0->type, kTfLiteFloat32);
145   ASSERT_NE(i0->data.raw, nullptr);  // mmapped
146   ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo);
147   TfLiteTensor* i1 = interpreter->tensor(1);
148   ASSERT_EQ(i1->type, kTfLiteFloat32);
149   ASSERT_EQ(i1->data.raw, nullptr);
150   ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw);
151   TfLiteTensor* o0 = interpreter->tensor(2);
152   ASSERT_EQ(o0->type, kTfLiteFloat32);
153   ASSERT_EQ(o0->data.raw, nullptr);
154   ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw);
155   TfLiteTensor* o1 = interpreter->tensor(3);
156   ASSERT_EQ(o1->type, kTfLiteFloat32);
157   ASSERT_EQ(o1->data.raw, nullptr);
158   ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw);
159 
160   // Check op 0 which has inputs {0, 1} outputs {2}.
161   {
162     const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg0 =
163         interpreter->node_and_registration(0);
164     ASSERT_NE(node_and_reg0, nullptr);
165     const TfLiteNode& node0 = node_and_reg0->first;
166     const TfLiteRegistration& reg0 = node_and_reg0->second;
167     TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2);
168     desired_inputs->data[0] = 0;
169     desired_inputs->data[1] = 1;
170     TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
171     desired_outputs->data[0] = 2;
172     ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs));
173     ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs));
174     TfLiteIntArrayFree(desired_inputs);
175     TfLiteIntArrayFree(desired_outputs);
176     ASSERT_EQ(reg0, dummy_reg);
177   }
178 
179   // Check op 1 which has inputs {2} outputs {3}.
180   {
181     const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg1 =
182         interpreter->node_and_registration(1);
183     ASSERT_NE(node_and_reg1, nullptr);
184     const TfLiteNode& node1 = node_and_reg1->first;
185     const TfLiteRegistration& reg1 = node_and_reg1->second;
186     TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1);
187     TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1);
188     desired_inputs->data[0] = 2;
189     desired_outputs->data[0] = 3;
190     ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs));
191     ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs));
192     TfLiteIntArrayFree(desired_inputs);
193     TfLiteIntArrayFree(desired_outputs);
194     ASSERT_EQ(reg1, dummy_reg);
195   }
196 }
197 
198 // Test that loading a model with TensorFlow ops fails when the flex delegate is
199 // not linked into the target.
TEST(FlexModel,FailureWithoutFlexDelegate)200 TEST(FlexModel, FailureWithoutFlexDelegate) {
201   auto model = FlatBufferModel::BuildFromFile(
202       "tensorflow/lite/testdata/multi_add_flex.bin");
203   ASSERT_TRUE(model);
204 
205   // Note that creation will succeed when using the BuiltinOpResolver, but
206   // unless the appropriate delegate is linked into the target or the client
207   // explicitly installs the delegate, execution will fail.
208   std::unique_ptr<Interpreter> interpreter;
209   ASSERT_EQ(InterpreterBuilder(*model,
210                                ops::builtin::BuiltinOpResolver{})(&interpreter),
211             kTfLiteOk);
212   ASSERT_TRUE(interpreter);
213 
214   // As the flex ops weren't resolved implicitly by the flex delegate, runtime
215   // allocation and execution will fail.
216   ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
217 }
218 
219 // This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
220 // buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel,TestBrokenMmap)221 TEST(BasicFlatBufferModel, TestBrokenMmap) {
222   ASSERT_FALSE(FlatBufferModel::BuildFromFile(
223       "tensorflow/lite/testdata/test_model_broken.bin"));
224 }
225 
TEST(BasicFlatBufferModel,TestNullModel)226 TEST(BasicFlatBufferModel, TestNullModel) {
227   // Check that we get an error code and interpreter pointer is reset.
228   std::unique_ptr<Interpreter> interpreter(new Interpreter);
229   ASSERT_NE(
230       InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter),
231       kTfLiteOk);
232   ASSERT_EQ(interpreter.get(), nullptr);
233 }
234 
235 // Mocks the verifier by setting the result in ctor.
236 class FakeVerifier : public tflite::TfLiteVerifier {
237  public:
FakeVerifier(bool result)238   explicit FakeVerifier(bool result) : result_(result) {}
Verify(const char * data,int length,tflite::ErrorReporter * reporter)239   bool Verify(const char* data, int length,
240               tflite::ErrorReporter* reporter) override {
241     return result_;
242   }
243 
244  private:
245   bool result_;
246 };
247 
TEST(BasicFlatBufferModel,TestWithTrueVerifier)248 TEST(BasicFlatBufferModel, TestWithTrueVerifier) {
249   FakeVerifier verifier(true);
250   ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
251       "tensorflow/lite/testdata/test_model.bin",
252       &verifier));
253 }
254 
TEST(BasicFlatBufferModel,TestWithFalseVerifier)255 TEST(BasicFlatBufferModel, TestWithFalseVerifier) {
256   FakeVerifier verifier(false);
257   ASSERT_FALSE(FlatBufferModel::VerifyAndBuildFromFile(
258       "tensorflow/lite/testdata/test_model.bin",
259       &verifier));
260 }
261 
TEST(BasicFlatBufferModel,TestWithNullVerifier)262 TEST(BasicFlatBufferModel, TestWithNullVerifier) {
263   ASSERT_TRUE(FlatBufferModel::VerifyAndBuildFromFile(
264       "tensorflow/lite/testdata/test_model.bin", nullptr));
265 }
266 
267 // This makes sure the ErrorReporter is marshalled from FlatBufferModel to
268 // the Interpreter.
TEST(BasicFlatBufferModel,TestCustomErrorReporter)269 TEST(BasicFlatBufferModel, TestCustomErrorReporter) {
270   TestErrorReporter reporter;
271   auto model = FlatBufferModel::BuildFromFile(
272       "tensorflow/lite/testdata/empty_model.bin",
273       &reporter);
274   ASSERT_TRUE(model);
275 
276   std::unique_ptr<Interpreter> interpreter;
277   TrivialResolver resolver;
278   InterpreterBuilder(*model, resolver)(&interpreter);
279   ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
280   ASSERT_EQ(reporter.num_calls(), 1);
281 }
282 
283 // This makes sure the ErrorReporter is marshalled from FlatBufferModel to
284 // the Interpreter.
TEST(BasicFlatBufferModel,TestNullErrorReporter)285 TEST(BasicFlatBufferModel, TestNullErrorReporter) {
286   auto model = FlatBufferModel::BuildFromFile(
287       "tensorflow/lite/testdata/empty_model.bin", nullptr);
288   ASSERT_TRUE(model);
289 
290   std::unique_ptr<Interpreter> interpreter;
291   TrivialResolver resolver;
292   InterpreterBuilder(*model, resolver)(&interpreter);
293   ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
294 }
295 
296 // Test that loading model directly from a Model flatbuffer works.
TEST(BasicFlatBufferModel,TestBuildFromModel)297 TEST(BasicFlatBufferModel, TestBuildFromModel) {
298   TestErrorReporter reporter;
299   FileCopyAllocation model_allocation(
300       "tensorflow/lite/testdata/test_model.bin", &reporter);
301   ASSERT_TRUE(model_allocation.valid());
302   ::flatbuffers::Verifier verifier(
303       reinterpret_cast<const uint8_t*>(model_allocation.base()),
304       model_allocation.bytes());
305   ASSERT_TRUE(VerifyModelBuffer(verifier));
306   const Model* model_fb = ::tflite::GetModel(model_allocation.base());
307 
308   auto model = FlatBufferModel::BuildFromModel(model_fb);
309   ASSERT_TRUE(model);
310 
311   std::unique_ptr<Interpreter> interpreter;
312   ASSERT_EQ(
313       InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter),
314       kTfLiteOk);
315   ASSERT_NE(interpreter, nullptr);
316 }
317 
318 // TODO(aselle): Add tests for serialization of builtin op data types.
319 // These tests will occur with the evaluation tests of individual operators,
320 // not here.
321 
322 }  // namespace tflite
323 
main(int argc,char ** argv)324 int main(int argc, char** argv) {
325   ::tflite::LogToStderr();
326   ::testing::InitGoogleTest(&argc, argv);
327   return RUN_ALL_TESTS();
328 }
329