1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
17
18 #include <unistd.h>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/stderr_reporter.h"
24 #include "tensorflow/lite/tools/verifier.h"
25 #include "tensorflow_lite_support/cc/common.h"
26 #include "tensorflow_lite_support/cc/port/status_macros.h"
27 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
28
29 #if TFLITE_USE_C_API
30 #include "tensorflow/lite/c/c_api_experimental.h"
31 #else
32 #include "tensorflow/lite/kernels/register.h"
33 #endif
34
35 namespace tflite {
36 namespace task {
37 namespace core {
38
39 #ifdef __ANDROID__
40 // https://github.com/opencv/opencv/issues/14906
41 // "ios_base::Init" object is not a part of Android's "iostream" header (in case
42 // of clang toolchain, NDK 20).
43 //
44 // Ref1:
45 // https://en.cppreference.com/w/cpp/io/ios_base/Init
46 // The header <iostream> behaves as if it defines (directly or indirectly)
47 // an instance of std::ios_base::Init with static storage duration
48 //
49 // Ref2:
50 // https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
51 static std::ios_base::Init s_iostream_initializer;
52 #endif
53
54 using ::absl::StatusCode;
55 using ::tflite::support::CreateStatusWithPayload;
56 using ::tflite::support::TfLiteSupportStatus;
57
Report(const char * format,va_list args)58 int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
59 return std::vsnprintf(error_message, sizeof(error_message), format, args);
60 }
61
Verify(const char * data,int length,tflite::ErrorReporter * reporter)62 bool TfLiteEngine::Verifier::Verify(const char* data, int length,
63 tflite::ErrorReporter* reporter) {
64 return tflite::Verify(data, length, *op_resolver_, reporter);
65 }
66
67 #if TFLITE_USE_C_API
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)68 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
69 : model_(nullptr, TfLiteModelDelete),
70 resolver_(std::move(resolver)),
71 verifier_(resolver_.get()) {}
72 #else
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)73 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
74 : model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {}
75 #endif
76
GetInputs()77 std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
78 Interpreter* interpreter = this->interpreter();
79 std::vector<TfLiteTensor*> tensors;
80 int input_count = InputCount(interpreter);
81 tensors.reserve(input_count);
82 for (int index = 0; index < input_count; index++) {
83 tensors.push_back(GetInput(interpreter, index));
84 }
85 return tensors;
86 }
87
GetOutputs()88 std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
89 Interpreter* interpreter = this->interpreter();
90 std::vector<const TfLiteTensor*> tensors;
91 int output_count = OutputCount(interpreter);
92 tensors.reserve(output_count);
93 for (int index = 0; index < output_count; index++) {
94 tensors.push_back(GetOutput(interpreter, index));
95 }
96 return tensors;
97 }
98
99 // The following function is adapted from the code in
100 // tflite::FlatBufferModel::VerifyAndBuildFromBuffer.
VerifyAndBuildModelFromBuffer(const char * buffer_data,size_t buffer_size)101 void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data,
102 size_t buffer_size) {
103 #if TFLITE_USE_C_API
104 // First verify with the base flatbuffers verifier.
105 // This verifies that the model is a valid flatbuffer model.
106 flatbuffers::Verifier base_verifier(
107 reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
108 if (!VerifyModelBuffer(base_verifier)) {
109 TF_LITE_REPORT_ERROR(&error_reporter_,
110 "The model is not a valid Flatbuffer buffer");
111 model_ = nullptr;
112 return;
113 }
114 // Next verify with the extra verifier. This verifies that the model only
115 // uses operators supported by the OpResolver.
116 if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) {
117 model_ = nullptr;
118 return;
119 }
120 // Build the model.
121 model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
122 #else
123 model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
124 buffer_data, buffer_size, &verifier_, &error_reporter_);
125 #endif
126 }
127
InitializeFromModelFileHandler()128 absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
129 const char* buffer_data = model_file_handler_->GetFileContent().data();
130 size_t buffer_size = model_file_handler_->GetFileContent().size();
131 VerifyAndBuildModelFromBuffer(buffer_data, buffer_size);
132 if (model_ == nullptr) {
133 // To be replaced with a proper switch-case when TF Lite model builder
134 // returns a `TfLiteStatus` code capturing this type of error.
135 if (absl::StrContains(error_reporter_.error_message,
136 "The model is not a valid Flatbuffer")) {
137 return CreateStatusWithPayload(
138 StatusCode::kInvalidArgument, error_reporter_.error_message,
139 TfLiteSupportStatus::kInvalidFlatBufferError);
140 } else {
141 // TODO(b/154917059): augment status with another `TfLiteStatus` code when
142 // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
143 // code, instead of the unspecified `kError`.
144 return CreateStatusWithPayload(
145 StatusCode::kUnknown,
146 absl::StrCat(
147 "Could not build model from the provided pre-loaded flatbuffer: ",
148 error_reporter_.error_message));
149 }
150 }
151
152 ASSIGN_OR_RETURN(
153 model_metadata_extractor_,
154 tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
155 buffer_data, buffer_size));
156
157 return absl::OkStatus();
158 }
159
BuildModelFromFlatBuffer(const char * buffer_data,size_t buffer_size)160 absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
161 size_t buffer_size) {
162 if (model_) {
163 return CreateStatusWithPayload(StatusCode::kInternal,
164 "Model already built");
165 }
166 external_file_.set_file_content(std::string(buffer_data, buffer_size));
167 ASSIGN_OR_RETURN(
168 model_file_handler_,
169 ExternalFileHandler::CreateFromExternalFile(&external_file_));
170 return InitializeFromModelFileHandler();
171 }
172
BuildModelFromFile(const std::string & file_name)173 absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
174 if (model_) {
175 return CreateStatusWithPayload(StatusCode::kInternal,
176 "Model already built");
177 }
178 external_file_.set_file_name(file_name);
179 ASSIGN_OR_RETURN(
180 model_file_handler_,
181 ExternalFileHandler::CreateFromExternalFile(&external_file_));
182 return InitializeFromModelFileHandler();
183 }
184
BuildModelFromFileDescriptor(int file_descriptor)185 absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
186 if (model_) {
187 return CreateStatusWithPayload(StatusCode::kInternal,
188 "Model already built");
189 }
190 external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
191 ASSIGN_OR_RETURN(
192 model_file_handler_,
193 ExternalFileHandler::CreateFromExternalFile(&external_file_));
194 return InitializeFromModelFileHandler();
195 }
196
BuildModelFromExternalFileProto(const ExternalFile * external_file)197 absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
198 const ExternalFile* external_file) {
199 if (model_) {
200 return CreateStatusWithPayload(StatusCode::kInternal,
201 "Model already built");
202 }
203 ASSIGN_OR_RETURN(model_file_handler_,
204 ExternalFileHandler::CreateFromExternalFile(external_file));
205 return InitializeFromModelFileHandler();
206 }
207
InitInterpreter(int num_threads)208 absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
209 tflite::proto::ComputeSettings compute_settings;
210 return InitInterpreter(compute_settings, num_threads);
211 }
212
213 #if TFLITE_USE_C_API
FindBuiltinOp(void * user_data,TfLiteBuiltinOperator builtin_op,int version)214 const TfLiteRegistration* FindBuiltinOp(void* user_data,
215 TfLiteBuiltinOperator builtin_op,
216 int version) {
217 OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
218 tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
219 return op_resolver->FindOp(op, version);
220 }
221
FindCustomOp(void * user_data,const char * custom_op,int version)222 const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op,
223 int version) {
224 OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
225 return op_resolver->FindOp(custom_op, version);
226 }
227 #endif
228
InitInterpreter(const tflite::proto::ComputeSettings & compute_settings,int num_threads)229 absl::Status TfLiteEngine::InitInterpreter(
230 const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
231 if (model_ == nullptr) {
232 return CreateStatusWithPayload(
233 StatusCode::kInternal,
234 "TF Lite FlatBufferModel is null. Please make sure to call one of the "
235 "BuildModelFrom methods before calling InitInterpreter.");
236 }
237 #if TFLITE_USE_C_API
238 std::function<absl::Status(TfLiteDelegate*,
239 std::unique_ptr<Interpreter, InterpreterDeleter>*)>
240 initializer = [this, num_threads](
241 TfLiteDelegate* optional_delegate,
242 std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
243 -> absl::Status {
244 std::unique_ptr<TfLiteInterpreterOptions,
245 void (*)(TfLiteInterpreterOptions*)>
246 options{TfLiteInterpreterOptionsCreate(),
247 TfLiteInterpreterOptionsDelete};
248 TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
249 FindCustomOp, resolver_.get());
250 TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
251 if (optional_delegate != nullptr) {
252 TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
253 }
254 interpreter_out->reset(
255 TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
256 if (*interpreter_out == nullptr) {
257 return CreateStatusWithPayload(
258 StatusCode::kAborted,
259 absl::StrCat("Could not build the TF Lite interpreter: "
260 "TfLiteInterpreterCreateWithSelectedOps failed: ",
261 error_reporter_.error_message));
262 }
263 return absl::OkStatus();
264 };
265 #else
266 auto initializer =
267 [this, num_threads](
268 std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
269 -> absl::Status {
270 if (tflite::InterpreterBuilder(*model_, *resolver_)(
271 interpreter_out, num_threads) != kTfLiteOk) {
272 return CreateStatusWithPayload(
273 StatusCode::kUnknown,
274 absl::StrCat("Could not build the TF Lite interpreter: ",
275 error_reporter_.error_message));
276 }
277 if (*interpreter_out == nullptr) {
278 return CreateStatusWithPayload(StatusCode::kInternal,
279 "TF Lite interpreter is null.");
280 }
281 return absl::OkStatus();
282 };
283 #endif
284
285 absl::Status status =
286 interpreter_.InitializeWithFallback(initializer, compute_settings);
287
288 if (!status.ok() &&
289 !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
290 status = CreateStatusWithPayload(status.code(), status.message());
291 }
292 return status;
293 }
294
295 } // namespace core
296 } // namespace task
297 } // namespace tflite
298