1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Deserialization infrastructure for tflite. Provides functionality 16 // to go from a serialized tflite model in flatbuffer format to an 17 // interpreter. 18 // 19 // using namespace tflite; 20 // StderrReporter error_reporter; 21 // auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", 22 // &error_reporter); 23 // MyOpResolver resolver; // You need to subclass OpResolver to provide 24 // // implementations. 25 // InterpreterBuilder builder(*model, resolver); 26 // std::unique_ptr<Interpreter> interpreter; 27 // if(builder(&interpreter) == kTfLiteOk) { 28 // .. run model inference with interpreter 29 // } 30 // 31 // OpResolver must be defined to provide your kernel implementations to the 32 // interpreter. This is environment specific and may consist of just the builtin 33 // ops, or some custom operators you defined to extend tflite. 34 #ifndef TENSORFLOW_LITE_MODEL_H_ 35 #define TENSORFLOW_LITE_MODEL_H_ 36 37 #include <memory> 38 #include "tensorflow/lite/c/c_api_internal.h" 39 #include "tensorflow/lite/core/api/error_reporter.h" 40 #include "tensorflow/lite/core/api/op_resolver.h" 41 #include "tensorflow/lite/interpreter.h" 42 #include "tensorflow/lite/mutable_op_resolver.h" 43 #include "tensorflow/lite/schema/schema_generated.h" 44 45 namespace tflite { 46 47 // Abstract interface that verifies whether a given model is legit. 48 // It facilitates the use-case to verify and build a model without loading it 49 // twice. 50 class TfLiteVerifier { 51 public: 52 // Returns true if the model is legit. 53 virtual bool Verify(const char* data, int length, 54 ErrorReporter* reporter) = 0; ~TfLiteVerifier()55 virtual ~TfLiteVerifier() {} 56 }; 57 58 // An RAII object that represents a read-only tflite model, copied from disk, 59 // or mmapped. This uses flatbuffers as the serialization format. 60 // 61 // NOTE: The current API requires that a FlatBufferModel instance be kept alive 62 // by the client as long as it is in use by any dependent Interpreter instances. 63 class FlatBufferModel { 64 public: 65 // Builds a model based on a file. 66 // Caller retains ownership of `error_reporter` and must ensure its lifetime 67 // is longer than the FlatBufferModel instance. 68 // Returns a nullptr in case of failure. 69 static std::unique_ptr<FlatBufferModel> BuildFromFile( 70 const char* filename, 71 ErrorReporter* error_reporter = DefaultErrorReporter()); 72 73 // Verifies whether the content of the file is legit, then builds a model 74 // based on the file. 75 // The extra_verifier argument is an additional optional verifier for the file 76 // contents. By default, we always check with tflite::VerifyModelBuffer. If 77 // extra_verifier is supplied, the file contents is also checked against the 78 // extra_verifier after the check against tflite::VerifyModelBuilder. 79 // Caller retains ownership of `error_reporter` and must ensure its lifetime 80 // is longer than the FlatBufferModel instance. 81 // Returns a nullptr in case of failure. 82 static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile( 83 const char* filename, TfLiteVerifier* extra_verifier = nullptr, 84 ErrorReporter* error_reporter = DefaultErrorReporter()); 85 86 // Builds a model based on a pre-loaded flatbuffer. 87 // Caller retains ownership of the buffer and should keep it alive until 88 // the returned object is destroyed. Caller also retains ownership of 89 // `error_reporter` and must ensure its lifetime is longer than the 90 // FlatBufferModel instance. 91 // Returns a nullptr in case of failure. 92 // NOTE: this does NOT validate the buffer so it should NOT be called on 93 // invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case 94 static std::unique_ptr<FlatBufferModel> BuildFromBuffer( 95 const char* caller_owned_buffer, size_t buffer_size, 96 ErrorReporter* error_reporter = DefaultErrorReporter()); 97 98 // Verifies whether the content of the buffer is legit, then builds a model 99 // based on the pre-loaded flatbuffer. 100 // The extra_verifier argument is an additional optional verifier for the 101 // buffer. By default, we always check with tflite::VerifyModelBuffer. If 102 // extra_verifier is supplied, the buffer is checked against the 103 // extra_verifier after the check against tflite::VerifyModelBuilder. The 104 // caller retains ownership of the buffer and should keep it alive until the 105 // returned object is destroyed. Caller retains ownership of `error_reporter` 106 // and must ensure its lifetime is longer than the FlatBufferModel instance. 107 // Returns a nullptr in case of failure. 108 static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer( 109 const char* buffer, size_t buffer_size, 110 TfLiteVerifier* extra_verifier = nullptr, 111 ErrorReporter* error_reporter = DefaultErrorReporter()); 112 113 // Builds a model directly from a flatbuffer pointer 114 // Caller retains ownership of the buffer and should keep it alive until the 115 // returned object is destroyed. Caller retains ownership of `error_reporter` 116 // and must ensure its lifetime is longer than the FlatBufferModel instance. 117 // Returns a nullptr in case of failure. 118 static std::unique_ptr<FlatBufferModel> BuildFromModel( 119 const tflite::Model* caller_owned_model_spec, 120 ErrorReporter* error_reporter = DefaultErrorReporter()); 121 122 // Releases memory or unmaps mmaped memory. 123 ~FlatBufferModel(); 124 125 // Copying or assignment is disallowed to simplify ownership semantics. 126 FlatBufferModel(const FlatBufferModel&) = delete; 127 FlatBufferModel& operator=(const FlatBufferModel&) = delete; 128 initialized()129 bool initialized() const { return model_ != nullptr; } 130 const tflite::Model* operator->() const { return model_; } GetModel()131 const tflite::Model* GetModel() const { return model_; } error_reporter()132 ErrorReporter* error_reporter() const { return error_reporter_; } allocation()133 const Allocation* allocation() const { return allocation_.get(); } 134 135 // Returns true if the model identifier is correct (otherwise false and 136 // reports an error). 137 bool CheckModelIdentifier() const; 138 139 private: 140 // Loads a model from a given allocation. FlatBufferModel will take over the 141 // ownership of `allocation`, and delete it in destructor. The ownership of 142 // `error_reporter`remains with the caller and must have lifetime at least 143 // as much as FlatBufferModel. This is to allow multiple models to use the 144 // same ErrorReporter instance. 145 FlatBufferModel(std::unique_ptr<Allocation> allocation, 146 ErrorReporter* error_reporter = DefaultErrorReporter()); 147 148 // Loads a model from Model flatbuffer. The `model` has to remain alive and 149 // unchanged until the end of this flatbuffermodel's lifetime. 150 FlatBufferModel(const Model* model, ErrorReporter* error_reporter); 151 152 // Flatbuffer traverser pointer. (Model* is a pointer that is within the 153 // allocated memory of the data allocated by allocation's internals. 154 const tflite::Model* model_ = nullptr; 155 // The error reporter to use for model errors and subsequent errors when 156 // the interpreter is created 157 ErrorReporter* error_reporter_; 158 // The allocator used for holding memory of the model. Note that this will 159 // be null if the client provides a tflite::Model directly. 160 std::unique_ptr<Allocation> allocation_; 161 }; 162 163 // Build an interpreter capable of interpreting `model`. 164 // 165 // model: A model whose lifetime must be at least as long as any 166 // interpreter(s) created by the builder. In principle multiple interpreters 167 // can be made from a single model. 168 // op_resolver: An instance that implements the OpResolver interface, which maps 169 // custom op names and builtin op codes to op registrations. The lifetime 170 // of the provided `op_resolver` object must be at least as long as the 171 // InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver` 172 // does not need to exist for the duration of any created Interpreter objects. 173 // error_reporter: a functor that is called to report errors that handles 174 // printf var arg semantics. The lifetime of the `error_reporter` object must 175 // be greater than or equal to the Interpreter created by operator(). 176 // 177 // Returns a kTfLiteOk when successful and sets interpreter to a valid 178 // Interpreter. Note: The user must ensure the model lifetime (and error 179 // reporter, if provided) is at least as long as interpreter's lifetime. 180 class InterpreterBuilder { 181 public: 182 InterpreterBuilder(const FlatBufferModel& model, 183 const OpResolver& op_resolver); 184 // Builds an interpreter given only the raw flatbuffer Model object (instead 185 // of a FlatBufferModel). Mostly used for testing. 186 // If `error_reporter` is null, then DefaultErrorReporter() is used. 187 InterpreterBuilder(const ::tflite::Model* model, 188 const OpResolver& op_resolver, 189 ErrorReporter* error_reporter = DefaultErrorReporter()); 190 ~InterpreterBuilder(); 191 InterpreterBuilder(const InterpreterBuilder&) = delete; 192 InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; 193 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter); 194 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter, 195 int num_threads); 196 197 private: 198 TfLiteStatus BuildLocalIndexToRegistrationMapping(); 199 TfLiteStatus ParseNodes( 200 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, 201 Subgraph* subgraph); 202 TfLiteStatus ParseTensors( 203 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, 204 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, 205 Subgraph* subgraph); 206 TfLiteStatus ApplyDelegates(Interpreter* interpreter); 207 TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, 208 TfLiteQuantization* quantization); 209 210 const ::tflite::Model* model_; 211 const OpResolver& op_resolver_; 212 ErrorReporter* error_reporter_; 213 214 std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_; 215 std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_; 216 const Allocation* allocation_ = nullptr; 217 }; 218 219 } // namespace tflite 220 221 #endif // TENSORFLOW_LITE_MODEL_H_ 222