1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16 
17 #include <stdarg.h>
18 
19 #include <functional>
20 #include <sstream>
21 #include <string>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
31 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
32 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
33 #include "tensorflow/lite/shared_library.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/util.h"
36 
37 #define TFLITE_PY_CHECK(x)               \
38   if ((x) != kTfLiteOk) {                \
39     return error_reporter_->exception(); \
40   }
41 
42 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i)                                    \
43   if (i >= interpreter_->tensors_size() || i < 0) {                         \
44     PyErr_Format(PyExc_ValueError,                                          \
45                  "Invalid tensor index %d exceeds max tensor index %lu", i, \
46                  interpreter_->tensors_size());                             \
47     return nullptr;                                                         \
48   }
49 
50 #define TFLITE_PY_NODES_BOUNDS_CHECK(i)                   \
51   if (i >= interpreter_->nodes_size() || i < 0) {         \
52     PyErr_Format(PyExc_ValueError, "Invalid node index"); \
53     return nullptr;                                       \
54   }
55 
56 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
57   if (!interpreter_) {                                                     \
58     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
59     return nullptr;                                                        \
60   }
61 
62 namespace tflite {
63 namespace interpreter_wrapper {
64 
65 namespace {
66 
67 using python_utils::PyDecrefDeleter;
68 
CreateInterpreter(const InterpreterWrapper::Model * model,const tflite::ops::builtin::BuiltinOpResolver & resolver)69 std::unique_ptr<Interpreter> CreateInterpreter(
70     const InterpreterWrapper::Model* model,
71     const tflite::ops::builtin::BuiltinOpResolver& resolver) {
72   if (!model) {
73     return nullptr;
74   }
75 
76   ::tflite::python::ImportNumpy();
77 
78   std::unique_ptr<Interpreter> interpreter;
79   if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
80     return nullptr;
81   }
82   return interpreter;
83 }
84 
PyArrayFromFloatVector(const float * data,npy_intp size)85 PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
86   void* pydata = malloc(size * sizeof(float));
87   memcpy(pydata, data, size * sizeof(float));
88   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
89   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
90   return obj;
91 }
92 
PyArrayFromIntVector(const int * data,npy_intp size)93 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
94   void* pydata = malloc(size * sizeof(int));
95   memcpy(pydata, data, size * sizeof(int));
96   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
97   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
98   return obj;
99 }
100 
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)101 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
102   PyObject* result = PyTuple_New(2);
103   PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
104   PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
105   return result;
106 }
107 
PyDictFromSparsityParam(const TfLiteSparsity & param)108 PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
109   PyObject* result = PyDict_New();
110   PyDict_SetItemString(result, "traversal_order",
111                        PyArrayFromIntVector(param.traversal_order->data,
112                                             param.traversal_order->size));
113   PyDict_SetItemString(
114       result, "block_map",
115       PyArrayFromIntVector(param.block_map->data, param.block_map->size));
116   PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
117   for (int i = 0; i < param.dim_metadata_size; i++) {
118     PyObject* dim_metadata_i = PyDict_New();
119     if (param.dim_metadata[i].format == kTfLiteDimDense) {
120       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
121       PyDict_SetItemString(dim_metadata_i, "dense_size",
122                            PyLong_FromSize_t(param.dim_metadata[i].dense_size));
123     } else {
124       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
125       const auto* array_segments = param.dim_metadata[i].array_segments;
126       const auto* array_indices = param.dim_metadata[i].array_indices;
127       PyDict_SetItemString(
128           dim_metadata_i, "array_segments",
129           PyArrayFromIntVector(array_segments->data, array_segments->size));
130       PyDict_SetItemString(
131           dim_metadata_i, "array_indices",
132           PyArrayFromIntVector(array_indices->data, array_indices->size));
133     }
134     PyList_SetItem(dim_metadata, i, dim_metadata_i);
135   }
136   PyDict_SetItemString(result, "dim_metadata", dim_metadata);
137   return result;
138 }
139 
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver,std::string * error_msg)140 bool RegisterCustomOpByName(const char* registerer_name,
141                             tflite::MutableOpResolver* resolver,
142                             std::string* error_msg) {
143   // Registerer functions take a pointer to a BuiltinOpResolver as an input
144   // parameter and return void.
145   // TODO(b/137576229): We should implement this functionality in a more
146   // principled way.
147   typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
148 
149   // Look for the Registerer function by name.
150   RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
151       SharedLibrary::GetSymbol(registerer_name));
152 
153   // Fail in an informative way if the function was not found.
154   if (registerer == nullptr) {
155     *error_msg =
156         absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
157                         registerer_name, SharedLibrary::GetError());
158     return false;
159   }
160 
161   // Call the registerer with the resolver.
162   registerer(resolver);
163   return true;
164 }
165 
166 }  // namespace
167 
CreateInterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)168 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
169     std::unique_ptr<InterpreterWrapper::Model> model,
170     std::unique_ptr<PythonErrorReporter> error_reporter,
171     const std::vector<std::string>& registerers_by_name,
172     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
173     std::string* error_msg) {
174   if (!model) {
175     *error_msg = error_reporter->message();
176     return nullptr;
177   }
178 
179   auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
180   for (const auto& registerer : registerers_by_name) {
181     if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
182       return nullptr;
183   }
184   for (const auto& registerer : registerers_by_func) {
185     registerer(reinterpret_cast<uintptr_t>(resolver.get()));
186   }
187   auto interpreter = CreateInterpreter(model.get(), *resolver);
188   if (!interpreter) {
189     *error_msg = error_reporter->message();
190     return nullptr;
191   }
192 
193   InterpreterWrapper* wrapper =
194       new InterpreterWrapper(std::move(model), std::move(error_reporter),
195                              std::move(resolver), std::move(interpreter));
196   return wrapper;
197 }
198 
InterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<Interpreter> interpreter)199 InterpreterWrapper::InterpreterWrapper(
200     std::unique_ptr<InterpreterWrapper::Model> model,
201     std::unique_ptr<PythonErrorReporter> error_reporter,
202     std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
203     std::unique_ptr<Interpreter> interpreter)
204     : model_(std::move(model)),
205       error_reporter_(std::move(error_reporter)),
206       resolver_(std::move(resolver)),
207       interpreter_(std::move(interpreter)) {}
208 
~InterpreterWrapper()209 InterpreterWrapper::~InterpreterWrapper() {}
210 
AllocateTensors()211 PyObject* InterpreterWrapper::AllocateTensors() {
212   TFLITE_PY_ENSURE_VALID_INTERPRETER();
213   TFLITE_PY_CHECK(interpreter_->AllocateTensors());
214   Py_RETURN_NONE;
215 }
216 
Invoke()217 PyObject* InterpreterWrapper::Invoke() {
218   TFLITE_PY_ENSURE_VALID_INTERPRETER();
219 
220   // Release the GIL so that we can run multiple interpreters in parallel
221   TfLiteStatus status_code = kTfLiteOk;
222   Py_BEGIN_ALLOW_THREADS;  // To return can happen between this and end!
223   status_code = interpreter_->Invoke();
224   Py_END_ALLOW_THREADS;
225 
226   TFLITE_PY_CHECK(
227       status_code);  // don't move this into the Py_BEGIN/Py_End block
228 
229   Py_RETURN_NONE;
230 }
231 
InputIndices() const232 PyObject* InterpreterWrapper::InputIndices() const {
233   TFLITE_PY_ENSURE_VALID_INTERPRETER();
234   PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
235                                             interpreter_->inputs().size());
236 
237   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
238 }
239 
OutputIndices() const240 PyObject* InterpreterWrapper::OutputIndices() const {
241   PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
242                                             interpreter_->outputs().size());
243 
244   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
245 }
246 
ResizeInputTensorImpl(int i,PyObject * value)247 PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
248   TFLITE_PY_ENSURE_VALID_INTERPRETER();
249 
250   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
251       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
252   if (!array_safe) {
253     PyErr_SetString(PyExc_ValueError,
254                     "Failed to convert numpy value into readable tensor.");
255     return nullptr;
256   }
257 
258   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
259 
260   if (PyArray_NDIM(array) != 1) {
261     PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
262                  PyArray_NDIM(array));
263     return nullptr;
264   }
265 
266   if (PyArray_TYPE(array) != NPY_INT32) {
267     PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
268                  PyArray_TYPE(array));
269     return nullptr;
270   }
271 
272   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
273                       NPY_ARRAY_OWNDATA);
274   return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
275 }
276 
ResizeInputTensor(int i,PyObject * value,bool strict)277 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
278                                                 bool strict) {
279   PyArrayObject* array =
280       reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
281   if (array == nullptr) {
282     return nullptr;
283   }
284 
285   std::vector<int> dims(PyArray_SHAPE(array)[0]);
286   memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
287 
288   if (strict) {
289     TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims));
290   } else {
291     TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
292   }
293   Py_RETURN_NONE;
294 }
295 
NumTensors() const296 int InterpreterWrapper::NumTensors() const {
297   if (!interpreter_) {
298     return 0;
299   }
300   return interpreter_->tensors_size();
301 }
302 
TensorName(int i) const303 std::string InterpreterWrapper::TensorName(int i) const {
304   if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
305     return "";
306   }
307 
308   const TfLiteTensor* tensor = interpreter_->tensor(i);
309   return tensor->name ? tensor->name : "";
310 }
311 
TensorType(int i) const312 PyObject* InterpreterWrapper::TensorType(int i) const {
313   TFLITE_PY_ENSURE_VALID_INTERPRETER();
314   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
315 
316   const TfLiteTensor* tensor = interpreter_->tensor(i);
317   if (tensor->type == kTfLiteNoType) {
318     PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
319     return nullptr;
320   }
321 
322   int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
323   if (code == -1) {
324     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
325     return nullptr;
326   }
327   return PyArray_TypeObjectFromType(code);
328 }
329 
TensorSize(int i) const330 PyObject* InterpreterWrapper::TensorSize(int i) const {
331   TFLITE_PY_ENSURE_VALID_INTERPRETER();
332   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
333 
334   const TfLiteTensor* tensor = interpreter_->tensor(i);
335   if (tensor->dims == nullptr) {
336     PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
337     return nullptr;
338   }
339   PyObject* np_array =
340       PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
341 
342   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
343 }
344 
TensorSizeSignature(int i) const345 PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
346   TFLITE_PY_ENSURE_VALID_INTERPRETER();
347   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
348 
349   const TfLiteTensor* tensor = interpreter_->tensor(i);
350   const int32_t* size_signature_data = nullptr;
351   int32_t size_signature_size = 0;
352   if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
353     size_signature_data = tensor->dims_signature->data;
354     size_signature_size = tensor->dims_signature->size;
355   } else {
356     size_signature_data = tensor->dims->data;
357     size_signature_size = tensor->dims->size;
358   }
359   PyObject* np_array =
360       PyArrayFromIntVector(size_signature_data, size_signature_size);
361 
362   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
363 }
364 
TensorSparsityParameters(int i) const365 PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
366   TFLITE_PY_ENSURE_VALID_INTERPRETER();
367   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
368   const TfLiteTensor* tensor = interpreter_->tensor(i);
369   if (tensor->sparsity == nullptr) {
370     return PyDict_New();
371   }
372 
373   return PyDictFromSparsityParam(*tensor->sparsity);
374 }
375 
TensorQuantization(int i) const376 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
377   TFLITE_PY_ENSURE_VALID_INTERPRETER();
378   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
379   const TfLiteTensor* tensor = interpreter_->tensor(i);
380   return PyTupleFromQuantizationParam(tensor->params);
381 }
382 
TensorQuantizationParameters(int i) const383 PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
384   TFLITE_PY_ENSURE_VALID_INTERPRETER();
385   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
386   const TfLiteTensor* tensor = interpreter_->tensor(i);
387   const TfLiteQuantization quantization = tensor->quantization;
388   float* scales_data = nullptr;
389   int32_t* zero_points_data = nullptr;
390   int32_t scales_size = 0;
391   int32_t zero_points_size = 0;
392   int32_t quantized_dimension = 0;
393   if (quantization.type == kTfLiteAffineQuantization) {
394     const TfLiteAffineQuantization* q_params =
395         reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
396     if (q_params->scale) {
397       scales_data = q_params->scale->data;
398       scales_size = q_params->scale->size;
399     }
400     if (q_params->zero_point) {
401       zero_points_data = q_params->zero_point->data;
402       zero_points_size = q_params->zero_point->size;
403     }
404     quantized_dimension = q_params->quantized_dimension;
405   }
406   PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
407   PyObject* zero_points_array =
408       PyArrayFromIntVector(zero_points_data, zero_points_size);
409 
410   PyObject* result = PyTuple_New(3);
411   PyTuple_SET_ITEM(result, 0, scales_array);
412   PyTuple_SET_ITEM(result, 1, zero_points_array);
413   PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
414   return result;
415 }
416 
SetTensor(int i,PyObject * value)417 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
418   TFLITE_PY_ENSURE_VALID_INTERPRETER();
419   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
420 
421   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
422       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
423   if (!array_safe) {
424     PyErr_SetString(PyExc_ValueError,
425                     "Failed to convert value into readable tensor.");
426     return nullptr;
427   }
428 
429   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
430   TfLiteTensor* tensor = interpreter_->tensor(i);
431 
432   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
433     PyErr_Format(PyExc_ValueError,
434                  "Cannot set tensor:"
435                  " Got value of type %s"
436                  " but expected type %s for input %d, name: %s ",
437                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
438                  TfLiteTypeGetName(tensor->type), i, tensor->name);
439     return nullptr;
440   }
441 
442   if (PyArray_NDIM(array) != tensor->dims->size) {
443     PyErr_Format(PyExc_ValueError,
444                  "Cannot set tensor: Dimension mismatch."
445                  " Got %d"
446                  " but expected %d for input %d.",
447                  PyArray_NDIM(array), tensor->dims->size, i);
448     return nullptr;
449   }
450 
451   for (int j = 0; j < PyArray_NDIM(array); j++) {
452     if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
453       PyErr_Format(PyExc_ValueError,
454                    "Cannot set tensor: Dimension mismatch."
455                    " Got %ld"
456                    " but expected %d for dimension %d of input %d.",
457                    PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
458       return nullptr;
459     }
460   }
461 
462   if (tensor->type != kTfLiteString) {
463     if (tensor->data.raw == nullptr) {
464       PyErr_Format(PyExc_ValueError,
465                    "Cannot set tensor:"
466                    " Tensor is unallocated. Try calling allocate_tensors()"
467                    " first");
468       return nullptr;
469     }
470 
471     size_t size = PyArray_NBYTES(array);
472     if (size != tensor->bytes) {
473       PyErr_Format(PyExc_ValueError,
474                    "numpy array had %zu bytes but expected %zu bytes.", size,
475                    tensor->bytes);
476       return nullptr;
477     }
478     memcpy(tensor->data.raw, PyArray_DATA(array), size);
479   } else {
480     DynamicBuffer dynamic_buffer;
481     if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
482       return nullptr;
483     }
484     dynamic_buffer.WriteToTensor(tensor, nullptr);
485   }
486   Py_RETURN_NONE;
487 }
488 
NumNodes() const489 int InterpreterWrapper::NumNodes() const {
490   if (!interpreter_) {
491     return 0;
492   }
493   return interpreter_->nodes_size();
494 }
495 
NodeInputs(int i) const496 PyObject* InterpreterWrapper::NodeInputs(int i) const {
497   TFLITE_PY_ENSURE_VALID_INTERPRETER();
498   TFLITE_PY_NODES_BOUNDS_CHECK(i);
499 
500   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
501   PyObject* inputs =
502       PyArrayFromIntVector(node->inputs->data, node->inputs->size);
503   return inputs;
504 }
505 
NodeOutputs(int i) const506 PyObject* InterpreterWrapper::NodeOutputs(int i) const {
507   TFLITE_PY_ENSURE_VALID_INTERPRETER();
508   TFLITE_PY_NODES_BOUNDS_CHECK(i);
509 
510   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
511   PyObject* outputs =
512       PyArrayFromIntVector(node->outputs->data, node->outputs->size);
513   return outputs;
514 }
515 
NodeName(int i) const516 std::string InterpreterWrapper::NodeName(int i) const {
517   if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
518     return "";
519   }
520   // Get op name from registration
521   const TfLiteRegistration* node_registration =
522       &(interpreter_->node_and_registration(i)->second);
523   int32_t op_code = node_registration->builtin_code;
524   std::string op_name;
525   if (op_code == tflite::BuiltinOperator_CUSTOM) {
526     const char* custom_name = node_registration->custom_name;
527     op_name = custom_name ? custom_name : "UnknownCustomOp";
528   } else {
529     op_name = tflite::EnumNamesBuiltinOperator()[op_code];
530   }
531   std::string op_name_str(op_name);
532   return op_name_str;
533 }
534 
535 namespace {
536 
537 // Checks to see if a tensor access can succeed (returns nullptr on error).
538 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num)539 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
540                              TfLiteTensor** tensor, int* type_num) {
541   TFLITE_PY_ENSURE_VALID_INTERPRETER();
542   TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
543 
544   *tensor = interpreter_->tensor(tensor_index);
545   if ((*tensor)->bytes == 0) {
546     PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
547     return nullptr;
548   }
549 
550   *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
551   if (*type_num == -1) {
552     PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
553     return nullptr;
554   }
555 
556   if (!(*tensor)->data.raw) {
557     PyErr_SetString(PyExc_ValueError,
558                     "Tensor data is null."
559                     " Run allocate_tensors() first");
560     return nullptr;
561   }
562 
563   Py_RETURN_NONE;
564 }
565 
566 }  // namespace
567 
GetSignatureDefs() const568 PyObject* InterpreterWrapper::GetSignatureDefs() const {
569   PyObject* result = PyDict_New();
570   for (const auto& sig_def_name : interpreter_->signature_def_names()) {
571     PyObject* signature_def = PyDict_New();
572     PyObject* inputs = PyDict_New();
573     PyObject* outputs = PyDict_New();
574     const auto& signature_def_inputs =
575         interpreter_->signature_inputs(sig_def_name->c_str());
576     const auto& signature_def_outputs =
577         interpreter_->signature_outputs(sig_def_name->c_str());
578     for (const auto& input : signature_def_inputs) {
579       PyDict_SetItemString(inputs, input.first.c_str(),
580                            PyLong_FromLong(input.second));
581     }
582     for (const auto& output : signature_def_outputs) {
583       PyDict_SetItemString(outputs, output.first.c_str(),
584                            PyLong_FromLong(output.second));
585     }
586 
587     PyDict_SetItemString(signature_def, "inputs", inputs);
588     PyDict_SetItemString(signature_def, "outputs", outputs);
589     PyDict_SetItemString(result, sig_def_name->c_str(), signature_def);
590   }
591   return result;
592 }
593 
GetOutputTensorFromSignatureDefName(const char * output_name,const char * method_name) const594 PyObject* InterpreterWrapper::GetOutputTensorFromSignatureDefName(
595     const char* output_name, const char* method_name) const {
596   const auto& outputs = interpreter_->signature_outputs(method_name);
597   const auto& output = outputs.find(output_name);
598   if (output == outputs.end()) return nullptr;
599   return GetTensor(output->second);
600 }
601 
SetInputTensorFromSignatureDefName(const char * input_name,const char * method_name,PyObject * value)602 PyObject* InterpreterWrapper::SetInputTensorFromSignatureDefName(
603     const char* input_name, const char* method_name, PyObject* value) {
604   const auto& inputs = interpreter_->signature_inputs(method_name);
605   const auto& input = inputs.find(input_name);
606   if (input == inputs.end()) return nullptr;
607   return SetTensor(input->second, value);
608 }
609 
GetTensor(int i) const610 PyObject* InterpreterWrapper::GetTensor(int i) const {
611   // Sanity check accessor
612   TfLiteTensor* tensor = nullptr;
613   int type_num = 0;
614 
615   PyObject* check_result =
616       CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
617   if (check_result == nullptr) return check_result;
618   Py_XDECREF(check_result);
619 
620   std::vector<npy_intp> dims(tensor->dims->data,
621                              tensor->dims->data + tensor->dims->size);
622   if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
623       tensor->type != kTfLiteVariant) {
624     // Make a buffer copy but we must tell Numpy It owns that data or else
625     // it will leak.
626     void* data = malloc(tensor->bytes);
627     if (!data) {
628       PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
629       return nullptr;
630     }
631     memcpy(data, tensor->data.raw, tensor->bytes);
632     PyObject* np_array;
633     if (tensor->sparsity == nullptr) {
634       np_array =
635           PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
636     } else {
637       std::vector<npy_intp> sparse_buffer_dims(1);
638       size_t size_of_type;
639       if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
640         PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
641         free(data);
642         return nullptr;
643       }
644       sparse_buffer_dims[0] = tensor->bytes / size_of_type;
645       np_array = PyArray_SimpleNewFromData(
646           sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
647     }
648     PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
649                         NPY_ARRAY_OWNDATA);
650     return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
651   } else {
652     // Create a C-order array so the data is contiguous in memory.
653     const int32_t kCOrder = 0;
654     PyObject* py_object =
655         PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
656 
657     if (py_object == nullptr) {
658       PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
659       return nullptr;
660     }
661 
662     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
663     PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
664     auto num_strings = GetStringCount(tensor);
665     for (int j = 0; j < num_strings; ++j) {
666       auto ref = GetString(tensor, j);
667 
668       PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
669       if (bytes == nullptr) {
670         Py_DECREF(py_object);
671         PyErr_Format(PyExc_ValueError,
672                      "Could not create PyBytes from string %d of input %d.", j,
673                      i);
674         return nullptr;
675       }
676       // PyArray_EMPTY produces an array full of Py_None, which we must decref.
677       Py_DECREF(data[j]);
678       data[j] = bytes;
679     }
680     return py_object;
681   }
682 }
683 
tensor(PyObject * base_object,int i)684 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
685   // Sanity check accessor
686   TfLiteTensor* tensor = nullptr;
687   int type_num = 0;
688 
689   PyObject* check_result =
690       CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
691   if (check_result == nullptr) return check_result;
692   Py_XDECREF(check_result);
693 
694   std::vector<npy_intp> dims(tensor->dims->data,
695                              tensor->dims->data + tensor->dims->size);
696   PyArrayObject* np_array =
697       reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
698           dims.size(), dims.data(), type_num, tensor->data.raw));
699   Py_INCREF(base_object);  // SetBaseObject steals, so we need to add.
700   PyArray_SetBaseObject(np_array, base_object);
701   return PyArray_Return(np_array);
702 }
703 
CreateWrapperCPPFromFile(const char * model_path,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)704 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
705     const char* model_path, const std::vector<std::string>& registerers_by_name,
706     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
707     std::string* error_msg) {
708   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
709   std::unique_ptr<InterpreterWrapper::Model> model =
710       Model::BuildFromFile(model_path, error_reporter.get());
711   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
712                                   registerers_by_name, registerers_by_func,
713                                   error_msg);
714 }
715 
CreateWrapperCPPFromFile(const char * model_path,const std::vector<std::string> & registerers,std::string * error_msg)716 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
717     const char* model_path, const std::vector<std::string>& registerers,
718     std::string* error_msg) {
719   return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg);
720 }
721 
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)722 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
723     PyObject* data, const std::vector<std::string>& registerers_by_name,
724     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
725     std::string* error_msg) {
726   char* buf = nullptr;
727   Py_ssize_t length;
728   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
729 
730   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
731     return nullptr;
732   }
733   std::unique_ptr<InterpreterWrapper::Model> model =
734       Model::BuildFromBuffer(buf, length, error_reporter.get());
735   return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
736                                   registerers_by_name, registerers_by_func,
737                                   error_msg);
738 }
739 
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers,std::string * error_msg)740 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
741     PyObject* data, const std::vector<std::string>& registerers,
742     std::string* error_msg) {
743   return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg);
744 }
745 
ResetVariableTensors()746 PyObject* InterpreterWrapper::ResetVariableTensors() {
747   TFLITE_PY_ENSURE_VALID_INTERPRETER();
748   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
749   Py_RETURN_NONE;
750 }
751 
SetNumThreads(int num_threads)752 PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
753   TFLITE_PY_ENSURE_VALID_INTERPRETER();
754   interpreter_->SetNumThreads(num_threads);
755   Py_RETURN_NONE;
756 }
757 
ModifyGraphWithDelegate(TfLiteDelegate * delegate)758 PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
759     TfLiteDelegate* delegate) {
760   TFLITE_PY_ENSURE_VALID_INTERPRETER();
761   TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
762   Py_RETURN_NONE;
763 }
764 
765 }  // namespace interpreter_wrapper
766 }  // namespace tflite
767