1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 18 19 #include <sys/mman.h> 20 21 #include <memory> 22 23 #include "absl/memory/memory.h" 24 #include "absl/status/status.h" 25 #include "absl/strings/string_view.h" 26 #include "tensorflow/lite/c/common.h" 27 #include "tensorflow/lite/core/api/op_resolver.h" 28 #include "tensorflow/lite/kernels/register.h" 29 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h" 30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" 31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" 32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" 33 34 // If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API 35 // rather than the TF Lite C++ API. 36 // TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and 37 // elsewhere and instead use the C API unconditionally, once we have a suitable 38 // replacement for the features of tflite::support::TfLiteInterpreterWrapper. 39 #if TFLITE_USE_C_API 40 #include "tensorflow/lite/c/c_api.h" 41 #include "tensorflow/lite/core/api/verifier.h" 42 #include "tensorflow/lite/tools/verifier.h" 43 #else 44 #include "tensorflow/lite/interpreter.h" 45 #include "tensorflow/lite/model.h" 46 #endif 47 48 namespace tflite { 49 namespace task { 50 namespace core { 51 52 // TfLiteEngine encapsulates logic for TFLite model initialization, inference 53 // and error reporting. 54 class TfLiteEngine { 55 public: 56 // Types. 57 using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper; 58 #if TFLITE_USE_C_API 59 using Model = struct TfLiteModel; 60 using Interpreter = struct TfLiteInterpreter; 61 using ModelDeleter = void (*)(Model*); 62 using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter; 63 #else 64 using Model = tflite::FlatBufferModel; 65 using Interpreter = tflite::Interpreter; 66 using ModelDeleter = std::default_delete<Model>; 67 using InterpreterDeleter = std::default_delete<Interpreter>; 68 #endif 69 70 // Constructors. 71 explicit TfLiteEngine( 72 std::unique_ptr<tflite::OpResolver> resolver = 73 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 74 // Model is neither copyable nor movable. 75 TfLiteEngine(const TfLiteEngine&) = delete; 76 TfLiteEngine& operator=(const TfLiteEngine&) = delete; 77 78 // Accessors. InputCount(const Interpreter * interpreter)79 static int32_t InputCount(const Interpreter* interpreter) { 80 #if TFLITE_USE_C_API 81 return TfLiteInterpreterGetInputTensorCount(interpreter); 82 #else 83 return interpreter->inputs().size(); 84 #endif 85 } OutputCount(const Interpreter * interpreter)86 static int32_t OutputCount(const Interpreter* interpreter) { 87 #if TFLITE_USE_C_API 88 return TfLiteInterpreterGetOutputTensorCount(interpreter); 89 #else 90 return interpreter->outputs().size(); 91 #endif 92 } GetInput(Interpreter * interpreter,int index)93 static TfLiteTensor* GetInput(Interpreter* interpreter, int index) { 94 #if TFLITE_USE_C_API 95 return TfLiteInterpreterGetInputTensor(interpreter, index); 96 #else 97 return interpreter->tensor(interpreter->inputs()[index]); 98 #endif 99 } 100 // Same as above, but const. GetInput(const Interpreter * interpreter,int index)101 static const TfLiteTensor* GetInput(const Interpreter* interpreter, 102 int index) { 103 #if TFLITE_USE_C_API 104 return TfLiteInterpreterGetInputTensor(interpreter, index); 105 #else 106 return interpreter->tensor(interpreter->inputs()[index]); 107 #endif 108 } GetOutput(Interpreter * interpreter,int index)109 static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) { 110 #if TFLITE_USE_C_API 111 // We need a const_cast here, because the TF Lite C API only has a non-const 112 // version of GetOutputTensor (in part because C doesn't support overloading 113 // on const). 114 return const_cast<TfLiteTensor*>( 115 TfLiteInterpreterGetOutputTensor(interpreter, index)); 116 #else 117 return interpreter->tensor(interpreter->outputs()[index]); 118 #endif 119 } 120 // Same as above, but const. GetOutput(const Interpreter * interpreter,int index)121 static const TfLiteTensor* GetOutput(const Interpreter* interpreter, 122 int index) { 123 #if TFLITE_USE_C_API 124 return TfLiteInterpreterGetOutputTensor(interpreter, index); 125 #else 126 return interpreter->tensor(interpreter->outputs()[index]); 127 #endif 128 } 129 130 std::vector<TfLiteTensor*> GetInputs(); 131 std::vector<const TfLiteTensor*> GetOutputs(); 132 model()133 const Model* model() const { return model_.get(); } interpreter()134 Interpreter* interpreter() { return interpreter_.get(); } interpreter()135 const Interpreter* interpreter() const { return interpreter_.get(); } interpreter_wrapper()136 InterpreterWrapper* interpreter_wrapper() { return &interpreter_; } metadata_extractor()137 const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const { 138 return model_metadata_extractor_.get(); 139 } 140 141 // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data 142 // whose ownership remains with the caller, and which must outlive the current 143 // object. This performs extra verification on the input data using 144 // tflite::Verify. 145 absl::Status BuildModelFromFlatBuffer(const char* buffer_data, 146 size_t buffer_size); 147 148 // Builds the TF Lite model from a given file. 149 absl::Status BuildModelFromFile(const std::string& file_name); 150 151 // Builds the TF Lite model from a given file descriptor using mmap(2). 152 absl::Status BuildModelFromFileDescriptor(int file_descriptor); 153 154 // Builds the TFLite model from the provided ExternalFile proto, which must 155 // outlive the current object. 156 absl::Status BuildModelFromExternalFileProto( 157 const ExternalFile* external_file); 158 159 // Initializes interpreter with encapsulated model. 160 // Note: setting num_threads to -1 has for effect to let TFLite runtime set 161 // the value. 162 absl::Status InitInterpreter(int num_threads = 1); 163 164 // Same as above, but allows specifying `compute_settings` for acceleration. 165 absl::Status InitInterpreter( 166 const tflite::proto::ComputeSettings& compute_settings, 167 int num_threads = 1); 168 169 // Cancels the on-going `Invoke()` call if any and if possible. This method 170 // can be called from a different thread than the one where `Invoke()` is 171 // running. Cancel()172 void Cancel() { 173 #if TFLITE_USE_C_API 174 // NOP. 175 #else 176 interpreter_.Cancel(); 177 #endif 178 } 179 180 protected: 181 // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the 182 // error into a string so that it can be used to complement tensorflow::Status 183 // error messages. 184 struct ErrorReporter : public tflite::ErrorReporter { 185 // Last error message captured by this error reporter. 186 char error_message[256]; 187 int Report(const char* format, va_list args) override; 188 }; 189 // Custom error reporter capturing low-level TF Lite error messages. 190 ErrorReporter error_reporter_; 191 192 private: 193 // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of 194 // the FlatBuffer data provided as input. 195 class Verifier : public tflite::TfLiteVerifier { 196 public: Verifier(const tflite::OpResolver * op_resolver)197 explicit Verifier(const tflite::OpResolver* op_resolver) 198 : op_resolver_(op_resolver) {} 199 bool Verify(const char* data, int length, 200 tflite::ErrorReporter* reporter) override; 201 // The OpResolver to be used to build the TF Lite interpreter. 202 const tflite::OpResolver* op_resolver_; 203 }; 204 205 // Verifies that the supplied buffer refers to a valid flatbuffer model, 206 // and that it uses only operators that are supported by the OpResolver 207 // that was passed to the TfLiteEngine constructor, and then builds 208 // the model from the buffer and stores it in 'model_'. 209 void VerifyAndBuildModelFromBuffer(const char* buffer_data, 210 size_t buffer_size); 211 212 // Gets the buffer from the file handler; verifies and builds the model 213 // from the buffer; if successful, sets 'model_metadata_extractor_' to be 214 // a TF Lite Metadata extractor for the model; and calculates an appropriate 215 // return Status, 216 absl::Status InitializeFromModelFileHandler(); 217 218 // TF Lite model and interpreter for actual inference. 219 std::unique_ptr<Model, ModelDeleter> model_; 220 221 // Interpreter wrapper built from the model. 222 InterpreterWrapper interpreter_; 223 224 // TFLite Metadata extractor built from the model. 225 std::unique_ptr<tflite::metadata::ModelMetadataExtractor> 226 model_metadata_extractor_; 227 228 // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to 229 // actual implementation. Defaults to TF Lite BuiltinOpResolver. 230 std::unique_ptr<tflite::OpResolver> resolver_; 231 232 // Extra verifier for FlatBuffer input data. 233 Verifier verifier_; 234 235 // ExternalFile and corresponding ExternalFileHandler for models loaded from 236 // disk or file descriptor. 237 ExternalFile external_file_; 238 std::unique_ptr<ExternalFileHandler> model_file_handler_; 239 }; 240 241 } // namespace core 242 } // namespace task 243 } // namespace tflite 244 245 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 246