1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/c_api_internal.h"
21 #include "tensorflow/lite/create_op_resolver.h"
22 #include "tensorflow/lite/delegates/interpreter_utils.h"
23 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
24 #include "tensorflow/lite/error_reporter.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/version.h"
29 
30 namespace {
31 class CallbackErrorReporter : public tflite::ErrorReporter {
32  public:
CallbackErrorReporter(TfLiteErrorReporterCallback callback)33   explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback)
34       : callback_(callback) {}
35 
Report(const char * format,va_list args)36   int Report(const char* format, va_list args) override {
37     callback_.error_reporter(callback_.user_data, format, args);
38     return 0;
39   }
40 
41  private:
42   TfLiteErrorReporterCallback callback_;
43 };
44 
45 /// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the
46 /// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks`
47 /// struct.
48 ///
49 /// The SetCallbacks method must be called before calling any of the FindOp
50 /// methods.
51 class CallbackOpResolver : public ::tflite::OpResolver {
52  public:
CallbackOpResolver()53   CallbackOpResolver() {}
SetCallbacks(const struct TfLiteOpResolverCallbacks & op_resolver_callbacks)54   void SetCallbacks(
55       const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) {
56     op_resolver_callbacks_ = op_resolver_callbacks;
57   }
FindOp(tflite::BuiltinOperator op,int version) const58   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
59                                    int version) const override {
60     if (op_resolver_callbacks_.find_builtin_op == nullptr) {
61       return nullptr;
62     }
63     return op_resolver_callbacks_.find_builtin_op(
64         op_resolver_callbacks_.user_data,
65         static_cast<TfLiteBuiltinOperator>(op), version);
66   }
FindOp(const char * op,int version) const67   const TfLiteRegistration* FindOp(const char* op, int version) const override {
68     if (op_resolver_callbacks_.find_custom_op == nullptr) {
69       return nullptr;
70     }
71     return op_resolver_callbacks_.find_custom_op(
72         op_resolver_callbacks_.user_data, op, version);
73   }
74 
75  private:
76   CallbackOpResolver(const CallbackOpResolver&) = delete;
77   CallbackOpResolver& operator=(const CallbackOpResolver&) = delete;
78 
79   struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {};
80 };
81 
82 }  // namespace
83 
84 extern "C" {
85 
86 // LINT.IfChange
87 
TfLiteVersion()88 const char* TfLiteVersion() { return TFLITE_VERSION_STRING; }
89 
TfLiteModelCreate(const void * model_data,size_t model_size)90 TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) {
91   auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
92       static_cast<const char*>(model_data), model_size);
93   std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
94   return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
95 }
96 
TfLiteModelCreateFromFile(const char * model_path)97 TfLiteModel* TfLiteModelCreateFromFile(const char* model_path) {
98   auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(model_path);
99   std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
100   return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
101 }
102 
TfLiteModelDelete(TfLiteModel * model)103 void TfLiteModelDelete(TfLiteModel* model) { delete model; }
104 
TfLiteInterpreterOptionsCreate()105 TfLiteInterpreterOptions* TfLiteInterpreterOptionsCreate() {
106   return new TfLiteInterpreterOptions{};
107 }
108 
TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions * options)109 void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions* options) {
110   delete options;
111 }
112 
TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions * options,int32_t num_threads)113 void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options,
114                                            int32_t num_threads) {
115   options->num_threads = num_threads;
116 }
117 
TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions * options,TfLiteDelegate * delegate)118 void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options,
119                                          TfLiteDelegate* delegate) {
120   options->delegates.push_back(delegate);
121 }
122 
TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions * options,void (* reporter)(void * user_data,const char * format,va_list args),void * user_data)123 void TfLiteInterpreterOptionsSetErrorReporter(
124     TfLiteInterpreterOptions* options,
125     void (*reporter)(void* user_data, const char* format, va_list args),
126     void* user_data) {
127   options->error_reporter_callback.error_reporter = reporter;
128   options->error_reporter_callback.user_data = user_data;
129 }
130 
TfLiteInterpreterCreate(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)131 TfLiteInterpreter* TfLiteInterpreterCreate(
132     const TfLiteModel* model,
133     const TfLiteInterpreterOptions* optional_options) {
134   std::unique_ptr<tflite::MutableOpResolver> resolver =
135       tflite::CreateOpResolver();
136   return tflite::internal::InterpreterCreateWithOpResolver(
137       model, optional_options, resolver.get());
138 }
139 
TfLiteInterpreterDelete(TfLiteInterpreter * interpreter)140 void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) {
141   delete interpreter;
142 }
143 
TfLiteInterpreterGetInputTensorCount(const TfLiteInterpreter * interpreter)144 int32_t TfLiteInterpreterGetInputTensorCount(
145     const TfLiteInterpreter* interpreter) {
146   return static_cast<int32_t>(interpreter->impl->inputs().size());
147 }
148 
TfLiteInterpreterGetInputTensor(const TfLiteInterpreter * interpreter,int32_t input_index)149 TfLiteTensor* TfLiteInterpreterGetInputTensor(
150     const TfLiteInterpreter* interpreter, int32_t input_index) {
151   return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]);
152 }
153 
TfLiteInterpreterResizeInputTensor(TfLiteInterpreter * interpreter,int32_t input_index,const int * input_dims,int32_t input_dims_size)154 TfLiteStatus TfLiteInterpreterResizeInputTensor(TfLiteInterpreter* interpreter,
155                                                 int32_t input_index,
156                                                 const int* input_dims,
157                                                 int32_t input_dims_size) {
158   std::vector<int> dims{input_dims, input_dims + input_dims_size};
159   return interpreter->impl->ResizeInputTensor(
160       interpreter->impl->inputs()[input_index], dims);
161 }
162 
TfLiteInterpreterAllocateTensors(TfLiteInterpreter * interpreter)163 TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
164   return interpreter->impl->AllocateTensors();
165 }
166 
TfLiteInterpreterInvoke(TfLiteInterpreter * interpreter)167 TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
168   if (interpreter->enable_delegate_fallback) {
169     return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
170         interpreter->impl.get());
171   } else {
172     return interpreter->impl->Invoke();
173   }
174 }
175 
TfLiteInterpreterGetOutputTensorCount(const TfLiteInterpreter * interpreter)176 int32_t TfLiteInterpreterGetOutputTensorCount(
177     const TfLiteInterpreter* interpreter) {
178   return static_cast<int32_t>(interpreter->impl->outputs().size());
179 }
180 
TfLiteInterpreterGetOutputTensor(const TfLiteInterpreter * interpreter,int32_t output_index)181 const TfLiteTensor* TfLiteInterpreterGetOutputTensor(
182     const TfLiteInterpreter* interpreter, int32_t output_index) {
183   return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]);
184 }
185 
TfLiteTensorType(const TfLiteTensor * tensor)186 TfLiteType TfLiteTensorType(const TfLiteTensor* tensor) { return tensor->type; }
187 
TfLiteTensorNumDims(const TfLiteTensor * tensor)188 int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor) {
189   return tensor->dims->size;
190 }
191 
TfLiteTensorDim(const TfLiteTensor * tensor,int32_t dim_index)192 int32_t TfLiteTensorDim(const TfLiteTensor* tensor, int32_t dim_index) {
193   return tensor->dims->data[dim_index];
194 }
195 
TfLiteTensorByteSize(const TfLiteTensor * tensor)196 size_t TfLiteTensorByteSize(const TfLiteTensor* tensor) {
197   return tensor->bytes;
198 }
199 
TfLiteTensorData(const TfLiteTensor * tensor)200 void* TfLiteTensorData(const TfLiteTensor* tensor) { return tensor->data.raw; }
201 
TfLiteTensorName(const TfLiteTensor * tensor)202 const char* TfLiteTensorName(const TfLiteTensor* tensor) {
203   return tensor->name;
204 }
205 
TfLiteTensorQuantizationParams(const TfLiteTensor * tensor)206 TfLiteQuantizationParams TfLiteTensorQuantizationParams(
207     const TfLiteTensor* tensor) {
208   return tensor->params;
209 }
210 
TfLiteTensorCopyFromBuffer(TfLiteTensor * tensor,const void * input_data,size_t input_data_size)211 TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor* tensor,
212                                         const void* input_data,
213                                         size_t input_data_size) {
214   if (tensor->bytes != input_data_size) {
215     return kTfLiteError;
216   }
217   memcpy(tensor->data.raw, input_data, input_data_size);
218   return kTfLiteOk;
219 }
220 
TfLiteTensorCopyToBuffer(const TfLiteTensor * tensor,void * output_data,size_t output_data_size)221 TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor,
222                                       void* output_data,
223                                       size_t output_data_size) {
224   if (tensor->bytes != output_data_size) {
225     return kTfLiteError;
226   }
227   memcpy(output_data, tensor->data.raw, output_data_size);
228   return kTfLiteOk;
229 }
230 
231 // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
232 
233 }  // extern "C"
234 
235 namespace tflite {
236 namespace internal {
237 
InterpreterCreateWithOpResolver(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options,tflite::MutableOpResolver * mutable_resolver)238 TfLiteInterpreter* InterpreterCreateWithOpResolver(
239     const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options,
240     tflite::MutableOpResolver* mutable_resolver) {
241   TFLITE_DCHECK_NE(mutable_resolver, nullptr);
242   if (!model || !model->impl) {
243     return nullptr;
244   }
245 
246   std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
247   if (optional_options &&
248       optional_options->error_reporter_callback.error_reporter != nullptr) {
249     optional_error_reporter.reset(
250         new CallbackErrorReporter(optional_options->error_reporter_callback));
251   }
252 
253   // By default, we use the provided mutable_op_resolver, adding any builtin or
254   // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
255   // `TfLiteInterpreterOptionsAddCustomOp`.
256   tflite::OpResolver* op_resolver = mutable_resolver;
257   if (optional_options) {
258     mutable_resolver->AddAll(optional_options->mutable_op_resolver);
259   }
260   // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
261   // a non-null callback parameter, then we instead use a
262   // `CallbackOpResolver` that will forward to the callbacks provided there.
263   CallbackOpResolver callback_op_resolver;
264   if (optional_options &&
265       (optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
266        optional_options->op_resolver_callbacks.find_custom_op != nullptr)) {
267     callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
268     op_resolver = &callback_op_resolver;
269   }
270 
271   tflite::ErrorReporter* error_reporter = optional_error_reporter
272                                               ? optional_error_reporter.get()
273                                               : tflite::DefaultErrorReporter();
274   tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
275                                      error_reporter);
276 
277   std::unique_ptr<tflite::Interpreter> interpreter;
278   if (builder(&interpreter) != kTfLiteOk) {
279     return nullptr;
280   }
281 
282   if (optional_options) {
283     if (optional_options->num_threads !=
284         TfLiteInterpreterOptions::kDefaultNumThreads) {
285       interpreter->SetNumThreads(optional_options->num_threads);
286     }
287 
288     if (optional_options->use_nnapi) {
289       if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) !=
290           kTfLiteOk) {
291         return nullptr;
292       }
293     }
294 
295     for (auto* delegate : optional_options->delegates) {
296       if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
297         return nullptr;
298       }
299     }
300   }
301 
302   bool enable_delegate_fallback =
303       optional_options != nullptr && optional_options->enable_delegate_fallback;
304 
305   return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
306                                std::move(interpreter),
307                                enable_delegate_fallback};
308 }
309 
310 }  // namespace internal
311 }  // namespace tflite
312