1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/experimental/micro/micro_interpreter.h"
16
17 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
18 #include "tensorflow/lite/experimental/micro/compatibility.h"
19
20 namespace tflite {
21 namespace {
22 const int kStackDataAllocatorSize = 128;
23 class StackDataAllocator : public BuiltinDataAllocator {
24 public:
Allocate(size_t size)25 void* Allocate(size_t size) override {
26 if (size > kStackDataAllocatorSize) {
27 return nullptr;
28 } else {
29 return data_;
30 }
31 }
Deallocate(void * data)32 void Deallocate(void* data) override {
33 // Do nothing.
34 }
35
36 private:
37 uint8_t data_[kStackDataAllocatorSize];
38
39 TF_LITE_REMOVE_VIRTUAL_DELETE
40 };
41
OpNameFromRegistration(const TfLiteRegistration * registration)42 const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
43 if (registration->builtin_code == BuiltinOperator_CUSTOM) {
44 return registration->custom_name;
45 } else {
46 return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
47 }
48 }
49
ReportOpError(struct TfLiteContext * context,const char * format,...)50 void ReportOpError(struct TfLiteContext* context, const char* format, ...) {
51 MicroInterpreter* interpreter =
52 static_cast<MicroInterpreter*>(context->impl_);
53 va_list args;
54 va_start(args, format);
55 interpreter->error_reporter()->Report(format, args);
56 va_end(args);
57 }
58
59 } // namespace
60
MicroInterpreter(const Model * model,const OpResolver & op_resolver,SimpleTensorAllocator * tensor_allocator,ErrorReporter * error_reporter)61 MicroInterpreter::MicroInterpreter(const Model* model,
62 const OpResolver& op_resolver,
63 SimpleTensorAllocator* tensor_allocator,
64 ErrorReporter* error_reporter)
65 : model_(model),
66 op_resolver_(op_resolver),
67 tensor_allocator_(tensor_allocator),
68 error_reporter_(error_reporter),
69 initialization_status_(kTfLiteOk) {
70 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
71 model->buffers();
72 auto* subgraphs = model->subgraphs();
73 if (subgraphs->size() != 1) {
74 error_reporter->Report("Only 1 subgraph is currently supported.\n");
75 initialization_status_ = kTfLiteError;
76 return;
77 }
78 subgraph_ = (*subgraphs)[0];
79 tensors_ = subgraph_->tensors();
80 operators_ = subgraph_->operators();
81
82 context_.tensors_size = tensors_->Length();
83 context_.tensors =
84 reinterpret_cast<TfLiteTensor*>(tensor_allocator_->AllocateMemory(
85 sizeof(TfLiteTensor) * context_.tensors_size, 4));
86 for (int i = 0; i < subgraph_->inputs()->Length(); ++i) {
87 const int tensor_index = subgraph_->inputs()->Get(i);
88 const auto* tensor = tensors_->Get(tensor_index);
89 initialization_status_ = tensor_allocator_->AllocateTensor(
90 *tensor, 0, operators_->Length(), buffers, error_reporter,
91 &context_.tensors[tensor_index]);
92 if (initialization_status_ != kTfLiteOk) {
93 return;
94 }
95 }
96
97 int* first_created = reinterpret_cast<int*>(tensor_allocator_->AllocateMemory(
98 sizeof(int) * tensors_->Length(), sizeof(int)));
99 int* last_used = reinterpret_cast<int*>(tensor_allocator_->AllocateMemory(
100 sizeof(int) * tensors_->Length(), sizeof(int)));
101 for (int i = 0; i < tensors_->Length(); ++i) {
102 first_created[i] = -1;
103 last_used[i] = -1;
104 }
105
106 for (int i = (operators_->Length() - 1); i >= 0; --i) {
107 const auto* op = operators_->Get(i);
108 for (int n = 0; n < op->inputs()->Length(); ++n) {
109 const int tensor_index = op->inputs()->Get(n);
110 if ((last_used[tensor_index] == -1) || (last_used[tensor_index] < i)) {
111 last_used[tensor_index] = i;
112 }
113 }
114 for (int n = 0; n < op->outputs()->Length(); ++n) {
115 const int tensor_index = op->outputs()->Get(n);
116 const int create_before = i;
117 int destroy_after = last_used[tensor_index];
118 if (destroy_after == -1) {
119 destroy_after = operators_->Length();
120 }
121 const auto* tensor = tensors_->Get(tensor_index);
122 if (!tensor->is_variable()) {
123 initialization_status_ = tensor_allocator_->AllocateTensor(
124 *tensor, create_before, destroy_after, buffers, error_reporter,
125 &context_.tensors[tensor_index]);
126 if (initialization_status_ != kTfLiteOk) {
127 return;
128 }
129 first_created[tensor_index] = i;
130 }
131 }
132 }
133
134 for (int i = 0; i < tensors_->Length(); ++i) {
135 const auto* tensor = tensors_->Get(i);
136 const bool is_read_only = (first_created[i] == -1) && (last_used[i] != -1);
137 if (tensor->is_variable() || is_read_only) {
138 initialization_status_ = tensor_allocator_->AllocateTensor(
139 *tensor, 0, operators_->Length(), buffers, error_reporter,
140 &context_.tensors[i]);
141 if (initialization_status_ != kTfLiteOk) {
142 return;
143 }
144 }
145 }
146 context_.impl_ = static_cast<void*>(this);
147 context_.GetExecutionPlan = nullptr;
148 context_.ResizeTensor = nullptr;
149 context_.ReportError = ReportOpError;
150 context_.AddTensors = nullptr;
151 context_.GetNodeAndRegistration = nullptr;
152 context_.ReplaceNodeSubsetsWithDelegateKernels = nullptr;
153 context_.recommended_num_threads = 1;
154 context_.GetExternalContext = nullptr;
155 context_.SetExternalContext = nullptr;
156 }
157
Invoke()158 TfLiteStatus MicroInterpreter::Invoke() {
159 if (initialization_status_ != kTfLiteOk) {
160 error_reporter_->Report("Invoke() called after initialization failed\n");
161 return kTfLiteError;
162 }
163 TfLiteStatus status = kTfLiteOk;
164 auto opcodes = model_->operator_codes();
165 for (int i = 0; i < operators_->Length(); ++i) {
166 const auto* op = operators_->Get(i);
167 int index = op->opcode_index();
168 if (index < 0 || index >= opcodes->size()) {
169 error_reporter_->Report("Missing registration for opcode_index %d\n",
170 index);
171 return kTfLiteError;
172 }
173 auto opcode = (*opcodes)[index];
174 const TfLiteRegistration* registration = nullptr;
175 status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
176 ®istration);
177 if (status != kTfLiteOk) {
178 return status;
179 }
180 if (registration == nullptr) {
181 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
182 return kTfLiteError;
183 }
184 BuiltinOperator op_type =
185 static_cast<BuiltinOperator>(registration->builtin_code);
186
187 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
188 error_reporter_->Report(
189 "Found builtin operator %s with custom options.\n",
190 EnumNameBuiltinOperator(op_type));
191 }
192 StackDataAllocator stack_data_allocator;
193 const char* custom_data = nullptr;
194 size_t custom_data_size = 0;
195 unsigned char* builtin_data = nullptr;
196 if (op->custom_options()) {
197 custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
198 custom_data_size = op->custom_options()->size();
199 } else {
200 TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
201 &stack_data_allocator,
202 (void**)(&builtin_data)));
203 }
204
205 const char* init_data;
206 size_t init_data_size;
207 if (registration->builtin_code == BuiltinOperator_CUSTOM) {
208 init_data = custom_data;
209 init_data_size = custom_data_size;
210 } else {
211 init_data = reinterpret_cast<const char*>(builtin_data);
212 init_data_size = 0;
213 }
214 void* user_data = nullptr;
215 if (registration->init) {
216 user_data = registration->init(&context_, init_data, init_data_size);
217 }
218
219 const int kMaxInputs = 16;
220 int inputs_data[kMaxInputs + 1];
221 TfLiteIntArray* inputs_array =
222 reinterpret_cast<TfLiteIntArray*>(inputs_data);
223 if (op->inputs()->Length() >= kMaxInputs) {
224 error_reporter_->Report("Too many inputs (%d)\n", op->inputs()->Length());
225 return kTfLiteError;
226 }
227 inputs_array->size = op->inputs()->Length();
228 for (int n = 0; n < op->inputs()->Length(); ++n) {
229 inputs_array->data[n] = op->inputs()->Get(n);
230 }
231
232 const int kMaxOutputs = 16;
233 int outputs_data[kMaxOutputs + 1];
234 TfLiteIntArray* outputs_array =
235 reinterpret_cast<TfLiteIntArray*>(outputs_data);
236 if (op->outputs()->Length() >= kMaxOutputs) {
237 error_reporter_->Report("Too many outputs (%d)\n",
238 op->outputs()->Length());
239 return kTfLiteError;
240 }
241 outputs_array->size = op->outputs()->Length();
242 for (int n = 0; n < op->outputs()->Length(); ++n) {
243 outputs_array->data[n] = op->outputs()->Get(n);
244 }
245
246 const int kMaxTemporaries = 16;
247 int temporaries_data[kMaxTemporaries + 1];
248 TfLiteIntArray* temporaries_array =
249 reinterpret_cast<TfLiteIntArray*>(temporaries_data);
250 temporaries_array->size = 0;
251
252 TfLiteNode node;
253 node.inputs = inputs_array;
254 node.outputs = outputs_array;
255 node.temporaries = temporaries_array;
256 node.user_data = user_data;
257 node.builtin_data = reinterpret_cast<void*>(builtin_data);
258 node.custom_initial_data = custom_data;
259 node.custom_initial_data_size = custom_data_size;
260 node.delegate = nullptr;
261 if (registration->prepare) {
262 TfLiteStatus prepare_status = registration->prepare(&context_, &node);
263 if (prepare_status != kTfLiteOk) {
264 error_reporter_->Report(
265 "Node %s (number %d) failed to prepare with status %d",
266 OpNameFromRegistration(registration), i, prepare_status);
267 return kTfLiteError;
268 }
269 }
270
271 if (registration->invoke) {
272 TfLiteStatus invoke_status = registration->invoke(&context_, &node);
273 if (invoke_status != kTfLiteOk) {
274 error_reporter_->Report(
275 "Node %s (number %d) failed to invoke with status %d",
276 OpNameFromRegistration(registration), i, invoke_status);
277 return kTfLiteError;
278 }
279 }
280
281 if (registration->free) {
282 registration->free(&context_, user_data);
283 }
284 }
285 return status;
286 }
287
input(int index)288 TfLiteTensor* MicroInterpreter::input(int index) {
289 const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs();
290 const size_t length = inputs->Length();
291 if ((index < 0) || (index >= length)) {
292 error_reporter_->Report("Input index %d out of range (length is %d)", index,
293 length);
294 return nullptr;
295 }
296 return &(context_.tensors[inputs->Get(index)]);
297 }
298
output(int index)299 TfLiteTensor* MicroInterpreter::output(int index) {
300 const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs();
301 const size_t length = outputs->Length();
302 if ((index < 0) || (index >= outputs->Length())) {
303 error_reporter_->Report("Output index %d out of range (length is %d)",
304 index, length);
305 return nullptr;
306 }
307 return &(context_.tensors[outputs->Get(index)]);
308 }
309
310 } // namespace tflite
311