1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api.h"
16
17 #include <memory>
18
19 #include "tensorflow/lite/builtin_ops.h"
20 #include "tensorflow/lite/c/c_api_internal.h"
21 #include "tensorflow/lite/create_op_resolver.h"
22 #include "tensorflow/lite/delegates/interpreter_utils.h"
23 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
24 #include "tensorflow/lite/error_reporter.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/model.h"
28 #include "tensorflow/lite/version.h"
29
30 namespace {
31 class CallbackErrorReporter : public tflite::ErrorReporter {
32 public:
CallbackErrorReporter(TfLiteErrorReporterCallback callback)33 explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback)
34 : callback_(callback) {}
35
Report(const char * format,va_list args)36 int Report(const char* format, va_list args) override {
37 callback_.error_reporter(callback_.user_data, format, args);
38 return 0;
39 }
40
41 private:
42 TfLiteErrorReporterCallback callback_;
43 };
44
45 /// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the
46 /// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks`
47 /// struct.
48 ///
49 /// The SetCallbacks method must be called before calling any of the FindOp
50 /// methods.
51 class CallbackOpResolver : public ::tflite::OpResolver {
52 public:
CallbackOpResolver()53 CallbackOpResolver() {}
SetCallbacks(const struct TfLiteOpResolverCallbacks & op_resolver_callbacks)54 void SetCallbacks(
55 const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) {
56 op_resolver_callbacks_ = op_resolver_callbacks;
57 }
FindOp(tflite::BuiltinOperator op,int version) const58 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
59 int version) const override {
60 if (op_resolver_callbacks_.find_builtin_op == nullptr) {
61 return nullptr;
62 }
63 return op_resolver_callbacks_.find_builtin_op(
64 op_resolver_callbacks_.user_data,
65 static_cast<TfLiteBuiltinOperator>(op), version);
66 }
FindOp(const char * op,int version) const67 const TfLiteRegistration* FindOp(const char* op, int version) const override {
68 if (op_resolver_callbacks_.find_custom_op == nullptr) {
69 return nullptr;
70 }
71 return op_resolver_callbacks_.find_custom_op(
72 op_resolver_callbacks_.user_data, op, version);
73 }
74
75 private:
76 CallbackOpResolver(const CallbackOpResolver&) = delete;
77 CallbackOpResolver& operator=(const CallbackOpResolver&) = delete;
78
79 struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {};
80 };
81
82 } // namespace
83
84 extern "C" {
85
86 // LINT.IfChange
87
TfLiteVersion()88 const char* TfLiteVersion() { return TFLITE_VERSION_STRING; }
89
TfLiteModelCreate(const void * model_data,size_t model_size)90 TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) {
91 auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
92 static_cast<const char*>(model_data), model_size);
93 std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
94 return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
95 }
96
TfLiteModelCreateFromFile(const char * model_path)97 TfLiteModel* TfLiteModelCreateFromFile(const char* model_path) {
98 auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(model_path);
99 std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
100 return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
101 }
102
TfLiteModelDelete(TfLiteModel * model)103 void TfLiteModelDelete(TfLiteModel* model) { delete model; }
104
TfLiteInterpreterOptionsCreate()105 TfLiteInterpreterOptions* TfLiteInterpreterOptionsCreate() {
106 return new TfLiteInterpreterOptions{};
107 }
108
TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions * options)109 void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions* options) {
110 delete options;
111 }
112
TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions * options,int32_t num_threads)113 void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options,
114 int32_t num_threads) {
115 options->num_threads = num_threads;
116 }
117
TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions * options,TfLiteDelegate * delegate)118 void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options,
119 TfLiteDelegate* delegate) {
120 options->delegates.push_back(delegate);
121 }
122
TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions * options,void (* reporter)(void * user_data,const char * format,va_list args),void * user_data)123 void TfLiteInterpreterOptionsSetErrorReporter(
124 TfLiteInterpreterOptions* options,
125 void (*reporter)(void* user_data, const char* format, va_list args),
126 void* user_data) {
127 options->error_reporter_callback.error_reporter = reporter;
128 options->error_reporter_callback.user_data = user_data;
129 }
130
TfLiteInterpreterCreate(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)131 TfLiteInterpreter* TfLiteInterpreterCreate(
132 const TfLiteModel* model,
133 const TfLiteInterpreterOptions* optional_options) {
134 std::unique_ptr<tflite::MutableOpResolver> resolver =
135 tflite::CreateOpResolver();
136 return tflite::internal::InterpreterCreateWithOpResolver(
137 model, optional_options, resolver.get());
138 }
139
TfLiteInterpreterDelete(TfLiteInterpreter * interpreter)140 void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) {
141 delete interpreter;
142 }
143
TfLiteInterpreterGetInputTensorCount(const TfLiteInterpreter * interpreter)144 int32_t TfLiteInterpreterGetInputTensorCount(
145 const TfLiteInterpreter* interpreter) {
146 return static_cast<int32_t>(interpreter->impl->inputs().size());
147 }
148
TfLiteInterpreterGetInputTensor(const TfLiteInterpreter * interpreter,int32_t input_index)149 TfLiteTensor* TfLiteInterpreterGetInputTensor(
150 const TfLiteInterpreter* interpreter, int32_t input_index) {
151 return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]);
152 }
153
TfLiteInterpreterResizeInputTensor(TfLiteInterpreter * interpreter,int32_t input_index,const int * input_dims,int32_t input_dims_size)154 TfLiteStatus TfLiteInterpreterResizeInputTensor(TfLiteInterpreter* interpreter,
155 int32_t input_index,
156 const int* input_dims,
157 int32_t input_dims_size) {
158 std::vector<int> dims{input_dims, input_dims + input_dims_size};
159 return interpreter->impl->ResizeInputTensor(
160 interpreter->impl->inputs()[input_index], dims);
161 }
162
TfLiteInterpreterAllocateTensors(TfLiteInterpreter * interpreter)163 TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
164 return interpreter->impl->AllocateTensors();
165 }
166
TfLiteInterpreterInvoke(TfLiteInterpreter * interpreter)167 TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
168 if (interpreter->enable_delegate_fallback) {
169 return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
170 interpreter->impl.get());
171 } else {
172 return interpreter->impl->Invoke();
173 }
174 }
175
TfLiteInterpreterGetOutputTensorCount(const TfLiteInterpreter * interpreter)176 int32_t TfLiteInterpreterGetOutputTensorCount(
177 const TfLiteInterpreter* interpreter) {
178 return static_cast<int32_t>(interpreter->impl->outputs().size());
179 }
180
TfLiteInterpreterGetOutputTensor(const TfLiteInterpreter * interpreter,int32_t output_index)181 const TfLiteTensor* TfLiteInterpreterGetOutputTensor(
182 const TfLiteInterpreter* interpreter, int32_t output_index) {
183 return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]);
184 }
185
TfLiteTensorType(const TfLiteTensor * tensor)186 TfLiteType TfLiteTensorType(const TfLiteTensor* tensor) { return tensor->type; }
187
TfLiteTensorNumDims(const TfLiteTensor * tensor)188 int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor) {
189 return tensor->dims->size;
190 }
191
TfLiteTensorDim(const TfLiteTensor * tensor,int32_t dim_index)192 int32_t TfLiteTensorDim(const TfLiteTensor* tensor, int32_t dim_index) {
193 return tensor->dims->data[dim_index];
194 }
195
TfLiteTensorByteSize(const TfLiteTensor * tensor)196 size_t TfLiteTensorByteSize(const TfLiteTensor* tensor) {
197 return tensor->bytes;
198 }
199
TfLiteTensorData(const TfLiteTensor * tensor)200 void* TfLiteTensorData(const TfLiteTensor* tensor) { return tensor->data.raw; }
201
TfLiteTensorName(const TfLiteTensor * tensor)202 const char* TfLiteTensorName(const TfLiteTensor* tensor) {
203 return tensor->name;
204 }
205
TfLiteTensorQuantizationParams(const TfLiteTensor * tensor)206 TfLiteQuantizationParams TfLiteTensorQuantizationParams(
207 const TfLiteTensor* tensor) {
208 return tensor->params;
209 }
210
TfLiteTensorCopyFromBuffer(TfLiteTensor * tensor,const void * input_data,size_t input_data_size)211 TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor* tensor,
212 const void* input_data,
213 size_t input_data_size) {
214 if (tensor->bytes != input_data_size) {
215 return kTfLiteError;
216 }
217 memcpy(tensor->data.raw, input_data, input_data_size);
218 return kTfLiteOk;
219 }
220
TfLiteTensorCopyToBuffer(const TfLiteTensor * tensor,void * output_data,size_t output_data_size)221 TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor,
222 void* output_data,
223 size_t output_data_size) {
224 if (tensor->bytes != output_data_size) {
225 return kTfLiteError;
226 }
227 memcpy(output_data, tensor->data.raw, output_data_size);
228 return kTfLiteOk;
229 }
230
231 // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
232
233 } // extern "C"
234
235 namespace tflite {
236 namespace internal {
237
InterpreterCreateWithOpResolver(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options,tflite::MutableOpResolver * mutable_resolver)238 TfLiteInterpreter* InterpreterCreateWithOpResolver(
239 const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options,
240 tflite::MutableOpResolver* mutable_resolver) {
241 TFLITE_DCHECK_NE(mutable_resolver, nullptr);
242 if (!model || !model->impl) {
243 return nullptr;
244 }
245
246 std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
247 if (optional_options &&
248 optional_options->error_reporter_callback.error_reporter != nullptr) {
249 optional_error_reporter.reset(
250 new CallbackErrorReporter(optional_options->error_reporter_callback));
251 }
252
253 // By default, we use the provided mutable_op_resolver, adding any builtin or
254 // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
255 // `TfLiteInterpreterOptionsAddCustomOp`.
256 tflite::OpResolver* op_resolver = mutable_resolver;
257 if (optional_options) {
258 mutable_resolver->AddAll(optional_options->mutable_op_resolver);
259 }
260 // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
261 // a non-null callback parameter, then we instead use a
262 // `CallbackOpResolver` that will forward to the callbacks provided there.
263 CallbackOpResolver callback_op_resolver;
264 if (optional_options &&
265 (optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
266 optional_options->op_resolver_callbacks.find_custom_op != nullptr)) {
267 callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
268 op_resolver = &callback_op_resolver;
269 }
270
271 tflite::ErrorReporter* error_reporter = optional_error_reporter
272 ? optional_error_reporter.get()
273 : tflite::DefaultErrorReporter();
274 tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
275 error_reporter);
276
277 std::unique_ptr<tflite::Interpreter> interpreter;
278 if (builder(&interpreter) != kTfLiteOk) {
279 return nullptr;
280 }
281
282 if (optional_options) {
283 if (optional_options->num_threads !=
284 TfLiteInterpreterOptions::kDefaultNumThreads) {
285 interpreter->SetNumThreads(optional_options->num_threads);
286 }
287
288 if (optional_options->use_nnapi) {
289 if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) !=
290 kTfLiteOk) {
291 return nullptr;
292 }
293 }
294
295 for (auto* delegate : optional_options->delegates) {
296 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
297 return nullptr;
298 }
299 }
300 }
301
302 bool enable_delegate_fallback =
303 optional_options != nullptr && optional_options->enable_delegate_fallback;
304
305 return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
306 std::move(interpreter),
307 enable_delegate_fallback};
308 }
309
310 } // namespace internal
311 } // namespace tflite
312