1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16 
17 #include <sstream>
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/lite/interpreter.h"
22 #include "tensorflow/lite/kernels/register.h"
23 #include "tensorflow/lite/model.h"
24 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
25 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
26 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
27 #include "tensorflow/lite/string_util.h"
28 
29 #define TFLITE_PY_CHECK(x)               \
30   if ((x) != kTfLiteOk) {                \
31     return error_reporter_->exception(); \
32   }
33 
34 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i)                                    \
35   if (i >= interpreter_->tensors_size() || i < 0) {                         \
36     PyErr_Format(PyExc_ValueError,                                          \
37                  "Invalid tensor index %d exceeds max tensor index %lu", i, \
38                  interpreter_->tensors_size());                             \
39     return nullptr;                                                         \
40   }
41 
42 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
43   if (!interpreter_) {                                                     \
44     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
45     return nullptr;                                                        \
46   }
47 
48 namespace tflite {
49 namespace interpreter_wrapper {
50 
51 namespace {
52 
53 using python_utils::PyDecrefDeleter;
54 
CreateInterpreter(const tflite::FlatBufferModel * model,const tflite::ops::builtin::BuiltinOpResolver & resolver)55 std::unique_ptr<tflite::Interpreter> CreateInterpreter(
56     const tflite::FlatBufferModel* model,
57     const tflite::ops::builtin::BuiltinOpResolver& resolver) {
58   if (!model) {
59     return nullptr;
60   }
61 
62   ::tflite::python::ImportNumpy();
63 
64   std::unique_ptr<tflite::Interpreter> interpreter;
65   if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
66     return nullptr;
67   }
68   return interpreter;
69 }
70 
PyArrayFromIntVector(const int * data,npy_intp size)71 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
72   void* pydata = malloc(size * sizeof(int));
73   memcpy(pydata, data, size * sizeof(int));
74   return PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
75 }
76 
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)77 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
78   PyObject* result = PyTuple_New(2);
79   PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
80   PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
81   return result;
82 }
83 
84 }  // namespace
85 
CreateInterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::string * error_msg)86 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
87     std::unique_ptr<tflite::FlatBufferModel> model,
88     std::unique_ptr<PythonErrorReporter> error_reporter,
89     std::string* error_msg) {
90   if (!model) {
91     *error_msg = error_reporter->message();
92     return nullptr;
93   }
94 
95   auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
96   auto interpreter = CreateInterpreter(model.get(), *resolver);
97   if (!interpreter) {
98     *error_msg = error_reporter->message();
99     return nullptr;
100   }
101 
102   InterpreterWrapper* wrapper =
103       new InterpreterWrapper(std::move(model), std::move(error_reporter),
104                              std::move(resolver), std::move(interpreter));
105   return wrapper;
106 }
107 
InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<tflite::Interpreter> interpreter)108 InterpreterWrapper::InterpreterWrapper(
109     std::unique_ptr<tflite::FlatBufferModel> model,
110     std::unique_ptr<PythonErrorReporter> error_reporter,
111     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
112     std::unique_ptr<tflite::Interpreter> interpreter)
113     : model_(std::move(model)),
114       error_reporter_(std::move(error_reporter)),
115       resolver_(std::move(resolver)),
116       interpreter_(std::move(interpreter)) {}
117 
~InterpreterWrapper()118 InterpreterWrapper::~InterpreterWrapper() {}
119 
AllocateTensors()120 PyObject* InterpreterWrapper::AllocateTensors() {
121   TFLITE_PY_ENSURE_VALID_INTERPRETER();
122   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
123   Py_RETURN_NONE;
124 }
125 
Invoke()126 PyObject* InterpreterWrapper::Invoke() {
127   TFLITE_PY_ENSURE_VALID_INTERPRETER();
128   TFLITE_PY_CHECK(interpreter_->Invoke());
129   Py_RETURN_NONE;
130 }
131 
InputIndices() const132 PyObject* InterpreterWrapper::InputIndices() const {
133   TFLITE_PY_ENSURE_VALID_INTERPRETER();
134   PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
135                                             interpreter_->inputs().size());
136 
137   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
138 }
139 
OutputIndices() const140 PyObject* InterpreterWrapper::OutputIndices() const {
141   PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
142                                             interpreter_->outputs().size());
143 
144   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
145 }
146 
ResizeInputTensor(int i,PyObject * value)147 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
148   TFLITE_PY_ENSURE_VALID_INTERPRETER();
149 
150   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
151       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
152   if (!array_safe) {
153     PyErr_SetString(PyExc_ValueError,
154                     "Failed to convert numpy value into readable tensor.");
155     return nullptr;
156   }
157 
158   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
159 
160   if (PyArray_NDIM(array) != 1) {
161     PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
162                  PyArray_NDIM(array));
163     return nullptr;
164   }
165 
166   if (PyArray_TYPE(array) != NPY_INT32) {
167     PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
168                  PyArray_TYPE(array));
169     return nullptr;
170   }
171 
172   std::vector<int> dims(PyArray_SHAPE(array)[0]);
173   memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
174 
175   TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
176   Py_RETURN_NONE;
177 }
178 
NumTensors() const179 int InterpreterWrapper::NumTensors() const {
180   if (!interpreter_) {
181     return 0;
182   }
183   return interpreter_->tensors_size();
184 }
185 
TensorName(int i) const186 std::string InterpreterWrapper::TensorName(int i) const {
187   if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
188     return "";
189   }
190 
191   const TfLiteTensor* tensor = interpreter_->tensor(i);
192   return tensor->name ? tensor->name : "";
193 }
194 
TensorType(int i) const195 PyObject* InterpreterWrapper::TensorType(int i) const {
196   TFLITE_PY_ENSURE_VALID_INTERPRETER();
197   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
198 
199   const TfLiteTensor* tensor = interpreter_->tensor(i);
200   if (tensor->type == kTfLiteNoType) {
201     PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
202     return nullptr;
203   }
204 
205   int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
206   if (code == -1) {
207     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
208     return nullptr;
209   }
210   return PyArray_TypeObjectFromType(code);
211 }
212 
TensorSize(int i) const213 PyObject* InterpreterWrapper::TensorSize(int i) const {
214   TFLITE_PY_ENSURE_VALID_INTERPRETER();
215   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
216 
217   const TfLiteTensor* tensor = interpreter_->tensor(i);
218   if (tensor->dims == nullptr) {
219     PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
220     return nullptr;
221   }
222   PyObject* np_array =
223       PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
224 
225   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
226 }
227 
TensorQuantization(int i) const228 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
229   TFLITE_PY_ENSURE_VALID_INTERPRETER();
230   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
231   const TfLiteTensor* tensor = interpreter_->tensor(i);
232   return PyTupleFromQuantizationParam(tensor->params);
233 }
234 
SetTensor(int i,PyObject * value)235 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
236   TFLITE_PY_ENSURE_VALID_INTERPRETER();
237   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
238 
239   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
240       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
241   if (!array_safe) {
242     PyErr_SetString(PyExc_ValueError,
243                     "Failed to convert value into readable tensor.");
244     return nullptr;
245   }
246 
247   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
248   TfLiteTensor* tensor = interpreter_->tensor(i);
249 
250   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
251     PyErr_Format(PyExc_ValueError,
252                  "Cannot set tensor:"
253                  " Got tensor of type %d"
254                  " but expected type %d for input %d ",
255                  python_utils::TfLiteTypeFromPyArray(array), tensor->type, i);
256     return nullptr;
257   }
258 
259   if (PyArray_NDIM(array) != tensor->dims->size) {
260     PyErr_Format(PyExc_ValueError,
261                  "Cannot set tensor: Dimension mismatch."
262                  " Got %d"
263                  " but expected %d for input %d.",
264                  PyArray_NDIM(array), tensor->dims->size, i);
265     return nullptr;
266   }
267 
268   for (int j = 0; j < PyArray_NDIM(array); j++) {
269     if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
270       PyErr_Format(PyExc_ValueError,
271                    "Cannot set tensor: Dimension mismatch."
272                    " Got %ld"
273                    " but expected %d for dimension %d of input %d.",
274                    PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
275       return nullptr;
276     }
277   }
278 
279   if (tensor->type != kTfLiteString) {
280     size_t size = PyArray_NBYTES(array);
281     if (size != tensor->bytes) {
282       PyErr_Format(PyExc_ValueError,
283                    "numpy array had %zu bytes but expected %zu bytes.", size,
284                    tensor->bytes);
285       return nullptr;
286     }
287     memcpy(tensor->data.raw, PyArray_DATA(array), size);
288   } else {
289     DynamicBuffer dynamic_buffer;
290     if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
291       return nullptr;
292     }
293     dynamic_buffer.WriteToTensor(tensor, nullptr);
294   }
295   Py_RETURN_NONE;
296 }
297 
298 namespace {
299 
300 // Checks to see if a tensor access can succeed (returns nullptr on error).
301 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num)302 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
303                              TfLiteTensor** tensor, int* type_num) {
304   TFLITE_PY_ENSURE_VALID_INTERPRETER();
305   TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
306 
307   *tensor = interpreter_->tensor(tensor_index);
308   if ((*tensor)->bytes == 0) {
309     PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
310     return nullptr;
311   }
312 
313   *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
314   if (*type_num == -1) {
315     PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
316     return nullptr;
317   }
318 
319   if (!(*tensor)->data.raw) {
320     PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
321     return nullptr;
322   }
323 
324   Py_RETURN_NONE;
325 }
326 
327 }  // namespace
328 
GetTensor(int i) const329 PyObject* InterpreterWrapper::GetTensor(int i) const {
330   // Sanity check accessor
331   TfLiteTensor* tensor = nullptr;
332   int type_num = 0;
333 
334   PyObject* check_result =
335       CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
336   if (check_result == nullptr) return check_result;
337   Py_XDECREF(check_result);
338 
339   std::vector<npy_intp> dims(tensor->dims->data,
340                              tensor->dims->data + tensor->dims->size);
341   if (tensor->type != kTfLiteString) {
342     // Make a buffer copy but we must tell Numpy It owns that data or else
343     // it will leak.
344     void* data = malloc(tensor->bytes);
345     if (!data) {
346       PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
347       return nullptr;
348     }
349     memcpy(data, tensor->data.raw, tensor->bytes);
350     PyObject* np_array =
351         PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
352     PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
353                         NPY_ARRAY_OWNDATA);
354     return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
355   } else {
356     // Create a C-order array so the data is contiguous in memory.
357     const int32_t kCOrder = 0;
358     PyObject* py_object =
359         PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
360 
361     if (py_object == nullptr) {
362       PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
363       return nullptr;
364     }
365 
366     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
367     PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
368     auto num_strings = GetStringCount(tensor->data.raw);
369     for (int j = 0; j < num_strings; ++j) {
370       auto ref = GetString(tensor->data.raw, j);
371 
372       PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
373       if (bytes == nullptr) {
374         Py_DECREF(py_object);
375         PyErr_Format(PyExc_ValueError,
376                      "Could not create PyBytes from string %d of input %d.", j,
377                      i);
378         return nullptr;
379       }
380       // PyArray_EMPTY produces an array full of Py_None, which we must decref.
381       Py_DECREF(data[j]);
382       data[j] = bytes;
383     }
384     return py_object;
385   }
386 }
387 
tensor(PyObject * base_object,int i)388 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
389   // Sanity check accessor
390   TfLiteTensor* tensor = nullptr;
391   int type_num = 0;
392 
393   PyObject* check_result =
394       CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
395   if (check_result == nullptr) return check_result;
396   Py_XDECREF(check_result);
397 
398   std::vector<npy_intp> dims(tensor->dims->data,
399                              tensor->dims->data + tensor->dims->size);
400   PyArrayObject* np_array =
401       reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
402           dims.size(), dims.data(), type_num, tensor->data.raw));
403   Py_INCREF(base_object);  // SetBaseObject steals, so we need to add.
404   PyArray_SetBaseObject(np_array, base_object);
405   return PyArray_Return(np_array);
406 }
407 
CreateWrapperCPPFromFile(const char * model_path,std::string * error_msg)408 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
409     const char* model_path, std::string* error_msg) {
410   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
411   std::unique_ptr<tflite::FlatBufferModel> model =
412       tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
413   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
414                                   error_msg);
415 }
416 
CreateWrapperCPPFromBuffer(PyObject * data,std::string * error_msg)417 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
418     PyObject* data, std::string* error_msg) {
419   char * buf = nullptr;
420   Py_ssize_t length;
421   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
422 
423   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
424     return nullptr;
425   }
426   std::unique_ptr<tflite::FlatBufferModel> model =
427       tflite::FlatBufferModel::BuildFromBuffer(buf, length,
428                                                error_reporter.get());
429   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
430                                   error_msg);
431 }
432 
ResetVariableTensors()433 PyObject* InterpreterWrapper::ResetVariableTensors() {
434   TFLITE_PY_ENSURE_VALID_INTERPRETER();
435   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
436   Py_RETURN_NONE;
437 }
438 
439 }  // namespace interpreter_wrapper
440 }  // namespace tflite
441