1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/interpreter_builder.h"
16
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21
22 #include <algorithm>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
30 #include "tensorflow/lite/core/api/error_reporter.h"
31 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
32 #include "tensorflow/lite/core/api/op_resolver.h"
33 #include "tensorflow/lite/core/macros.h"
34 #include "tensorflow/lite/core/subgraph.h"
35 #include "tensorflow/lite/interpreter.h"
36 #include "tensorflow/lite/kernels/internal/compatibility.h"
37 #include "tensorflow/lite/model_builder.h"
38 #include "tensorflow/lite/profiling/platform_profiler.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/schema/schema_utils.h"
41 #include "tensorflow/lite/shared_library.h"
42 #include "tensorflow/lite/stderr_reporter.h"
43 #include "tensorflow/lite/string_type.h"
44 #include "tensorflow/lite/util.h"
45 #include "tensorflow/lite/version.h"
46
47 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
48 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
49 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
50 // Neither Apple nor Windows provide aligned_alloc.
51 #if !defined(__APPLE__) && !defined(_WIN32)
52 #define TFLITE_USE_STD_ALIGNED_ALLOC
53 #endif
54 #endif
55 #endif
56
57 // TODO(b/139446230): Move to portable platform header.
58 #if defined(__ANDROID__)
59 #define TFLITE_IS_MOBILE_PLATFORM
60 #endif // defined(__ANDROID__)
61
62 #if defined(__APPLE__)
63 #include "TargetConditionals.h"
64 #if TARGET_IPHONE_SIMULATOR
65 #define TFLITE_IS_MOBILE_PLATFORM
66 #elif TARGET_OS_IPHONE
67 #define TFLITE_IS_MOBILE_PLATFORM
68 #endif
69 #endif // defined(__APPLE__)
70
71 namespace tflite {
72
73 namespace {
74
75 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)76 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
77 return e ? e : DefaultErrorReporter();
78 }
79
80 template <typename T>
Copy(const T * data_ptr,TfLiteIntArray ** arr)81 TfLiteStatus Copy(const T* data_ptr, TfLiteIntArray** arr) {
82 if (data_ptr->values() == nullptr) {
83 return kTfLiteError;
84 }
85
86 int size = data_ptr->values()->size();
87 *arr = TfLiteIntArrayCreate(size);
88 for (int i = 0; i < size; i++) {
89 (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i));
90 }
91 return kTfLiteOk;
92 }
93
ParseSparseIndexVector(const DimensionMetadata * src,TfLiteDimensionMetadata * tgt)94 TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
95 TfLiteDimensionMetadata* tgt) {
96 if (src->array_segments() == nullptr || src->array_indices() == nullptr) {
97 return kTfLiteError;
98 }
99 TfLiteStatus status = kTfLiteOk;
100 switch (src->array_segments_type()) {
101 case SparseIndexVector_Int32Vector:
102 status = Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments);
103 break;
104 case SparseIndexVector_Uint16Vector:
105 status =
106 Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments);
107 break;
108 case SparseIndexVector_Uint8Vector:
109 status = Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments);
110 break;
111 default:
112 status = kTfLiteError;
113 break;
114 }
115 if (status != kTfLiteOk) return status;
116
117 switch (src->array_indices_type()) {
118 case SparseIndexVector_Int32Vector:
119 return Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices);
120 case SparseIndexVector_Uint16Vector:
121 return Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices);
122 case SparseIndexVector_Uint8Vector:
123 return Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices);
124 default:
125 break;
126 }
127 return kTfLiteError;
128 }
129
130 // Helper that returns std::map that corresponds to vector of TensorMap.
GetMapFromTensorMap(const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>> * tensor_map)131 std::map<std::string, uint32_t> GetMapFromTensorMap(
132 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
133 tensor_map) {
134 if (!tensor_map) return {};
135 std::map<std::string, uint32_t> result;
136 for (const auto tensor : *tensor_map) {
137 if (tensor != nullptr && tensor->name() != nullptr) {
138 result[tensor->name()->c_str()] = tensor->tensor_index();
139 }
140 }
141 return result;
142 }
143
144 } // namespace
145
146 const char* kEmptyTensorName = "";
147
148 // Using weak symbols to create a delegate allows automatic injection of the
149 // delegate simply by adding it as a dependency.
150 // For flex delegate, see also the strong override in
151 // lite/delegates/flex/delegate.cc.
AcquireFlexDelegate()152 TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
153 auto acquire_flex_delegate_func =
154 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
155 SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
156 if (acquire_flex_delegate_func) {
157 return acquire_flex_delegate_func();
158 }
159
160 #if !defined(TFLITE_IS_MOBILE_PLATFORM)
161 // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is
162 // available.
163 const char* filename_pywrap_tensorflow_internal =
164 #if defined(_WIN32)
165 "_pywrap_tensorflow_internal.pyd";
166 #elif defined(__APPLE__)
167 "python/_pywrap_tensorflow_internal.so";
168 #else
169 "_pywrap_tensorflow_internal.so";
170 #endif
171 void* lib_tf_internal =
172 SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
173 #if defined(_WIN32)
174 if (lib_tf_internal == nullptr) {
175 lib_tf_internal = SharedLibrary::LoadLibrary(
176 "_pywrap_tensorflow_interpreter_wrapper.pyd");
177 }
178 #endif
179 if (lib_tf_internal) {
180 acquire_flex_delegate_func =
181 reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
182 SharedLibrary::GetLibrarySymbol(lib_tf_internal,
183 "TF_AcquireFlexDelegate"));
184 if (acquire_flex_delegate_func) {
185 return acquire_flex_delegate_func();
186 }
187 }
188 #endif // !defined(TFLITE_IS_MOBILE_PLATFORM)
189
190 return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
191 }
192
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver)193 InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
194 const OpResolver& op_resolver)
195 : model_(model.GetModel()),
196 op_resolver_(op_resolver),
197 error_reporter_(ValidateErrorReporter(model.error_reporter())),
198 allocation_(model.allocation()) {}
199
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter)200 InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
201 const OpResolver& op_resolver,
202 ErrorReporter* error_reporter)
203 : model_(model),
204 op_resolver_(op_resolver),
205 error_reporter_(ValidateErrorReporter(error_reporter)) {}
206
~InterpreterBuilder()207 InterpreterBuilder::~InterpreterBuilder() {}
208
BuildLocalIndexToRegistrationMapping()209 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
210 TfLiteStatus status = kTfLiteOk;
211 // Reset state.
212 flatbuffer_op_index_to_registration_.clear();
213 unresolved_custom_ops_.clear();
214
215 auto opcodes = model_->operator_codes();
216 if (!opcodes) {
217 return status;
218 }
219 int num_custom_ops = 0;
220 for (const OperatorCode* opcode : *opcodes) {
221 if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) {
222 num_custom_ops++;
223 }
224 }
225 unresolved_custom_ops_.reserve(num_custom_ops);
226 for (const OperatorCode* opcode : *opcodes) {
227 const TfLiteRegistration* registration = nullptr;
228 status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
229 ®istration);
230 if (status != kTfLiteOk) {
231 if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) {
232 return status;
233 }
234 // If it's an unresolved custom op, allow it for now. It might be resolved
235 // by a delegate later.
236 if (!opcode->custom_code()) {
237 error_reporter_->Report(
238 "Operator with CUSTOM builtin_code has no custom_code.\n");
239 return status;
240 }
241 const auto* op_name = opcode->custom_code()->c_str();
242 unresolved_custom_ops_.push_back(CreateUnresolvedCustomOp(op_name));
243 registration = &unresolved_custom_ops_.back();
244 has_flex_op_ |= IsFlexOp(op_name);
245 status = kTfLiteOk;
246 }
247 flatbuffer_op_index_to_registration_.push_back(registration);
248 }
249 return status;
250 }
251
252 namespace {
253 template <class T>
FlatBufferIntArrayToVector(T * flat_array)254 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
255 // Initialize shape of tensors with null shape. Empty vectors are converted
256 // to nullptr for models that are constructed via flatbuffers::Pack.
257 if (flat_array == nullptr) {
258 return {};
259 }
260 std::vector<int> ret(flat_array->size());
261 for (int i = 0; i < flat_array->size(); i++) {
262 ret[i] = flat_array->Get(i);
263 }
264 return ret;
265 }
266
267 // Used to determine how the op data parsing function creates its working space.
268 class MallocDataAllocator : public BuiltinDataAllocator {
269 public:
Allocate(size_t size,size_t alignment_hint)270 void* Allocate(size_t size, size_t alignment_hint) override {
271 #ifdef TFLITE_USE_STD_ALIGNED_ALLOC
272 // Ensure that alignment is a power of two and a multiple of sizeof(void *)
273 // and that size is an integral multiple of alignment.
274 size_t used_alignment = std::max(alignment_hint, sizeof(void*));
275 size_t used_size =
276 ((size + used_alignment - 1) / used_alignment) * used_alignment;
277 TFLITE_DCHECK(
278 (used_alignment != 0) &&
279 ((used_alignment & (used_alignment - 1)) == 0)); // is power-of-two
280 return aligned_alloc(used_alignment, used_size);
281 #else
282 return malloc(size);
283 #endif
284 }
Deallocate(void * data)285 void Deallocate(void* data) override { free(data); }
286 };
287
288 } // namespace
289
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Subgraph * subgraph)290 TfLiteStatus InterpreterBuilder::ParseNodes(
291 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
292 Subgraph* subgraph) {
293 TfLiteStatus status = kTfLiteOk;
294
295 // Reduce the number of redundant allocations
296 subgraph->ReserveNodes(operators->size());
297
298 for (int i = 0; i < operators->size(); ++i) {
299 const auto* op = operators->Get(i);
300 int index = op->opcode_index();
301 if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
302 error_reporter_->Report("Missing registration for opcode_index %d\n",
303 index);
304 status = kTfLiteError;
305 continue;
306 }
307
308 const TfLiteRegistration* registration =
309 flatbuffer_op_index_to_registration_[index];
310 if (registration == nullptr) {
311 error_reporter_->Report("Skipping op for opcode_index %d\n", index);
312 status = kTfLiteError;
313 continue;
314 }
315
316 BuiltinOperator op_type =
317 static_cast<BuiltinOperator>(registration->builtin_code);
318
319 if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
320 error_reporter_->Report(
321 "Found builtin operator %s with custom options.\n",
322 EnumNameBuiltinOperator(op_type));
323 }
324
325 if (op_type == BuiltinOperator_CUSTOM) {
326 if (op->custom_options()) {
327 subgraph->AddNodeWithParameters(
328 FlatBufferIntArrayToVector(op->inputs()),
329 FlatBufferIntArrayToVector(op->outputs()),
330 FlatBufferIntArrayToVector(op->intermediates()),
331 reinterpret_cast<const char*>(op->custom_options()->data()),
332 op->custom_options()->size(), nullptr, registration);
333 } else {
334 subgraph->AddNodeWithParameters(
335 FlatBufferIntArrayToVector(op->inputs()),
336 FlatBufferIntArrayToVector(op->outputs()),
337 FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
338 nullptr, registration);
339 }
340 } else {
341 void* builtin_data = nullptr;
342 MallocDataAllocator malloc_allocator;
343 TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
344 &malloc_allocator, &builtin_data));
345 subgraph->AddNodeWithParameters(
346 FlatBufferIntArrayToVector(op->inputs()),
347 FlatBufferIntArrayToVector(op->outputs()),
348 FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
349 builtin_data, registration);
350 }
351 }
352
353 return status;
354 }
355
ParseQuantization(const QuantizationParameters * src_quantization,TfLiteQuantization * quantization,const std::vector<int> & dims)356 TfLiteStatus InterpreterBuilder::ParseQuantization(
357 const QuantizationParameters* src_quantization,
358 TfLiteQuantization* quantization, const std::vector<int>& dims) {
359 quantization->type = kTfLiteNoQuantization;
360 if (!src_quantization || !src_quantization->scale() ||
361 src_quantization->scale()->size() == 0) {
362 return kTfLiteOk;
363 }
364 if (!src_quantization->zero_point()) {
365 error_reporter_->Report(
366 "Quantization parameters has non-null scale but null zero_point.");
367 return kTfLiteError;
368 }
369
370 // Ensure that the number of scales matches the number of zero_points.
371 if (src_quantization->scale()->size() !=
372 src_quantization->zero_point()->size()) {
373 error_reporter_->Report(
374 "QuantizationParam has %d zero_point values and %d scale values. Must "
375 "have same number.",
376 src_quantization->zero_point()->size(),
377 src_quantization->scale()->size());
378 return kTfLiteError;
379 }
380
381 const size_t num_scales = src_quantization->scale()->size();
382
383 // Ensure that the quantization dimension is valid.
384 if (src_quantization->quantized_dimension() < 0 ||
385 (!dims.empty() &&
386 src_quantization->quantized_dimension() >= dims.size())) {
387 error_reporter_->Report(
388 "quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
389 src_quantization->quantized_dimension());
390 return kTfLiteError;
391 }
392
393 // Ensure that the number of scales is 1 for per-layer quantization, and
394 // matches number of quantization dimensions for per-axis quantization.
395 if (num_scales != 1 &&
396 (!dims.empty() &&
397 num_scales != dims[src_quantization->quantized_dimension()])) {
398 error_reporter_->Report(
399 "num_scales must be 1 for per-layer quantization, or %d for per-axis "
400 "quantization, but got %d.",
401 dims[src_quantization->quantized_dimension()], num_scales);
402 return kTfLiteError;
403 }
404
405 // Affine-quantization.
406 quantization->type = kTfLiteAffineQuantization;
407 auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
408 malloc(sizeof(TfLiteAffineQuantization)));
409 affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
410 affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
411 for (size_t i = 0; i < num_scales; ++i) {
412 affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
413 affine_quantization->zero_point->data[i] =
414 src_quantization->zero_point()->Get(i);
415 }
416 affine_quantization->quantized_dimension =
417 src_quantization->quantized_dimension();
418 quantization->params = reinterpret_cast<void*>(affine_quantization);
419 return kTfLiteOk;
420 }
421
ParseSparsity(const SparsityParameters * src_sparsity,TfLiteSparsity ** sparsity_ptr)422 TfLiteStatus InterpreterBuilder::ParseSparsity(
423 const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) {
424 if (!src_sparsity) {
425 return kTfLiteOk;
426 }
427
428 if (src_sparsity->traversal_order() == nullptr ||
429 src_sparsity->dim_metadata() == nullptr) {
430 error_reporter_->Report("Invalid sparsity parameter.");
431 return kTfLiteError;
432 }
433
434 auto* sparsity =
435 reinterpret_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
436 memset(sparsity, 0, sizeof(TfLiteSparsity));
437 *sparsity_ptr = sparsity;
438
439 const size_t traversal_order_size = src_sparsity->traversal_order()->size();
440 sparsity->traversal_order = TfLiteIntArrayCreate(traversal_order_size);
441 for (int i = 0; i < traversal_order_size; i++) {
442 sparsity->traversal_order->data[i] =
443 src_sparsity->traversal_order()->Get(i);
444 }
445
446 if (src_sparsity->block_map()) {
447 const size_t block_map_size = src_sparsity->block_map()->size();
448 sparsity->block_map = TfLiteIntArrayCreate(block_map_size);
449 for (int i = 0; i < block_map_size; i++) {
450 sparsity->block_map->data[i] = src_sparsity->block_map()->Get(i);
451 }
452 }
453
454 const size_t dim_metadata_size = src_sparsity->dim_metadata()->size();
455 sparsity->dim_metadata_size = dim_metadata_size;
456 sparsity->dim_metadata = reinterpret_cast<TfLiteDimensionMetadata*>(
457 malloc(dim_metadata_size * sizeof(TfLiteDimensionMetadata)));
458 memset(sparsity->dim_metadata, 0,
459 dim_metadata_size * sizeof(TfLiteDimensionMetadata));
460
461 for (int i = 0; i < dim_metadata_size; i++) {
462 const auto* src_metadata = src_sparsity->dim_metadata()->Get(i);
463 if (src_metadata->format() != DimensionType_DENSE &&
464 src_metadata->format() != DimensionType_SPARSE_CSR) {
465 TF_LITE_REPORT_ERROR(error_reporter_,
466 "The %dth dimension has unknown type: %d.", i,
467 src_metadata->format());
468 return kTfLiteError;
469 }
470 auto* tgt_metadata = &sparsity->dim_metadata[i];
471
472 tgt_metadata->format =
473 static_cast<TfLiteDimensionType>(src_metadata->format());
474
475 if (tgt_metadata->format == kTfLiteDimDense) {
476 tgt_metadata->dense_size = src_metadata->dense_size();
477 } else {
478 if (ParseSparseIndexVector(src_metadata, tgt_metadata) != kTfLiteOk) {
479 TF_LITE_REPORT_ERROR(
480 error_reporter_,
481 "The %dth sparse dimension has invalid parameters.", i);
482 return kTfLiteError;
483 }
484 }
485 }
486
487 return kTfLiteOk;
488 }
489
ParseSignatureDefs(const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>> * signature_def_list,Interpreter * interpreter)490 TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
491 const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
492 signature_def_list,
493 Interpreter* interpreter) {
494 if (signature_def_list == nullptr || signature_def_list->size() == 0) {
495 return kTfLiteOk;
496 }
497 std::vector<Interpreter::SignatureDef> signature_defs;
498 signature_defs.reserve(signature_def_list->size());
499 for (const auto fb_signature_def : *signature_def_list) {
500 if (fb_signature_def == nullptr) {
501 TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
502 return kTfLiteError;
503 }
504 if (fb_signature_def->method_name() == nullptr) {
505 TF_LITE_REPORT_ERROR(error_reporter_,
506 "Missing exported method name for SignatureDef");
507 return kTfLiteError;
508 }
509 if (fb_signature_def->inputs() == nullptr) {
510 TF_LITE_REPORT_ERROR(error_reporter_,
511 "NULL SignatureDef inputs for exported method %s",
512 fb_signature_def->method_name()->c_str());
513 return kTfLiteError;
514 }
515 if (fb_signature_def->outputs() == nullptr) {
516 TF_LITE_REPORT_ERROR(error_reporter_,
517 "NULL SignatureDef outputs for exported method %s",
518 fb_signature_def->method_name()->c_str());
519 return kTfLiteError;
520 }
521 signature_defs.resize(signature_defs.size() + 1);
522 auto& signature_def = signature_defs.back();
523 signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
524 signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
525 signature_def.method_name = fb_signature_def->method_name()->c_str();
526 if (fb_signature_def->key() != nullptr) {
527 signature_def.signature_def_key = fb_signature_def->key()->c_str();
528 }
529 }
530 interpreter->SetSignatureDef(std::move(signature_defs));
531 return kTfLiteOk;
532 }
533
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Subgraph * subgraph)534 TfLiteStatus InterpreterBuilder::ParseTensors(
535 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
536 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
537 Subgraph* subgraph) {
538 TfLiteStatus status = kTfLiteOk;
539
540 // A little helper to get the names of inputs and outputs. Note that they
541 // must outlive the subgraph.
542 auto get_name = [](const tflite::Tensor* t) -> const char* {
543 auto name = t->name();
544 if (name) return name->c_str();
545 return kEmptyTensorName;
546 };
547
548 num_fp32_tensors_ = 0;
549 for (int i = 0; i < tensors->size(); ++i) {
550 const auto* tensor = tensors->Get(i);
551 std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
552
553 TfLiteType type;
554 if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
555 kTfLiteOk) {
556 status = kTfLiteError;
557 continue;
558 }
559 if (type == kTfLiteFloat32) {
560 ++num_fp32_tensors_;
561 }
562 auto get_readonly_data = [&](const char** buffer_data,
563 size_t* buffer_size) {
564 // TODO(aselle): Check what happens if we have an unspecified size
565 // constant.
566 *buffer_data = nullptr;
567 if (tensor->buffer() == 0) return kTfLiteOk;
568 if (tensor->buffer() >= buffers->size()) {
569 error_reporter_->Report(
570 "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
571 i, tensor->buffer(), buffers->size());
572 return kTfLiteError;
573 }
574 if (auto* buffer = (*buffers)[tensor->buffer()]) {
575 if (auto* array = buffer->data()) {
576 if (size_t size = array->size()) {
577 *buffer_size = size;
578 *buffer_data = reinterpret_cast<const char*>(array->data());
579 return kTfLiteOk;
580 }
581 }
582 }
583 return kTfLiteOk;
584 };
585 size_t buffer_size = 0;
586 const char* buffer_ptr;
587 TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
588
589 const auto* src_quantization = tensor->quantization();
590 TfLiteQuantization quantization;
591 if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
592 error_reporter_->Report("Tensor %d has invalid quantization parameters.",
593 i);
594 status = kTfLiteError;
595 }
596
597 std::vector<int> dims_signature = {};
598 if (tensor->shape_signature()) {
599 dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
600 }
601
602 bool is_variable = tensor->is_variable();
603 if (buffer_ptr) {
604 if (is_variable) {
605 error_reporter_->Report(
606 "Tensor %d is a variable tensor with buffer. "
607 "It's not supported now.\n",
608 i);
609 status = kTfLiteError;
610 }
611
612 // TODO(b/144999664): Only constant sparse tensor is supported now.
613 const auto* src_sparsity = tensor->sparsity();
614 TfLiteSparsity* sparsity = nullptr;
615 if (ParseSparsity(src_sparsity, &sparsity) != kTfLiteOk) {
616 error_reporter_->Report("Tensor %d has invalid sparsity parameters.",
617 i);
618 status = kTfLiteError;
619 }
620
621 if (subgraph->SetTensorParametersReadOnly(
622 i, type, get_name(tensor), dims, quantization, buffer_ptr,
623 buffer_size, allocation_, sparsity) != kTfLiteOk) {
624 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
625 i);
626 status = kTfLiteError;
627 }
628 } else {
629 if (subgraph->SetTensorParametersReadWrite(
630 i, type, get_name(tensor), dims, quantization, is_variable,
631 dims_signature) != kTfLiteOk) {
632 error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
633 i);
634 status = kTfLiteError;
635 }
636 }
637 }
638
639 return status;
640 }
641
ApplyDelegates(Interpreter * interpreter,int num_threads)642 TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter,
643 int num_threads) {
644 // Apply Flex delegate if applicable.
645 if (has_flex_op_) {
646 if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
647 TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(
648 // Transfers ownership of flex_delegate to the interpreter.
649 std::move(flex_delegate)));
650 }
651 }
652 for (TfLiteDelegate* delegate : delegates_) {
653 // Note that we DON'T transfer ownership of the delegate to the interpreter.
654 // (Doing that would cause problems if operator() was invoked twice.)
655 TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(delegate));
656 }
657 return kTfLiteOk;
658 }
659
operator ()(std::unique_ptr<Interpreter> * interpreter)660 TfLiteStatus InterpreterBuilder::operator()(
661 std::unique_ptr<Interpreter>* interpreter) {
662 return operator()(interpreter, /*num_threads=*/-1);
663 }
664
operator ()(std::unique_ptr<Interpreter> * interpreter,int num_threads)665 TfLiteStatus InterpreterBuilder::operator()(
666 std::unique_ptr<Interpreter>* interpreter, int num_threads) {
667 if (!interpreter) {
668 error_reporter_->Report(
669 "Null output pointer passed to InterpreterBuilder.");
670 return kTfLiteError;
671 }
672
673 if (num_threads < -1) {
674 error_reporter_->Report(
675 "num_threads should be >=0 or just -1 to let TFLite runtime set the "
676 "value.");
677 return kTfLiteError;
678 }
679
680 // Safe exit by deleting partially created interpreter, to reduce verbosity
681 // on error conditions. Use by return cleanup_on_error();
682 auto cleanup_and_error = [&interpreter]() {
683 interpreter->reset();
684 return kTfLiteError;
685 };
686
687 if (!model_) {
688 error_reporter_->Report("Null pointer passed in as model.");
689 return cleanup_and_error();
690 }
691
692 if (model_->version() != TFLITE_SCHEMA_VERSION) {
693 error_reporter_->Report(
694 "Model provided is schema version %d not equal "
695 "to supported version %d.\n",
696 model_->version(), TFLITE_SCHEMA_VERSION);
697 return cleanup_and_error();
698 }
699
700 if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
701 error_reporter_->Report("Registration failed.\n");
702 return cleanup_and_error();
703 }
704
705 // Flatbuffer model schemas define a list of opcodes independent of the graph.
706 // We first map those to registrations. This reduces string lookups for custom
707 // ops since we only do it once per custom op rather than once per custom op
708 // invocation in the model graph.
709 // Construct interpreter with correct number of tensors and operators.
710 auto* subgraphs = model_->subgraphs();
711 auto* buffers = model_->buffers();
712
713 if (subgraphs->size() == 0) {
714 TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n");
715 return cleanup_and_error();
716 }
717
718 if (!buffers) {
719 TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n");
720 return cleanup_and_error();
721 }
722
723 interpreter->reset(new Interpreter(error_reporter_));
724 (*interpreter)->SetNumThreads(num_threads);
725 if (subgraphs->size() > 1) {
726 (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
727 }
728
729 (*interpreter)->SetProfiler(tflite::profiling::MaybeCreatePlatformProfiler());
730
731 for (int subgraph_index = 0; subgraph_index < subgraphs->size();
732 ++subgraph_index) {
733 const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
734 tflite::Subgraph* modified_subgraph =
735 (*interpreter)->subgraph(subgraph_index);
736 auto operators = subgraph->operators();
737 auto tensors = subgraph->tensors();
738 if (!operators || !tensors) {
739 TF_LITE_REPORT_ERROR(error_reporter_,
740 "Did not get operators or tensors in subgraph %d.\n",
741 subgraph_index);
742 return cleanup_and_error();
743 }
744 if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
745 return cleanup_and_error();
746 }
747 // Set num threads
748 // Parse inputs/outputs
749 modified_subgraph->SetInputs(
750 FlatBufferIntArrayToVector(subgraph->inputs()));
751 modified_subgraph->SetOutputs(
752 FlatBufferIntArrayToVector(subgraph->outputs()));
753
754 // Finally setup nodes and tensors
755 if (ParseNodes(operators, modified_subgraph) != kTfLiteOk)
756 return cleanup_and_error();
757 if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
758 return cleanup_and_error();
759
760 std::vector<int> variables;
761 for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
762 auto* tensor = modified_subgraph->tensor(i);
763 if (tensor->is_variable) {
764 variables.push_back(i);
765 }
766 }
767 modified_subgraph->SetVariables(std::move(variables));
768 if (subgraph->name()) {
769 modified_subgraph->SetName(subgraph->name()->c_str());
770 }
771 }
772
773 if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
774 kTfLiteOk) {
775 return cleanup_and_error();
776 }
777
778 if (num_fp32_tensors_ > 0) {
779 (*interpreter)->lazy_delegate_providers_ =
780 op_resolver_.GetDelegates(num_threads);
781 }
782
783 TfLiteStatus status = ApplyDelegates(interpreter->get(), num_threads);
784 if (status != kTfLiteOk) {
785 interpreter->reset();
786 }
787 return status;
788 }
789
AddDelegate(TfLiteDelegate * delegate)790 void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
791 if (delegate == nullptr) {
792 TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
793 } else {
794 delegates_.push_back(delegate);
795 }
796 }
797
798 } // namespace tflite
799