1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16
17 #include <sstream>
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/lite/interpreter.h"
22 #include "tensorflow/lite/kernels/register.h"
23 #include "tensorflow/lite/model.h"
24 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
25 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
27 #include "tensorflow/lite/string_util.h"
28
29 #define TFLITE_PY_CHECK(x) \
30 if ((x) != kTfLiteOk) { \
31 return error_reporter_->exception(); \
32 }
33
34 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
35 if (i >= interpreter_->tensors_size() || i < 0) { \
36 PyErr_Format(PyExc_ValueError, \
37 "Invalid tensor index %d exceeds max tensor index %lu", i, \
38 interpreter_->tensors_size()); \
39 return nullptr; \
40 }
41
42 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
43 if (!interpreter_) { \
44 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
45 return nullptr; \
46 }
47
48 namespace tflite {
49 namespace interpreter_wrapper {
50
51 namespace {
52
53 using python_utils::PyDecrefDeleter;
54
CreateInterpreter(const tflite::FlatBufferModel * model,const tflite::ops::builtin::BuiltinOpResolver & resolver)55 std::unique_ptr<tflite::Interpreter> CreateInterpreter(
56 const tflite::FlatBufferModel* model,
57 const tflite::ops::builtin::BuiltinOpResolver& resolver) {
58 if (!model) {
59 return nullptr;
60 }
61
62 ::tflite::python::ImportNumpy();
63
64 std::unique_ptr<tflite::Interpreter> interpreter;
65 if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
66 return nullptr;
67 }
68 return interpreter;
69 }
70
PyArrayFromIntVector(const int * data,npy_intp size)71 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
72 void* pydata = malloc(size * sizeof(int));
73 memcpy(pydata, data, size * sizeof(int));
74 return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
75 }
76
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)77 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
78 PyObject* result = PyTuple_New(2);
79 PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
80 PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
81 return result;
82 }
83
84 } // namespace
85
CreateInterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::string * error_msg)86 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
87 std::unique_ptr<tflite::FlatBufferModel> model,
88 std::unique_ptr<PythonErrorReporter> error_reporter,
89 std::string* error_msg) {
90 if (!model) {
91 *error_msg = error_reporter->message();
92 return nullptr;
93 }
94
95 auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
96 auto interpreter = CreateInterpreter(model.get(), *resolver);
97 if (!interpreter) {
98 *error_msg = error_reporter->message();
99 return nullptr;
100 }
101
102 InterpreterWrapper* wrapper =
103 new InterpreterWrapper(std::move(model), std::move(error_reporter),
104 std::move(resolver), std::move(interpreter));
105 return wrapper;
106 }
107
InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::Interpreter> interpreter)108 InterpreterWrapper::InterpreterWrapper(
109 std::unique_ptr<tflite::FlatBufferModel> model,
110 std::unique_ptr<PythonErrorReporter> error_reporter,
111 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
112 std::unique_ptr<tflite::Interpreter> interpreter)
113 : model_(std::move(model)),
114 error_reporter_(std::move(error_reporter)),
115 resolver_(std::move(resolver)),
116 interpreter_(std::move(interpreter)) {}
117
~InterpreterWrapper()118 InterpreterWrapper::~InterpreterWrapper() {}
119
AllocateTensors()120 PyObject* InterpreterWrapper::AllocateTensors() {
121 TFLITE_PY_ENSURE_VALID_INTERPRETER();
122 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
123 Py_RETURN_NONE;
124 }
125
Invoke()126 PyObject* InterpreterWrapper::Invoke() {
127 TFLITE_PY_ENSURE_VALID_INTERPRETER();
128 TFLITE_PY_CHECK(interpreter_->Invoke());
129 Py_RETURN_NONE;
130 }
131
InputIndices() const132 PyObject* InterpreterWrapper::InputIndices() const {
133 TFLITE_PY_ENSURE_VALID_INTERPRETER();
134 PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
135 interpreter_->inputs().size());
136
137 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
138 }
139
OutputIndices() const140 PyObject* InterpreterWrapper::OutputIndices() const {
141 PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
142 interpreter_->outputs().size());
143
144 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
145 }
146
ResizeInputTensor(int i,PyObject * value)147 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
148 TFLITE_PY_ENSURE_VALID_INTERPRETER();
149
150 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
151 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
152 if (!array_safe) {
153 PyErr_SetString(PyExc_ValueError,
154 "Failed to convert numpy value into readable tensor.");
155 return nullptr;
156 }
157
158 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
159
160 if (PyArray_NDIM(array) != 1) {
161 PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
162 PyArray_NDIM(array));
163 return nullptr;
164 }
165
166 if (PyArray_TYPE(array) != NPY_INT32) {
167 PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
168 PyArray_TYPE(array));
169 return nullptr;
170 }
171
172 std::vector<int> dims(PyArray_SHAPE(array)[0]);
173 memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
174
175 TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
176 Py_RETURN_NONE;
177 }
178
NumTensors() const179 int InterpreterWrapper::NumTensors() const {
180 if (!interpreter_) {
181 return 0;
182 }
183 return interpreter_->tensors_size();
184 }
185
TensorName(int i) const186 std::string InterpreterWrapper::TensorName(int i) const {
187 if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
188 return "";
189 }
190
191 const TfLiteTensor* tensor = interpreter_->tensor(i);
192 return tensor->name ? tensor->name : "";
193 }
194
TensorType(int i) const195 PyObject* InterpreterWrapper::TensorType(int i) const {
196 TFLITE_PY_ENSURE_VALID_INTERPRETER();
197 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
198
199 const TfLiteTensor* tensor = interpreter_->tensor(i);
200 if (tensor->type == kTfLiteNoType) {
201 PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
202 return nullptr;
203 }
204
205 int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
206 if (code == -1) {
207 PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
208 return nullptr;
209 }
210 return PyArray_TypeObjectFromType(code);
211 }
212
TensorSize(int i) const213 PyObject* InterpreterWrapper::TensorSize(int i) const {
214 TFLITE_PY_ENSURE_VALID_INTERPRETER();
215 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
216
217 const TfLiteTensor* tensor = interpreter_->tensor(i);
218 if (tensor->dims == nullptr) {
219 PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
220 return nullptr;
221 }
222 PyObject* np_array =
223 PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
224
225 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
226 }
227
TensorQuantization(int i) const228 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
229 TFLITE_PY_ENSURE_VALID_INTERPRETER();
230 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
231 const TfLiteTensor* tensor = interpreter_->tensor(i);
232 return PyTupleFromQuantizationParam(tensor->params);
233 }
234
SetTensor(int i,PyObject * value)235 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
236 TFLITE_PY_ENSURE_VALID_INTERPRETER();
237 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
238
239 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
240 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
241 if (!array_safe) {
242 PyErr_SetString(PyExc_ValueError,
243 "Failed to convert value into readable tensor.");
244 return nullptr;
245 }
246
247 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
248 TfLiteTensor* tensor = interpreter_->tensor(i);
249
250 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
251 PyErr_Format(PyExc_ValueError,
252 "Cannot set tensor:"
253 " Got tensor of type %d"
254 " but expected type %d for input %d ",
255 python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
256 return nullptr;
257 }
258
259 if (PyArray_NDIM(array) != tensor->dims->size) {
260 PyErr_Format(PyExc_ValueError,
261 "Cannot set tensor: Dimension mismatch."
262 " Got %d"
263 " but expected %d for input %d.",
264 PyArray_NDIM(array), tensor->dims->size, i);
265 return nullptr;
266 }
267
268 for (int j = 0; j < PyArray_NDIM(array); j++) {
269 if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
270 PyErr_Format(PyExc_ValueError,
271 "Cannot set tensor: Dimension mismatch."
272 " Got %ld"
273 " but expected %d for dimension %d of input %d.",
274 PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
275 return nullptr;
276 }
277 }
278
279 if (tensor->type != kTfLiteString) {
280 size_t size = PyArray_NBYTES(array);
281 if (size != tensor->bytes) {
282 PyErr_Format(PyExc_ValueError,
283 "numpy array had %zu bytes but expected %zu bytes.", size,
284 tensor->bytes);
285 return nullptr;
286 }
287 memcpy(tensor->data.raw, PyArray_DATA(array), size);
288 } else {
289 DynamicBuffer dynamic_buffer;
290 if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
291 return nullptr;
292 }
293 dynamic_buffer.WriteToTensor(tensor, nullptr);
294 }
295 Py_RETURN_NONE;
296 }
297
298 namespace {
299
300 // Checks to see if a tensor access can succeed (returns nullptr on error).
301 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num)302 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
303 TfLiteTensor** tensor, int* type_num) {
304 TFLITE_PY_ENSURE_VALID_INTERPRETER();
305 TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
306
307 *tensor = interpreter_->tensor(tensor_index);
308 if ((*tensor)->bytes == 0) {
309 PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
310 return nullptr;
311 }
312
313 *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
314 if (*type_num == -1) {
315 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
316 return nullptr;
317 }
318
319 if (!(*tensor)->data.raw) {
320 PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
321 return nullptr;
322 }
323
324 Py_RETURN_NONE;
325 }
326
327 } // namespace
328
GetTensor(int i) const329 PyObject* InterpreterWrapper::GetTensor(int i) const {
330 // Sanity check accessor
331 TfLiteTensor* tensor = nullptr;
332 int type_num = 0;
333
334 PyObject* check_result =
335 CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
336 if (check_result == nullptr) return check_result;
337 Py_XDECREF(check_result);
338
339 std::vector<npy_intp> dims(tensor->dims->data,
340 tensor->dims->data + tensor->dims->size);
341 if (tensor->type != kTfLiteString) {
342 // Make a buffer copy but we must tell Numpy It owns that data or else
343 // it will leak.
344 void* data = malloc(tensor->bytes);
345 if (!data) {
346 PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
347 return nullptr;
348 }
349 memcpy(data, tensor->data.raw, tensor->bytes);
350 PyObject* np_array =
351 PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
352 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
353 NPY_ARRAY_OWNDATA);
354 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
355 } else {
356 // Create a C-order array so the data is contiguous in memory.
357 const int32_t kCOrder = 0;
358 PyObject* py_object =
359 PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
360
361 if (py_object == nullptr) {
362 PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
363 return nullptr;
364 }
365
366 PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
367 PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
368 auto num_strings = GetStringCount(tensor->data.raw);
369 for (int j = 0; j < num_strings; ++j) {
370 auto ref = GetString(tensor->data.raw, j);
371
372 PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
373 if (bytes == nullptr) {
374 Py_DECREF(py_object);
375 PyErr_Format(PyExc_ValueError,
376 "Could not create PyBytes from string %d of input %d.", j,
377 i);
378 return nullptr;
379 }
380 // PyArray_EMPTY produces an array full of Py_None, which we must decref.
381 Py_DECREF(data[j]);
382 data[j] = bytes;
383 }
384 return py_object;
385 }
386 }
387
tensor(PyObject * base_object,int i)388 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
389 // Sanity check accessor
390 TfLiteTensor* tensor = nullptr;
391 int type_num = 0;
392
393 PyObject* check_result =
394 CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
395 if (check_result == nullptr) return check_result;
396 Py_XDECREF(check_result);
397
398 std::vector<npy_intp> dims(tensor->dims->data,
399 tensor->dims->data + tensor->dims->size);
400 PyArrayObject* np_array =
401 reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
402 dims.size(), dims.data(), type_num, tensor->data.raw));
403 Py_INCREF(base_object); // SetBaseObject steals, so we need to add.
404 PyArray_SetBaseObject(np_array, base_object);
405 return PyArray_Return(np_array);
406 }
407
CreateWrapperCPPFromFile(const char * model_path,std::string * error_msg)408 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
409 const char* model_path, std::string* error_msg) {
410 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
411 std::unique_ptr<tflite::FlatBufferModel> model =
412 tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
413 return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
414 error_msg);
415 }
416
CreateWrapperCPPFromBuffer(PyObject * data,std::string * error_msg)417 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
418 PyObject* data, std::string* error_msg) {
419 char * buf = nullptr;
420 Py_ssize_t length;
421 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
422
423 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
424 return nullptr;
425 }
426 std::unique_ptr<tflite::FlatBufferModel> model =
427 tflite::FlatBufferModel::BuildFromBuffer(buf, length,
428 error_reporter.get());
429 return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
430 error_msg);
431 }
432
ResetVariableTensors()433 PyObject* InterpreterWrapper::ResetVariableTensors() {
434 TFLITE_PY_ENSURE_VALID_INTERPRETER();
435 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
436 Py_RETURN_NONE;
437 }
438
439 } // namespace interpreter_wrapper
440 } // namespace tflite
441