1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
17 
18 #include <unistd.h>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/stderr_reporter.h"
24 #include "tensorflow/lite/tools/verifier.h"
25 #include "tensorflow_lite_support/cc/common.h"
26 #include "tensorflow_lite_support/cc/port/status_macros.h"
27 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
28 
29 #if TFLITE_USE_C_API
30 #include "tensorflow/lite/c/c_api_experimental.h"
31 #else
32 #include "tensorflow/lite/kernels/register.h"
33 #endif
34 
35 namespace tflite {
36 namespace task {
37 namespace core {
38 
39 #ifdef __ANDROID__
40 // https://github.com/opencv/opencv/issues/14906
41 // "ios_base::Init" object is not a part of Android's "iostream" header (in case
42 // of clang toolchain, NDK 20).
43 //
44 // Ref1:
45 // https://en.cppreference.com/w/cpp/io/ios_base/Init
46 //       The header <iostream> behaves as if it defines (directly or indirectly)
47 //       an instance of std::ios_base::Init with static storage duration
48 //
49 // Ref2:
50 // https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
51 static std::ios_base::Init s_iostream_initializer;
52 #endif
53 
54 using ::absl::StatusCode;
55 using ::tflite::support::CreateStatusWithPayload;
56 using ::tflite::support::TfLiteSupportStatus;
57 
Report(const char * format,va_list args)58 int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
59   return std::vsnprintf(error_message, sizeof(error_message), format, args);
60 }
61 
Verify(const char * data,int length,tflite::ErrorReporter * reporter)62 bool TfLiteEngine::Verifier::Verify(const char* data, int length,
63                                     tflite::ErrorReporter* reporter) {
64   return tflite::Verify(data, length, *op_resolver_, reporter);
65 }
66 
67 #if TFLITE_USE_C_API
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)68 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
69     : model_(nullptr, TfLiteModelDelete),
70       resolver_(std::move(resolver)),
71       verifier_(resolver_.get()) {}
72 #else
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)73 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
74     : model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {}
75 #endif
76 
GetInputs()77 std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
78   Interpreter* interpreter = this->interpreter();
79   std::vector<TfLiteTensor*> tensors;
80   int input_count = InputCount(interpreter);
81   tensors.reserve(input_count);
82   for (int index = 0; index < input_count; index++) {
83     tensors.push_back(GetInput(interpreter, index));
84   }
85   return tensors;
86 }
87 
GetOutputs()88 std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
89   Interpreter* interpreter = this->interpreter();
90   std::vector<const TfLiteTensor*> tensors;
91   int output_count = OutputCount(interpreter);
92   tensors.reserve(output_count);
93   for (int index = 0; index < output_count; index++) {
94     tensors.push_back(GetOutput(interpreter, index));
95   }
96   return tensors;
97 }
98 
99 // The following function is adapted from the code in
100 // tflite::FlatBufferModel::VerifyAndBuildFromBuffer.
VerifyAndBuildModelFromBuffer(const char * buffer_data,size_t buffer_size)101 void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data,
102                                                  size_t buffer_size) {
103 #if TFLITE_USE_C_API
104   // First verify with the base flatbuffers verifier.
105   // This verifies that the model is a valid flatbuffer model.
106   flatbuffers::Verifier base_verifier(
107       reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
108   if (!VerifyModelBuffer(base_verifier)) {
109     TF_LITE_REPORT_ERROR(&error_reporter_,
110                          "The model is not a valid Flatbuffer buffer");
111     model_ = nullptr;
112     return;
113   }
114   // Next verify with the extra verifier.  This verifies that the model only
115   // uses operators supported by the OpResolver.
116   if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) {
117     model_ = nullptr;
118     return;
119   }
120   // Build the model.
121   model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
122 #else
123   model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
124       buffer_data, buffer_size, &verifier_, &error_reporter_);
125 #endif
126 }
127 
InitializeFromModelFileHandler()128 absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
129   const char* buffer_data = model_file_handler_->GetFileContent().data();
130   size_t buffer_size = model_file_handler_->GetFileContent().size();
131   VerifyAndBuildModelFromBuffer(buffer_data, buffer_size);
132   if (model_ == nullptr) {
133     // To be replaced with a proper switch-case when TF Lite model builder
134     // returns a `TfLiteStatus` code capturing this type of error.
135     if (absl::StrContains(error_reporter_.error_message,
136                           "The model is not a valid Flatbuffer")) {
137       return CreateStatusWithPayload(
138           StatusCode::kInvalidArgument, error_reporter_.error_message,
139           TfLiteSupportStatus::kInvalidFlatBufferError);
140     } else {
141       // TODO(b/154917059): augment status with another `TfLiteStatus` code when
142       // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
143       // code, instead of the unspecified `kError`.
144       return CreateStatusWithPayload(
145           StatusCode::kUnknown,
146           absl::StrCat(
147               "Could not build model from the provided pre-loaded flatbuffer: ",
148               error_reporter_.error_message));
149     }
150   }
151 
152   ASSIGN_OR_RETURN(
153       model_metadata_extractor_,
154       tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
155           buffer_data, buffer_size));
156 
157   return absl::OkStatus();
158 }
159 
BuildModelFromFlatBuffer(const char * buffer_data,size_t buffer_size)160 absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
161                                                     size_t buffer_size) {
162   if (model_) {
163     return CreateStatusWithPayload(StatusCode::kInternal,
164                                    "Model already built");
165   }
166   external_file_.set_file_content(std::string(buffer_data, buffer_size));
167   ASSIGN_OR_RETURN(
168       model_file_handler_,
169       ExternalFileHandler::CreateFromExternalFile(&external_file_));
170   return InitializeFromModelFileHandler();
171 }
172 
BuildModelFromFile(const std::string & file_name)173 absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
174   if (model_) {
175     return CreateStatusWithPayload(StatusCode::kInternal,
176                                    "Model already built");
177   }
178   external_file_.set_file_name(file_name);
179   ASSIGN_OR_RETURN(
180       model_file_handler_,
181       ExternalFileHandler::CreateFromExternalFile(&external_file_));
182   return InitializeFromModelFileHandler();
183 }
184 
BuildModelFromFileDescriptor(int file_descriptor)185 absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
186   if (model_) {
187     return CreateStatusWithPayload(StatusCode::kInternal,
188                                    "Model already built");
189   }
190   external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
191   ASSIGN_OR_RETURN(
192       model_file_handler_,
193       ExternalFileHandler::CreateFromExternalFile(&external_file_));
194   return InitializeFromModelFileHandler();
195 }
196 
BuildModelFromExternalFileProto(const ExternalFile * external_file)197 absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
198     const ExternalFile* external_file) {
199   if (model_) {
200     return CreateStatusWithPayload(StatusCode::kInternal,
201                                    "Model already built");
202   }
203   ASSIGN_OR_RETURN(model_file_handler_,
204                    ExternalFileHandler::CreateFromExternalFile(external_file));
205   return InitializeFromModelFileHandler();
206 }
207 
InitInterpreter(int num_threads)208 absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
209   tflite::proto::ComputeSettings compute_settings;
210   return InitInterpreter(compute_settings, num_threads);
211 }
212 
213 #if TFLITE_USE_C_API
FindBuiltinOp(void * user_data,TfLiteBuiltinOperator builtin_op,int version)214 const TfLiteRegistration* FindBuiltinOp(void* user_data,
215                                         TfLiteBuiltinOperator builtin_op,
216                                         int version) {
217   OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
218   tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
219   return op_resolver->FindOp(op, version);
220 }
221 
FindCustomOp(void * user_data,const char * custom_op,int version)222 const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op,
223                                        int version) {
224   OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
225   return op_resolver->FindOp(custom_op, version);
226 }
227 #endif
228 
InitInterpreter(const tflite::proto::ComputeSettings & compute_settings,int num_threads)229 absl::Status TfLiteEngine::InitInterpreter(
230     const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
231   if (model_ == nullptr) {
232     return CreateStatusWithPayload(
233         StatusCode::kInternal,
234         "TF Lite FlatBufferModel is null. Please make sure to call one of the "
235         "BuildModelFrom methods before calling InitInterpreter.");
236   }
237 #if TFLITE_USE_C_API
238   std::function<absl::Status(TfLiteDelegate*,
239                              std::unique_ptr<Interpreter, InterpreterDeleter>*)>
240       initializer = [this, num_threads](
241           TfLiteDelegate* optional_delegate,
242           std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
243       -> absl::Status {
244     std::unique_ptr<TfLiteInterpreterOptions,
245                     void (*)(TfLiteInterpreterOptions*)>
246         options{TfLiteInterpreterOptionsCreate(),
247                 TfLiteInterpreterOptionsDelete};
248     TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
249                                           FindCustomOp, resolver_.get());
250     TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
251     if (optional_delegate != nullptr) {
252       TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
253     }
254     interpreter_out->reset(
255         TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
256     if (*interpreter_out == nullptr) {
257       return CreateStatusWithPayload(
258           StatusCode::kAborted,
259           absl::StrCat("Could not build the TF Lite interpreter: "
260                        "TfLiteInterpreterCreateWithSelectedOps failed: ",
261                        error_reporter_.error_message));
262     }
263     return absl::OkStatus();
264   };
265 #else
266   auto initializer =
267       [this, num_threads](
268           std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
269       -> absl::Status {
270     if (tflite::InterpreterBuilder(*model_, *resolver_)(
271             interpreter_out, num_threads) != kTfLiteOk) {
272       return CreateStatusWithPayload(
273           StatusCode::kUnknown,
274           absl::StrCat("Could not build the TF Lite interpreter: ",
275                        error_reporter_.error_message));
276     }
277     if (*interpreter_out == nullptr) {
278       return CreateStatusWithPayload(StatusCode::kInternal,
279                                      "TF Lite interpreter is null.");
280     }
281     return absl::OkStatus();
282   };
283 #endif
284 
285   absl::Status status =
286       interpreter_.InitializeWithFallback(initializer, compute_settings);
287 
288   if (!status.ok() &&
289       !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
290     status = CreateStatusWithPayload(status.code(), status.message());
291   }
292   return status;
293 }
294 
295 }  // namespace core
296 }  // namespace task
297 }  // namespace tflite
298