1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Deserialization infrastructure for tflite. Provides functionality
16 // to go from a serialized tflite model in flatbuffer format to an
17 // interpreter.
18 //
19 // using namespace tflite;
20 // StderrReporter error_reporter;
21 // auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
22 //                                             &error_reporter);
23 // MyOpResolver resolver;  // You need to subclass OpResolver to provide
24 //                         // implementations.
25 // InterpreterBuilder builder(*model, resolver);
26 // std::unique_ptr<Interpreter> interpreter;
27 // if(builder(&interpreter) == kTfLiteOk) {
28 //   .. run model inference with interpreter
29 // }
30 //
31 // OpResolver must be defined to provide your kernel implementations to the
32 // interpreter. This is environment specific and may consist of just the builtin
33 // ops, or some custom operators you defined to extend tflite.
34 #ifndef TENSORFLOW_LITE_MODEL_H_
35 #define TENSORFLOW_LITE_MODEL_H_
36 
37 #include <memory>
38 #include "tensorflow/lite/c/c_api_internal.h"
39 #include "tensorflow/lite/core/api/error_reporter.h"
40 #include "tensorflow/lite/core/api/op_resolver.h"
41 #include "tensorflow/lite/interpreter.h"
42 #include "tensorflow/lite/mutable_op_resolver.h"
43 #include "tensorflow/lite/schema/schema_generated.h"
44 
45 namespace tflite {
46 
47 // Abstract interface that verifies whether a given model is legit.
48 // It facilitates the use-case to verify and build a model without loading it
49 // twice.
50 class TfLiteVerifier {
51  public:
52   // Returns true if the model is legit.
53   virtual bool Verify(const char* data, int length,
54                       ErrorReporter* reporter) = 0;
~TfLiteVerifier()55   virtual ~TfLiteVerifier() {}
56 };
57 
58 // An RAII object that represents a read-only tflite model, copied from disk,
59 // or mmapped. This uses flatbuffers as the serialization format.
60 //
61 // NOTE: The current API requires that a FlatBufferModel instance be kept alive
62 // by the client as long as it is in use by any dependent Interpreter instances.
63 class FlatBufferModel {
64  public:
65   // Builds a model based on a file.
66   // Caller retains ownership of `error_reporter` and must ensure its lifetime
67   // is longer than the FlatBufferModel instance.
68   // Returns a nullptr in case of failure.
69   static std::unique_ptr<FlatBufferModel> BuildFromFile(
70       const char* filename,
71       ErrorReporter* error_reporter = DefaultErrorReporter());
72 
73   // Verifies whether the content of the file is legit, then builds a model
74   // based on the file.
75   // The extra_verifier argument is an additional optional verifier for the file
76   // contents. By default, we always check with tflite::VerifyModelBuffer. If
77   // extra_verifier is supplied, the file contents is also checked against the
78   // extra_verifier after the check against tflite::VerifyModelBuilder.
79   // Caller retains ownership of `error_reporter` and must ensure its lifetime
80   // is longer than the FlatBufferModel instance.
81   // Returns a nullptr in case of failure.
82   static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
83       const char* filename, TfLiteVerifier* extra_verifier = nullptr,
84       ErrorReporter* error_reporter = DefaultErrorReporter());
85 
86   // Builds a model based on a pre-loaded flatbuffer.
87   // Caller retains ownership of the buffer and should keep it alive until
88   // the returned object is destroyed. Caller also retains ownership of
89   // `error_reporter` and must ensure its lifetime is longer than the
90   // FlatBufferModel instance.
91   // Returns a nullptr in case of failure.
92   // NOTE: this does NOT validate the buffer so it should NOT be called on
93   // invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
94   static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
95       const char* caller_owned_buffer, size_t buffer_size,
96       ErrorReporter* error_reporter = DefaultErrorReporter());
97 
98   // Verifies whether the content of the buffer is legit, then builds a model
99   // based on the pre-loaded flatbuffer.
100   // The extra_verifier argument is an additional optional verifier for the
101   // buffer. By default, we always check with tflite::VerifyModelBuffer. If
102   // extra_verifier is supplied, the buffer is checked against the
103   // extra_verifier after the check against tflite::VerifyModelBuilder. The
104   // caller retains ownership of the buffer and should keep it alive until the
105   // returned object is destroyed. Caller retains ownership of `error_reporter`
106   // and must ensure its lifetime is longer than the FlatBufferModel instance.
107   // Returns a nullptr in case of failure.
108   static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromBuffer(
109       const char* buffer, size_t buffer_size,
110       TfLiteVerifier* extra_verifier = nullptr,
111       ErrorReporter* error_reporter = DefaultErrorReporter());
112 
113   // Builds a model directly from a flatbuffer pointer
114   // Caller retains ownership of the buffer and should keep it alive until the
115   // returned object is destroyed. Caller retains ownership of `error_reporter`
116   // and must ensure its lifetime is longer than the FlatBufferModel instance.
117   // Returns a nullptr in case of failure.
118   static std::unique_ptr<FlatBufferModel> BuildFromModel(
119       const tflite::Model* caller_owned_model_spec,
120       ErrorReporter* error_reporter = DefaultErrorReporter());
121 
122   // Releases memory or unmaps mmaped memory.
123   ~FlatBufferModel();
124 
125   // Copying or assignment is disallowed to simplify ownership semantics.
126   FlatBufferModel(const FlatBufferModel&) = delete;
127   FlatBufferModel& operator=(const FlatBufferModel&) = delete;
128 
initialized()129   bool initialized() const { return model_ != nullptr; }
130   const tflite::Model* operator->() const { return model_; }
GetModel()131   const tflite::Model* GetModel() const { return model_; }
error_reporter()132   ErrorReporter* error_reporter() const { return error_reporter_; }
allocation()133   const Allocation* allocation() const { return allocation_.get(); }
134 
135   // Returns true if the model identifier is correct (otherwise false and
136   // reports an error).
137   bool CheckModelIdentifier() const;
138 
139  private:
140   // Loads a model from a given allocation. FlatBufferModel will take over the
141   // ownership of `allocation`, and delete it in destructor. The ownership of
142   // `error_reporter`remains with the caller and must have lifetime at least
143   // as much as FlatBufferModel. This is to allow multiple models to use the
144   // same ErrorReporter instance.
145   FlatBufferModel(std::unique_ptr<Allocation> allocation,
146                   ErrorReporter* error_reporter = DefaultErrorReporter());
147 
148   // Loads a model from Model flatbuffer. The `model` has to remain alive and
149   // unchanged until the end of this flatbuffermodel's lifetime.
150   FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
151 
152   // Flatbuffer traverser pointer. (Model* is a pointer that is within the
153   // allocated memory of the data allocated by allocation's internals.
154   const tflite::Model* model_ = nullptr;
155   // The error reporter to use for model errors and subsequent errors when
156   // the interpreter is created
157   ErrorReporter* error_reporter_;
158   // The allocator used for holding memory of the model. Note that this will
159   // be null if the client provides a tflite::Model directly.
160   std::unique_ptr<Allocation> allocation_;
161 };
162 
163 // Build an interpreter capable of interpreting `model`.
164 //
165 // model: A model whose lifetime must be at least as long as any
166 //   interpreter(s) created by the builder. In principle multiple interpreters
167 //   can be made from a single model.
168 // op_resolver: An instance that implements the OpResolver interface, which maps
169 //   custom op names and builtin op codes to op registrations. The lifetime
170 //   of the provided `op_resolver` object must be at least as long as the
171 //   InterpreterBuilder; unlike `model` and `error_reporter`, the `op_resolver`
172 //   does not need to exist for the duration of any created Interpreter objects.
173 // error_reporter: a functor that is called to report errors that handles
174 //   printf var arg semantics. The lifetime of the `error_reporter` object must
175 //   be greater than or equal to the Interpreter created by operator().
176 //
177 // Returns a kTfLiteOk when successful and sets interpreter to a valid
178 // Interpreter. Note: The user must ensure the model lifetime (and error
179 // reporter, if provided) is at least as long as interpreter's lifetime.
180 class InterpreterBuilder {
181  public:
182   InterpreterBuilder(const FlatBufferModel& model,
183                      const OpResolver& op_resolver);
184   // Builds an interpreter given only the raw flatbuffer Model object (instead
185   // of a FlatBufferModel). Mostly used for testing.
186   // If `error_reporter` is null, then DefaultErrorReporter() is used.
187   InterpreterBuilder(const ::tflite::Model* model,
188                      const OpResolver& op_resolver,
189                      ErrorReporter* error_reporter = DefaultErrorReporter());
190   ~InterpreterBuilder();
191   InterpreterBuilder(const InterpreterBuilder&) = delete;
192   InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
193   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
194   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
195                           int num_threads);
196 
197  private:
198   TfLiteStatus BuildLocalIndexToRegistrationMapping();
199   TfLiteStatus ParseNodes(
200       const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
201       Subgraph* subgraph);
202   TfLiteStatus ParseTensors(
203       const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
204       const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
205       Subgraph* subgraph);
206   TfLiteStatus ApplyDelegates(Interpreter* interpreter);
207   TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
208                                  TfLiteQuantization* quantization);
209 
210   const ::tflite::Model* model_;
211   const OpResolver& op_resolver_;
212   ErrorReporter* error_reporter_;
213 
214   std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
215   std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
216   const Allocation* allocation_ = nullptr;
217 };
218 
219 }  // namespace tflite
220 
221 #endif  // TENSORFLOW_LITE_MODEL_H_
222