1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/interpreter_builder.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 
22 #include <algorithm>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
30 #include "tensorflow/lite/core/api/error_reporter.h"
31 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
32 #include "tensorflow/lite/core/api/op_resolver.h"
33 #include "tensorflow/lite/core/macros.h"
34 #include "tensorflow/lite/core/subgraph.h"
35 #include "tensorflow/lite/interpreter.h"
36 #include "tensorflow/lite/kernels/internal/compatibility.h"
37 #include "tensorflow/lite/model_builder.h"
38 #include "tensorflow/lite/profiling/platform_profiler.h"
39 #include "tensorflow/lite/schema/schema_generated.h"
40 #include "tensorflow/lite/schema/schema_utils.h"
41 #include "tensorflow/lite/shared_library.h"
42 #include "tensorflow/lite/stderr_reporter.h"
43 #include "tensorflow/lite/string_type.h"
44 #include "tensorflow/lite/util.h"
45 #include "tensorflow/lite/version.h"
46 
47 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
48 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
49 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
50 // Neither Apple nor Windows provide aligned_alloc.
51 #if !defined(__APPLE__) && !defined(_WIN32)
52 #define TFLITE_USE_STD_ALIGNED_ALLOC
53 #endif
54 #endif
55 #endif
56 
57 // TODO(b/139446230): Move to portable platform header.
58 #if defined(__ANDROID__)
59 #define TFLITE_IS_MOBILE_PLATFORM
60 #endif  // defined(__ANDROID__)
61 
62 #if defined(__APPLE__)
63 #include "TargetConditionals.h"
64 #if TARGET_IPHONE_SIMULATOR
65 #define TFLITE_IS_MOBILE_PLATFORM
66 #elif TARGET_OS_IPHONE
67 #define TFLITE_IS_MOBILE_PLATFORM
68 #endif
69 #endif  // defined(__APPLE__)
70 
71 namespace tflite {
72 
73 namespace {
74 
75 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)76 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
77   return e ? e : DefaultErrorReporter();
78 }
79 
80 template <typename T>
Copy(const T * data_ptr,TfLiteIntArray ** arr)81 TfLiteStatus Copy(const T* data_ptr, TfLiteIntArray** arr) {
82   if (data_ptr->values() == nullptr) {
83     return kTfLiteError;
84   }
85 
86   int size = data_ptr->values()->size();
87   *arr = TfLiteIntArrayCreate(size);
88   for (int i = 0; i < size; i++) {
89     (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i));
90   }
91   return kTfLiteOk;
92 }
93 
ParseSparseIndexVector(const DimensionMetadata * src,TfLiteDimensionMetadata * tgt)94 TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
95                                     TfLiteDimensionMetadata* tgt) {
96   if (src->array_segments() == nullptr || src->array_indices() == nullptr) {
97     return kTfLiteError;
98   }
99   TfLiteStatus status = kTfLiteOk;
100   switch (src->array_segments_type()) {
101     case SparseIndexVector_Int32Vector:
102       status = Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments);
103       break;
104     case SparseIndexVector_Uint16Vector:
105       status =
106           Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments);
107       break;
108     case SparseIndexVector_Uint8Vector:
109       status = Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments);
110       break;
111     default:
112       status = kTfLiteError;
113       break;
114   }
115   if (status != kTfLiteOk) return status;
116 
117   switch (src->array_indices_type()) {
118     case SparseIndexVector_Int32Vector:
119       return Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices);
120     case SparseIndexVector_Uint16Vector:
121       return Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices);
122     case SparseIndexVector_Uint8Vector:
123       return Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices);
124     default:
125       break;
126   }
127   return kTfLiteError;
128 }
129 
130 // Helper that returns std::map that corresponds to vector of TensorMap.
GetMapFromTensorMap(const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>> * tensor_map)131 std::map<std::string, uint32_t> GetMapFromTensorMap(
132     const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
133         tensor_map) {
134   if (!tensor_map) return {};
135   std::map<std::string, uint32_t> result;
136   for (const auto tensor : *tensor_map) {
137     if (tensor != nullptr && tensor->name() != nullptr) {
138       result[tensor->name()->c_str()] = tensor->tensor_index();
139     }
140   }
141   return result;
142 }
143 
144 }  // namespace
145 
146 const char* kEmptyTensorName = "";
147 
148 // Using weak symbols to create a delegate allows automatic injection of the
149 // delegate simply by adding it as a dependency.
150 // For flex delegate, see also the strong override in
151 // lite/delegates/flex/delegate.cc.
AcquireFlexDelegate()152 TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
153   auto acquire_flex_delegate_func =
154       reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
155           SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
156   if (acquire_flex_delegate_func) {
157     return acquire_flex_delegate_func();
158   }
159 
160 #if !defined(TFLITE_IS_MOBILE_PLATFORM)
161   // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is
162   // available.
163   const char* filename_pywrap_tensorflow_internal =
164 #if defined(_WIN32)
165       "_pywrap_tensorflow_internal.pyd";
166 #elif defined(__APPLE__)
167       "python/_pywrap_tensorflow_internal.so";
168 #else
169       "_pywrap_tensorflow_internal.so";
170 #endif
171   void* lib_tf_internal =
172       SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
173 #if defined(_WIN32)
174   if (lib_tf_internal == nullptr) {
175     lib_tf_internal = SharedLibrary::LoadLibrary(
176         "_pywrap_tensorflow_interpreter_wrapper.pyd");
177   }
178 #endif
179   if (lib_tf_internal) {
180     acquire_flex_delegate_func =
181         reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
182             SharedLibrary::GetLibrarySymbol(lib_tf_internal,
183                                             "TF_AcquireFlexDelegate"));
184     if (acquire_flex_delegate_func) {
185       return acquire_flex_delegate_func();
186     }
187   }
188 #endif  // !defined(TFLITE_IS_MOBILE_PLATFORM)
189 
190   return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
191 }
192 
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver)193 InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
194                                        const OpResolver& op_resolver)
195     : model_(model.GetModel()),
196       op_resolver_(op_resolver),
197       error_reporter_(ValidateErrorReporter(model.error_reporter())),
198       allocation_(model.allocation()) {}
199 
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter)200 InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
201                                        const OpResolver& op_resolver,
202                                        ErrorReporter* error_reporter)
203     : model_(model),
204       op_resolver_(op_resolver),
205       error_reporter_(ValidateErrorReporter(error_reporter)) {}
206 
~InterpreterBuilder()207 InterpreterBuilder::~InterpreterBuilder() {}
208 
BuildLocalIndexToRegistrationMapping()209 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
210   TfLiteStatus status = kTfLiteOk;
211   // Reset state.
212   flatbuffer_op_index_to_registration_.clear();
213   unresolved_custom_ops_.clear();
214 
215   auto opcodes = model_->operator_codes();
216   if (!opcodes) {
217     return status;
218   }
219   int num_custom_ops = 0;
220   for (const OperatorCode* opcode : *opcodes) {
221     if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) {
222       num_custom_ops++;
223     }
224   }
225   unresolved_custom_ops_.reserve(num_custom_ops);
226   for (const OperatorCode* opcode : *opcodes) {
227     const TfLiteRegistration* registration = nullptr;
228     status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
229                                        &registration);
230     if (status != kTfLiteOk) {
231       if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) {
232         return status;
233       }
234       // If it's an unresolved custom op, allow it for now. It might be resolved
235       // by a delegate later.
236       if (!opcode->custom_code()) {
237         error_reporter_->Report(
238             "Operator with CUSTOM builtin_code has no custom_code.\n");
239         return status;
240       }
241       const auto* op_name = opcode->custom_code()->c_str();
242       unresolved_custom_ops_.push_back(CreateUnresolvedCustomOp(op_name));
243       registration = &unresolved_custom_ops_.back();
244       has_flex_op_ |= IsFlexOp(op_name);
245       status = kTfLiteOk;
246     }
247     flatbuffer_op_index_to_registration_.push_back(registration);
248   }
249   return status;
250 }
251 
252 namespace {
253 template <class T>
FlatBufferIntArrayToVector(T * flat_array)254 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
255   // Initialize shape of tensors with null shape. Empty vectors are converted
256   // to nullptr for models that are constructed via flatbuffers::Pack.
257   if (flat_array == nullptr) {
258     return {};
259   }
260   std::vector<int> ret(flat_array->size());
261   for (int i = 0; i < flat_array->size(); i++) {
262     ret[i] = flat_array->Get(i);
263   }
264   return ret;
265 }
266 
267 // Used to determine how the op data parsing function creates its working space.
268 class MallocDataAllocator : public BuiltinDataAllocator {
269  public:
Allocate(size_t size,size_t alignment_hint)270   void* Allocate(size_t size, size_t alignment_hint) override {
271 #ifdef TFLITE_USE_STD_ALIGNED_ALLOC
272     // Ensure that alignment is a power of two and a multiple of sizeof(void *)
273     // and that size is an integral multiple of alignment.
274     size_t used_alignment = std::max(alignment_hint, sizeof(void*));
275     size_t used_size =
276         ((size + used_alignment - 1) / used_alignment) * used_alignment;
277     TFLITE_DCHECK(
278         (used_alignment != 0) &&
279         ((used_alignment & (used_alignment - 1)) == 0));  // is power-of-two
280     return aligned_alloc(used_alignment, used_size);
281 #else
282     return malloc(size);
283 #endif
284   }
Deallocate(void * data)285   void Deallocate(void* data) override { free(data); }
286 };
287 
288 }  // namespace
289 
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Subgraph * subgraph)290 TfLiteStatus InterpreterBuilder::ParseNodes(
291     const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
292     Subgraph* subgraph) {
293   TfLiteStatus status = kTfLiteOk;
294 
295   // Reduce the number of redundant allocations
296   subgraph->ReserveNodes(operators->size());
297 
298   for (int i = 0; i < operators->size(); ++i) {
299     const auto* op = operators->Get(i);
300     int index = op->opcode_index();
301     if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
302       error_reporter_->Report("Missing registration for opcode_index %d\n",
303                               index);
304       status = kTfLiteError;
305       continue;
306     }
307 
308     const TfLiteRegistration* registration =
309         flatbuffer_op_index_to_registration_[index];
310     if (registration == nullptr) {
311       error_reporter_->Report("Skipping op for opcode_index %d\n", index);
312       status = kTfLiteError;
313       continue;
314     }
315 
316     BuiltinOperator op_type =
317         static_cast<BuiltinOperator>(registration->builtin_code);
318 
319     if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
320       error_reporter_->Report(
321           "Found builtin operator %s with custom options.\n",
322           EnumNameBuiltinOperator(op_type));
323     }
324 
325     if (op_type == BuiltinOperator_CUSTOM) {
326       if (op->custom_options()) {
327         subgraph->AddNodeWithParameters(
328             FlatBufferIntArrayToVector(op->inputs()),
329             FlatBufferIntArrayToVector(op->outputs()),
330             FlatBufferIntArrayToVector(op->intermediates()),
331             reinterpret_cast<const char*>(op->custom_options()->data()),
332             op->custom_options()->size(), nullptr, registration);
333       } else {
334         subgraph->AddNodeWithParameters(
335             FlatBufferIntArrayToVector(op->inputs()),
336             FlatBufferIntArrayToVector(op->outputs()),
337             FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
338             nullptr, registration);
339       }
340     } else {
341       void* builtin_data = nullptr;
342       MallocDataAllocator malloc_allocator;
343       TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
344                                         &malloc_allocator, &builtin_data));
345       subgraph->AddNodeWithParameters(
346           FlatBufferIntArrayToVector(op->inputs()),
347           FlatBufferIntArrayToVector(op->outputs()),
348           FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
349           builtin_data, registration);
350     }
351   }
352 
353   return status;
354 }
355 
ParseQuantization(const QuantizationParameters * src_quantization,TfLiteQuantization * quantization,const std::vector<int> & dims)356 TfLiteStatus InterpreterBuilder::ParseQuantization(
357     const QuantizationParameters* src_quantization,
358     TfLiteQuantization* quantization, const std::vector<int>& dims) {
359   quantization->type = kTfLiteNoQuantization;
360   if (!src_quantization || !src_quantization->scale() ||
361       src_quantization->scale()->size() == 0) {
362     return kTfLiteOk;
363   }
364   if (!src_quantization->zero_point()) {
365     error_reporter_->Report(
366         "Quantization parameters has non-null scale but null zero_point.");
367     return kTfLiteError;
368   }
369 
370   // Ensure that the number of scales matches the number of zero_points.
371   if (src_quantization->scale()->size() !=
372       src_quantization->zero_point()->size()) {
373     error_reporter_->Report(
374         "QuantizationParam has %d zero_point values and %d scale values. Must "
375         "have same number.",
376         src_quantization->zero_point()->size(),
377         src_quantization->scale()->size());
378     return kTfLiteError;
379   }
380 
381   const size_t num_scales = src_quantization->scale()->size();
382 
383   // Ensure that the quantization dimension is valid.
384   if (src_quantization->quantized_dimension() < 0 ||
385       (!dims.empty() &&
386        src_quantization->quantized_dimension() >= dims.size())) {
387     error_reporter_->Report(
388         "quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
389         src_quantization->quantized_dimension());
390     return kTfLiteError;
391   }
392 
393   // Ensure that the number of scales is 1 for per-layer quantization, and
394   // matches number of quantization dimensions for per-axis quantization.
395   if (num_scales != 1 &&
396       (!dims.empty() &&
397        num_scales != dims[src_quantization->quantized_dimension()])) {
398     error_reporter_->Report(
399         "num_scales must be 1 for per-layer quantization, or %d for per-axis "
400         "quantization, but got %d.",
401         dims[src_quantization->quantized_dimension()], num_scales);
402     return kTfLiteError;
403   }
404 
405   // Affine-quantization.
406   quantization->type = kTfLiteAffineQuantization;
407   auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
408       malloc(sizeof(TfLiteAffineQuantization)));
409   affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
410   affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
411   for (size_t i = 0; i < num_scales; ++i) {
412     affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
413     affine_quantization->zero_point->data[i] =
414         src_quantization->zero_point()->Get(i);
415   }
416   affine_quantization->quantized_dimension =
417       src_quantization->quantized_dimension();
418   quantization->params = reinterpret_cast<void*>(affine_quantization);
419   return kTfLiteOk;
420 }
421 
ParseSparsity(const SparsityParameters * src_sparsity,TfLiteSparsity ** sparsity_ptr)422 TfLiteStatus InterpreterBuilder::ParseSparsity(
423     const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) {
424   if (!src_sparsity) {
425     return kTfLiteOk;
426   }
427 
428   if (src_sparsity->traversal_order() == nullptr ||
429       src_sparsity->dim_metadata() == nullptr) {
430     error_reporter_->Report("Invalid sparsity parameter.");
431     return kTfLiteError;
432   }
433 
434   auto* sparsity =
435       reinterpret_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
436   memset(sparsity, 0, sizeof(TfLiteSparsity));
437   *sparsity_ptr = sparsity;
438 
439   const size_t traversal_order_size = src_sparsity->traversal_order()->size();
440   sparsity->traversal_order = TfLiteIntArrayCreate(traversal_order_size);
441   for (int i = 0; i < traversal_order_size; i++) {
442     sparsity->traversal_order->data[i] =
443         src_sparsity->traversal_order()->Get(i);
444   }
445 
446   if (src_sparsity->block_map()) {
447     const size_t block_map_size = src_sparsity->block_map()->size();
448     sparsity->block_map = TfLiteIntArrayCreate(block_map_size);
449     for (int i = 0; i < block_map_size; i++) {
450       sparsity->block_map->data[i] = src_sparsity->block_map()->Get(i);
451     }
452   }
453 
454   const size_t dim_metadata_size = src_sparsity->dim_metadata()->size();
455   sparsity->dim_metadata_size = dim_metadata_size;
456   sparsity->dim_metadata = reinterpret_cast<TfLiteDimensionMetadata*>(
457       malloc(dim_metadata_size * sizeof(TfLiteDimensionMetadata)));
458   memset(sparsity->dim_metadata, 0,
459          dim_metadata_size * sizeof(TfLiteDimensionMetadata));
460 
461   for (int i = 0; i < dim_metadata_size; i++) {
462     const auto* src_metadata = src_sparsity->dim_metadata()->Get(i);
463     if (src_metadata->format() != DimensionType_DENSE &&
464         src_metadata->format() != DimensionType_SPARSE_CSR) {
465       TF_LITE_REPORT_ERROR(error_reporter_,
466                            "The %dth dimension has unknown type: %d.", i,
467                            src_metadata->format());
468       return kTfLiteError;
469     }
470     auto* tgt_metadata = &sparsity->dim_metadata[i];
471 
472     tgt_metadata->format =
473         static_cast<TfLiteDimensionType>(src_metadata->format());
474 
475     if (tgt_metadata->format == kTfLiteDimDense) {
476       tgt_metadata->dense_size = src_metadata->dense_size();
477     } else {
478       if (ParseSparseIndexVector(src_metadata, tgt_metadata) != kTfLiteOk) {
479         TF_LITE_REPORT_ERROR(
480             error_reporter_,
481             "The %dth sparse dimension has invalid parameters.", i);
482         return kTfLiteError;
483       }
484     }
485   }
486 
487   return kTfLiteOk;
488 }
489 
ParseSignatureDefs(const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>> * signature_def_list,Interpreter * interpreter)490 TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
491     const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
492         signature_def_list,
493     Interpreter* interpreter) {
494   if (signature_def_list == nullptr || signature_def_list->size() == 0) {
495     return kTfLiteOk;
496   }
497   std::vector<Interpreter::SignatureDef> signature_defs;
498   signature_defs.reserve(signature_def_list->size());
499   for (const auto fb_signature_def : *signature_def_list) {
500     if (fb_signature_def == nullptr) {
501       TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
502       return kTfLiteError;
503     }
504     if (fb_signature_def->method_name() == nullptr) {
505       TF_LITE_REPORT_ERROR(error_reporter_,
506                            "Missing exported method name for SignatureDef");
507       return kTfLiteError;
508     }
509     if (fb_signature_def->inputs() == nullptr) {
510       TF_LITE_REPORT_ERROR(error_reporter_,
511                            "NULL SignatureDef inputs for exported method %s",
512                            fb_signature_def->method_name()->c_str());
513       return kTfLiteError;
514     }
515     if (fb_signature_def->outputs() == nullptr) {
516       TF_LITE_REPORT_ERROR(error_reporter_,
517                            "NULL SignatureDef outputs for exported method %s",
518                            fb_signature_def->method_name()->c_str());
519       return kTfLiteError;
520     }
521     signature_defs.resize(signature_defs.size() + 1);
522     auto& signature_def = signature_defs.back();
523     signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
524     signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
525     signature_def.method_name = fb_signature_def->method_name()->c_str();
526     if (fb_signature_def->key() != nullptr) {
527       signature_def.signature_def_key = fb_signature_def->key()->c_str();
528     }
529   }
530   interpreter->SetSignatureDef(std::move(signature_defs));
531   return kTfLiteOk;
532 }
533 
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Subgraph * subgraph)534 TfLiteStatus InterpreterBuilder::ParseTensors(
535     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
536     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
537     Subgraph* subgraph) {
538   TfLiteStatus status = kTfLiteOk;
539 
540   // A little helper to get the names of inputs and outputs. Note that they
541   // must outlive the subgraph.
542   auto get_name = [](const tflite::Tensor* t) -> const char* {
543     auto name = t->name();
544     if (name) return name->c_str();
545     return kEmptyTensorName;
546   };
547 
548   num_fp32_tensors_ = 0;
549   for (int i = 0; i < tensors->size(); ++i) {
550     const auto* tensor = tensors->Get(i);
551     std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
552 
553     TfLiteType type;
554     if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
555         kTfLiteOk) {
556       status = kTfLiteError;
557       continue;
558     }
559     if (type == kTfLiteFloat32) {
560       ++num_fp32_tensors_;
561     }
562     auto get_readonly_data = [&](const char** buffer_data,
563                                  size_t* buffer_size) {
564       // TODO(aselle): Check what happens if we have an unspecified size
565       // constant.
566       *buffer_data = nullptr;
567       if (tensor->buffer() == 0) return kTfLiteOk;
568       if (tensor->buffer() >= buffers->size()) {
569         error_reporter_->Report(
570             "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
571             i, tensor->buffer(), buffers->size());
572         return kTfLiteError;
573       }
574       if (auto* buffer = (*buffers)[tensor->buffer()]) {
575         if (auto* array = buffer->data()) {
576           if (size_t size = array->size()) {
577             *buffer_size = size;
578             *buffer_data = reinterpret_cast<const char*>(array->data());
579             return kTfLiteOk;
580           }
581         }
582       }
583       return kTfLiteOk;
584     };
585     size_t buffer_size = 0;
586     const char* buffer_ptr;
587     TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
588 
589     const auto* src_quantization = tensor->quantization();
590     TfLiteQuantization quantization;
591     if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
592       error_reporter_->Report("Tensor %d has invalid quantization parameters.",
593                               i);
594       status = kTfLiteError;
595     }
596 
597     std::vector<int> dims_signature = {};
598     if (tensor->shape_signature()) {
599       dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
600     }
601 
602     bool is_variable = tensor->is_variable();
603     if (buffer_ptr) {
604       if (is_variable) {
605         error_reporter_->Report(
606             "Tensor %d is a variable tensor with buffer. "
607             "It's not supported now.\n",
608             i);
609         status = kTfLiteError;
610       }
611 
612       // TODO(b/144999664): Only constant sparse tensor is supported now.
613       const auto* src_sparsity = tensor->sparsity();
614       TfLiteSparsity* sparsity = nullptr;
615       if (ParseSparsity(src_sparsity, &sparsity) != kTfLiteOk) {
616         error_reporter_->Report("Tensor %d has invalid sparsity parameters.",
617                                 i);
618         status = kTfLiteError;
619       }
620 
621       if (subgraph->SetTensorParametersReadOnly(
622               i, type, get_name(tensor), dims, quantization, buffer_ptr,
623               buffer_size, allocation_, sparsity) != kTfLiteOk) {
624         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
625                                 i);
626         status = kTfLiteError;
627       }
628     } else {
629       if (subgraph->SetTensorParametersReadWrite(
630               i, type, get_name(tensor), dims, quantization, is_variable,
631               dims_signature) != kTfLiteOk) {
632         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
633                                 i);
634         status = kTfLiteError;
635       }
636     }
637   }
638 
639   return status;
640 }
641 
ApplyDelegates(Interpreter * interpreter,int num_threads)642 TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter,
643                                                 int num_threads) {
644   // Apply Flex delegate if applicable.
645   if (has_flex_op_) {
646     if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
647       TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(
648           // Transfers ownership of flex_delegate to the interpreter.
649           std::move(flex_delegate)));
650     }
651   }
652   for (TfLiteDelegate* delegate : delegates_) {
653     // Note that we DON'T transfer ownership of the delegate to the interpreter.
654     // (Doing that would cause problems if operator() was invoked twice.)
655     TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegate(delegate));
656   }
657   return kTfLiteOk;
658 }
659 
operator ()(std::unique_ptr<Interpreter> * interpreter)660 TfLiteStatus InterpreterBuilder::operator()(
661     std::unique_ptr<Interpreter>* interpreter) {
662   return operator()(interpreter, /*num_threads=*/-1);
663 }
664 
operator ()(std::unique_ptr<Interpreter> * interpreter,int num_threads)665 TfLiteStatus InterpreterBuilder::operator()(
666     std::unique_ptr<Interpreter>* interpreter, int num_threads) {
667   if (!interpreter) {
668     error_reporter_->Report(
669         "Null output pointer passed to InterpreterBuilder.");
670     return kTfLiteError;
671   }
672 
673   if (num_threads < -1) {
674     error_reporter_->Report(
675         "num_threads should be >=0 or just -1 to let TFLite runtime set the "
676         "value.");
677     return kTfLiteError;
678   }
679 
680   // Safe exit by deleting partially created interpreter, to reduce verbosity
681   // on error conditions. Use by return cleanup_on_error();
682   auto cleanup_and_error = [&interpreter]() {
683     interpreter->reset();
684     return kTfLiteError;
685   };
686 
687   if (!model_) {
688     error_reporter_->Report("Null pointer passed in as model.");
689     return cleanup_and_error();
690   }
691 
692   if (model_->version() != TFLITE_SCHEMA_VERSION) {
693     error_reporter_->Report(
694         "Model provided is schema version %d not equal "
695         "to supported version %d.\n",
696         model_->version(), TFLITE_SCHEMA_VERSION);
697     return cleanup_and_error();
698   }
699 
700   if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
701     error_reporter_->Report("Registration failed.\n");
702     return cleanup_and_error();
703   }
704 
705   // Flatbuffer model schemas define a list of opcodes independent of the graph.
706   // We first map those to registrations. This reduces string lookups for custom
707   // ops since we only do it once per custom op rather than once per custom op
708   // invocation in the model graph.
709   // Construct interpreter with correct number of tensors and operators.
710   auto* subgraphs = model_->subgraphs();
711   auto* buffers = model_->buffers();
712 
713   if (subgraphs->size() == 0) {
714     TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n");
715     return cleanup_and_error();
716   }
717 
718   if (!buffers) {
719     TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n");
720     return cleanup_and_error();
721   }
722 
723   interpreter->reset(new Interpreter(error_reporter_));
724   (*interpreter)->SetNumThreads(num_threads);
725   if (subgraphs->size() > 1) {
726     (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
727   }
728 
729   (*interpreter)->SetProfiler(tflite::profiling::MaybeCreatePlatformProfiler());
730 
731   for (int subgraph_index = 0; subgraph_index < subgraphs->size();
732        ++subgraph_index) {
733     const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
734     tflite::Subgraph* modified_subgraph =
735         (*interpreter)->subgraph(subgraph_index);
736     auto operators = subgraph->operators();
737     auto tensors = subgraph->tensors();
738     if (!operators || !tensors) {
739       TF_LITE_REPORT_ERROR(error_reporter_,
740                            "Did not get operators or tensors in subgraph %d.\n",
741                            subgraph_index);
742       return cleanup_and_error();
743     }
744     if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
745       return cleanup_and_error();
746     }
747     // Set num threads
748     // Parse inputs/outputs
749     modified_subgraph->SetInputs(
750         FlatBufferIntArrayToVector(subgraph->inputs()));
751     modified_subgraph->SetOutputs(
752         FlatBufferIntArrayToVector(subgraph->outputs()));
753 
754     // Finally setup nodes and tensors
755     if (ParseNodes(operators, modified_subgraph) != kTfLiteOk)
756       return cleanup_and_error();
757     if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
758       return cleanup_and_error();
759 
760     std::vector<int> variables;
761     for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
762       auto* tensor = modified_subgraph->tensor(i);
763       if (tensor->is_variable) {
764         variables.push_back(i);
765       }
766     }
767     modified_subgraph->SetVariables(std::move(variables));
768     if (subgraph->name()) {
769       modified_subgraph->SetName(subgraph->name()->c_str());
770     }
771   }
772 
773   if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
774       kTfLiteOk) {
775     return cleanup_and_error();
776   }
777 
778   if (num_fp32_tensors_ > 0) {
779     (*interpreter)->lazy_delegate_providers_ =
780         op_resolver_.GetDelegates(num_threads);
781   }
782 
783   TfLiteStatus status = ApplyDelegates(interpreter->get(), num_threads);
784   if (status != kTfLiteOk) {
785     interpreter->reset();
786   }
787   return status;
788 }
789 
AddDelegate(TfLiteDelegate * delegate)790 void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
791   if (delegate == nullptr) {
792     TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
793   } else {
794     delegates_.push_back(delegate);
795   }
796 }
797 
798 }  // namespace tflite
799