1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
18 
19 #include <sys/mman.h>
20 
21 #include <memory>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/status/status.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/core/api/op_resolver.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 
34 // If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API
35 // rather than the TF Lite C++ API.
36 // TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and
37 // elsewhere and instead use the C API unconditionally, once we have a suitable
38 // replacement for the features of tflite::support::TfLiteInterpreterWrapper.
39 #if TFLITE_USE_C_API
40 #include "tensorflow/lite/c/c_api.h"
41 #include "tensorflow/lite/core/api/verifier.h"
42 #include "tensorflow/lite/tools/verifier.h"
43 #else
44 #include "tensorflow/lite/interpreter.h"
45 #include "tensorflow/lite/model.h"
46 #endif
47 
48 namespace tflite {
49 namespace task {
50 namespace core {
51 
52 // TfLiteEngine encapsulates logic for TFLite model initialization, inference
53 // and error reporting.
54 class TfLiteEngine {
55  public:
56   // Types.
57   using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper;
58 #if TFLITE_USE_C_API
59   using Model = struct TfLiteModel;
60   using Interpreter = struct TfLiteInterpreter;
61   using ModelDeleter = void (*)(Model*);
62   using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter;
63 #else
64   using Model = tflite::FlatBufferModel;
65   using Interpreter = tflite::Interpreter;
66   using ModelDeleter = std::default_delete<Model>;
67   using InterpreterDeleter = std::default_delete<Interpreter>;
68 #endif
69 
70   // Constructors.
71   explicit TfLiteEngine(
72       std::unique_ptr<tflite::OpResolver> resolver =
73           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
74   // Model is neither copyable nor movable.
75   TfLiteEngine(const TfLiteEngine&) = delete;
76   TfLiteEngine& operator=(const TfLiteEngine&) = delete;
77 
78   // Accessors.
InputCount(const Interpreter * interpreter)79   static int32_t InputCount(const Interpreter* interpreter) {
80 #if TFLITE_USE_C_API
81     return TfLiteInterpreterGetInputTensorCount(interpreter);
82 #else
83     return interpreter->inputs().size();
84 #endif
85   }
OutputCount(const Interpreter * interpreter)86   static int32_t OutputCount(const Interpreter* interpreter) {
87 #if TFLITE_USE_C_API
88     return TfLiteInterpreterGetOutputTensorCount(interpreter);
89 #else
90     return interpreter->outputs().size();
91 #endif
92   }
GetInput(Interpreter * interpreter,int index)93   static TfLiteTensor* GetInput(Interpreter* interpreter, int index) {
94 #if TFLITE_USE_C_API
95     return TfLiteInterpreterGetInputTensor(interpreter, index);
96 #else
97     return interpreter->tensor(interpreter->inputs()[index]);
98 #endif
99   }
100   // Same as above, but const.
GetInput(const Interpreter * interpreter,int index)101   static const TfLiteTensor* GetInput(const Interpreter* interpreter,
102                                       int index) {
103 #if TFLITE_USE_C_API
104     return TfLiteInterpreterGetInputTensor(interpreter, index);
105 #else
106     return interpreter->tensor(interpreter->inputs()[index]);
107 #endif
108   }
GetOutput(Interpreter * interpreter,int index)109   static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) {
110 #if TFLITE_USE_C_API
111     // We need a const_cast here, because the TF Lite C API only has a non-const
112     // version of GetOutputTensor (in part because C doesn't support overloading
113     // on const).
114     return const_cast<TfLiteTensor*>(
115         TfLiteInterpreterGetOutputTensor(interpreter, index));
116 #else
117     return interpreter->tensor(interpreter->outputs()[index]);
118 #endif
119   }
120   // Same as above, but const.
GetOutput(const Interpreter * interpreter,int index)121   static const TfLiteTensor* GetOutput(const Interpreter* interpreter,
122                                        int index) {
123 #if TFLITE_USE_C_API
124     return TfLiteInterpreterGetOutputTensor(interpreter, index);
125 #else
126     return interpreter->tensor(interpreter->outputs()[index]);
127 #endif
128   }
129 
130   std::vector<TfLiteTensor*> GetInputs();
131   std::vector<const TfLiteTensor*> GetOutputs();
132 
model()133   const Model* model() const { return model_.get(); }
interpreter()134   Interpreter* interpreter() { return interpreter_.get(); }
interpreter()135   const Interpreter* interpreter() const { return interpreter_.get(); }
interpreter_wrapper()136   InterpreterWrapper* interpreter_wrapper() { return &interpreter_; }
metadata_extractor()137   const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const {
138     return model_metadata_extractor_.get();
139   }
140 
141   // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data
142   // whose ownership remains with the caller, and which must outlive the current
143   // object. This performs extra verification on the input data using
144   // tflite::Verify.
145   absl::Status BuildModelFromFlatBuffer(const char* buffer_data,
146                                         size_t buffer_size);
147 
148   // Builds the TF Lite model from a given file.
149   absl::Status BuildModelFromFile(const std::string& file_name);
150 
151   // Builds the TF Lite model from a given file descriptor using mmap(2).
152   absl::Status BuildModelFromFileDescriptor(int file_descriptor);
153 
154   // Builds the TFLite model from the provided ExternalFile proto, which must
155   // outlive the current object.
156   absl::Status BuildModelFromExternalFileProto(
157       const ExternalFile* external_file);
158 
159   // Initializes interpreter with encapsulated model.
160   // Note: setting num_threads to -1 has for effect to let TFLite runtime set
161   // the value.
162   absl::Status InitInterpreter(int num_threads = 1);
163 
164   // Same as above, but allows specifying `compute_settings` for acceleration.
165   absl::Status InitInterpreter(
166       const tflite::proto::ComputeSettings& compute_settings,
167       int num_threads = 1);
168 
169   // Cancels the on-going `Invoke()` call if any and if possible. This method
170   // can be called from a different thread than the one where `Invoke()` is
171   // running.
Cancel()172   void Cancel() {
173 #if TFLITE_USE_C_API
174     // NOP.
175 #else
176     interpreter_.Cancel();
177 #endif
178   }
179 
180  protected:
181   // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the
182   // error into a string so that it can be used to complement tensorflow::Status
183   // error messages.
184   struct ErrorReporter : public tflite::ErrorReporter {
185     // Last error message captured by this error reporter.
186     char error_message[256];
187     int Report(const char* format, va_list args) override;
188   };
189   // Custom error reporter capturing low-level TF Lite error messages.
190   ErrorReporter error_reporter_;
191 
192  private:
193   // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of
194   // the FlatBuffer data provided as input.
195   class Verifier : public tflite::TfLiteVerifier {
196    public:
Verifier(const tflite::OpResolver * op_resolver)197     explicit Verifier(const tflite::OpResolver* op_resolver)
198         : op_resolver_(op_resolver) {}
199     bool Verify(const char* data, int length,
200                 tflite::ErrorReporter* reporter) override;
201     // The OpResolver to be used to build the TF Lite interpreter.
202     const tflite::OpResolver* op_resolver_;
203   };
204 
205   // Verifies that the supplied buffer refers to a valid flatbuffer model,
206   // and that it uses only operators that are supported by the OpResolver
207   // that was passed to the TfLiteEngine constructor, and then builds
208   // the model from the buffer and stores it in 'model_'.
209   void VerifyAndBuildModelFromBuffer(const char* buffer_data,
210                                      size_t buffer_size);
211 
212   // Gets the buffer from the file handler; verifies and builds the model
213   // from the buffer; if successful, sets 'model_metadata_extractor_' to be
214   // a TF Lite Metadata extractor for the model; and calculates an appropriate
215   // return Status,
216   absl::Status InitializeFromModelFileHandler();
217 
218   // TF Lite model and interpreter for actual inference.
219   std::unique_ptr<Model, ModelDeleter> model_;
220 
221   // Interpreter wrapper built from the model.
222   InterpreterWrapper interpreter_;
223 
224   // TFLite Metadata extractor built from the model.
225   std::unique_ptr<tflite::metadata::ModelMetadataExtractor>
226       model_metadata_extractor_;
227 
228   // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to
229   // actual implementation. Defaults to TF Lite BuiltinOpResolver.
230   std::unique_ptr<tflite::OpResolver> resolver_;
231 
232   // Extra verifier for FlatBuffer input data.
233   Verifier verifier_;
234 
235   // ExternalFile and corresponding ExternalFileHandler for models loaded from
236   // disk or file descriptor.
237   ExternalFile external_file_;
238   std::unique_ptr<ExternalFileHandler> model_file_handler_;
239 };
240 
241 }  // namespace core
242 }  // namespace task
243 }  // namespace tflite
244 
245 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
246