1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Python TF-Lite interpreter."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ctypes
22import platform
23import sys
24import os
25
26import numpy as np
27
28# pylint: disable=g-import-not-at-top
29if not os.path.splitext(__file__)[0].endswith(
30    os.path.join('tflite_runtime', 'interpreter')):
31  # This file is part of tensorflow package.
32  from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper
33  from tensorflow.python.util.tf_export import tf_export as _tf_export
34else:
35  # This file is part of tflite_runtime package.
36  from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper
37
38  def _tf_export(*x, **kwargs):
39    del x, kwargs
40    return lambda x: x
41
42
43class Delegate(object):
44  """Python wrapper class to manage TfLiteDelegate objects.
45
46  The shared library is expected to have two functions:
47    TfLiteDelegate* tflite_plugin_create_delegate(
48        char**, char**, size_t, void (*report_error)(const char *))
49    void tflite_plugin_destroy_delegate(TfLiteDelegate*)
50
51  The first one creates a delegate object. It may return NULL to indicate an
52  error (with a suitable error message reported by calling report_error()).
53  The second one destroys delegate object and must be called for every
54  created delegate object. Passing NULL as argument value is allowed, i.e.
55
56    tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...))
57
58  always works.
59  """
60
61  def __init__(self, library, options=None):
62    """Loads delegate from the shared library.
63
64    Args:
65      library: Shared library name.
66      options: Dictionary of options that are required to load the delegate. All
67        keys and values in the dictionary should be serializable. Consult the
68        documentation of the specific delegate for required and legal options.
69        (default None)
70
71    Raises:
72      RuntimeError: This is raised if the Python implementation is not CPython.
73    """
74
75    # TODO(b/136468453): Remove need for __del__ ordering needs of CPython
76    # by using explicit closes(). See implementation of Interpreter __del__.
77    if platform.python_implementation() != 'CPython':
78      raise RuntimeError('Delegates are currently only supported into CPython'
79                         'due to missing immediate reference counting.')
80
81    self._library = ctypes.pydll.LoadLibrary(library)
82    self._library.tflite_plugin_create_delegate.argtypes = [
83        ctypes.POINTER(ctypes.c_char_p),
84        ctypes.POINTER(ctypes.c_char_p), ctypes.c_int,
85        ctypes.CFUNCTYPE(None, ctypes.c_char_p)
86    ]
87    self._library.tflite_plugin_create_delegate.restype = ctypes.c_void_p
88
89    # Convert the options from a dictionary to lists of char pointers.
90    options = options or {}
91    options_keys = (ctypes.c_char_p * len(options))()
92    options_values = (ctypes.c_char_p * len(options))()
93    for idx, (key, value) in enumerate(options.items()):
94      options_keys[idx] = str(key).encode('utf-8')
95      options_values[idx] = str(value).encode('utf-8')
96
97    class ErrorMessageCapture(object):
98
99      def __init__(self):
100        self.message = ''
101
102      def report(self, x):
103        self.message += x if isinstance(x, str) else x.decode('utf-8')
104
105    capture = ErrorMessageCapture()
106    error_capturer_cb = ctypes.CFUNCTYPE(None, ctypes.c_char_p)(capture.report)
107    # Do not make a copy of _delegate_ptr. It is freed by Delegate's finalizer.
108    self._delegate_ptr = self._library.tflite_plugin_create_delegate(
109        options_keys, options_values, len(options), error_capturer_cb)
110    if self._delegate_ptr is None:
111      raise ValueError(capture.message)
112
113  def __del__(self):
114    # __del__ can not be called multiple times, so if the delegate is destroyed.
115    # don't try to destroy it twice.
116    if self._library is not None:
117      self._library.tflite_plugin_destroy_delegate.argtypes = [ctypes.c_void_p]
118      self._library.tflite_plugin_destroy_delegate(self._delegate_ptr)
119      self._library = None
120
121  def _get_native_delegate_pointer(self):
122    """Returns the native TfLiteDelegate pointer.
123
124    It is not safe to copy this pointer because it needs to be freed.
125
126    Returns:
127      TfLiteDelegate *
128    """
129    return self._delegate_ptr
130
131
132@_tf_export('lite.experimental.load_delegate')
133def load_delegate(library, options=None):
134  """Returns loaded Delegate object.
135
136  Args:
137    library: Name of shared library containing the
138      [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates).
139    options: Dictionary of options that are required to load the delegate. All
140      keys and values in the dictionary should be convertible to str. Consult
141      the documentation of the specific delegate for required and legal options.
142      (default None)
143
144  Returns:
145    Delegate object.
146
147  Raises:
148    ValueError: Delegate failed to load.
149    RuntimeError: If delegate loading is used on unsupported platform.
150  """
151  try:
152    delegate = Delegate(library, options)
153  except ValueError as e:
154    raise ValueError('Failed to load delegate from {}\n{}'.format(
155        library, str(e)))
156  return delegate
157
158
159class SignatureRunner(object):
160  """SignatureRunner class for running TFLite models using SignatureDef.
161
162  This class should be instantiated through TFLite Interpreter only using
163  get_signature_runner method on Interpreter.
164  Example,
165  signature = interpreter.get_signature_runner("my_signature")
166  result = signature(input_1=my_input_1, input_2=my_input_2)
167  print(result["my_output"])
168  print(result["my_second_output"])
169  All names used are this specific SignatureDef names.
170
171  Notes:
172    No other function on this object or on the interpreter provided should be
173    called while this object call has not finished.
174  """
175
176  def __init__(self, interpreter=None, signature_def_name=None):
177    """Constructor.
178
179    Args:
180      interpreter: Interpreter object that is already initialized with the
181        requested model.
182      signature_def_name: SignatureDef names to be used.
183    """
184    if not interpreter:
185      raise ValueError('None interpreter provided.')
186    if not signature_def_name:
187      raise ValueError('None signature_def_name provided.')
188    self._interpreter = interpreter
189    self._signature_def_name = signature_def_name
190    signature_defs = interpreter._get_full_signature_list()
191    if signature_def_name not in signature_defs:
192      raise ValueError('Invalid signature_def_name provided.')
193    self._signature_def = signature_defs[signature_def_name]
194    self._outputs = self._signature_def['outputs'].items()
195    self._inputs = self._signature_def['inputs']
196
197  def __call__(self, **kwargs):
198    """Runs the SignatureDef given the provided inputs in arguments.
199
200    Args:
201      **kwargs: key,value for inputs to the model. Key is the SignatureDef input
202        name. Value is numpy array with the value.
203
204    Returns:
205      dictionary of the results from the model invoke.
206      Key in the dictionary is SignatureDef output name.
207      Value is the result Tensor.
208    """
209
210    if len(kwargs) != len(self._inputs):
211      raise ValueError(
212          'Invalid number of inputs provided for running a SignatureDef, '
213          'expected %s vs provided %s' % (len(kwargs), len(self._inputs)))
214    # Resize input tensors
215    for input_name, value in kwargs.items():
216      if input_name not in self._inputs:
217        raise ValueError('Invalid Input name (%s) for SignatureDef' %
218                         input_name)
219      self._interpreter.resize_tensor_input(self._inputs[input_name],
220                                            value.shape)
221    # Allocate tensors.
222    self._interpreter.allocate_tensors()
223    # Set the input values.
224    for input_name, value in kwargs.items():
225      self._interpreter._set_input_tensor(
226          input_name, value=value, method_name=self._signature_def_name)
227    self._interpreter.invoke()
228    result = {}
229    for output_name, output_index in self._outputs:
230      result[output_name] = self._interpreter.get_tensor(output_index)
231    return result
232
233
234@_tf_export('lite.Interpreter')
235class Interpreter(object):
236  """Interpreter interface for TensorFlow Lite Models.
237
238  This makes the TensorFlow Lite interpreter accessible in Python.
239  It is possible to use this interpreter in a multithreaded Python environment,
240  but you must be sure to call functions of a particular instance from only
241  one thread at a time. So if you want to have 4 threads running different
242  inferences simultaneously, create  an interpreter for each one as thread-local
243  data. Similarly, if you are calling invoke() in one thread on a single
244  interpreter but you want to use tensor() on another thread once it is done,
245  you must use a synchronization primitive between the threads to ensure invoke
246  has returned before calling tensor().
247  """
248
249  def __init__(self,
250               model_path=None,
251               model_content=None,
252               experimental_delegates=None,
253               num_threads=None):
254    """Constructor.
255
256    Args:
257      model_path: Path to TF-Lite Flatbuffer file.
258      model_content: Content of model.
259      experimental_delegates: Experimental. Subject to change. List of
260        [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
261          objects returned by lite.load_delegate().
262      num_threads: Sets the number of threads used by the interpreter and
263        available to CPU kernels. If not set, the interpreter will use an
264        implementation-dependent default number of threads. Currently, only a
265        subset of kernels, such as conv, support multi-threading.
266
267    Raises:
268      ValueError: If the interpreter was unable to create.
269    """
270    if not hasattr(self, '_custom_op_registerers'):
271      self._custom_op_registerers = []
272    if model_path and not model_content:
273      custom_op_registerers_by_name = [
274          x for x in self._custom_op_registerers if isinstance(x, str)
275      ]
276      custom_op_registerers_by_func = [
277          x for x in self._custom_op_registerers if not isinstance(x, str)
278      ]
279      self._interpreter = (
280          _interpreter_wrapper.CreateWrapperFromFile(
281              model_path, custom_op_registerers_by_name,
282              custom_op_registerers_by_func))
283      if not self._interpreter:
284        raise ValueError('Failed to open {}'.format(model_path))
285    elif model_content and not model_path:
286      custom_op_registerers_by_name = [
287          x for x in self._custom_op_registerers if isinstance(x, str)
288      ]
289      custom_op_registerers_by_func = [
290          x for x in self._custom_op_registerers if not isinstance(x, str)
291      ]
292      # Take a reference, so the pointer remains valid.
293      # Since python strings are immutable then PyString_XX functions
294      # will always return the same pointer.
295      self._model_content = model_content
296      self._interpreter = (
297          _interpreter_wrapper.CreateWrapperFromBuffer(
298              model_content, custom_op_registerers_by_name,
299              custom_op_registerers_by_func))
300    elif not model_content and not model_path:
301      raise ValueError('`model_path` or `model_content` must be specified.')
302    else:
303      raise ValueError('Can\'t both provide `model_path` and `model_content`')
304
305    if num_threads is not None:
306      if not isinstance(num_threads, int):
307        raise ValueError('type of num_threads should be int')
308      if num_threads < 1:
309        raise ValueError('num_threads should >= 1')
310      self._interpreter.SetNumThreads(num_threads)
311
312    # Each delegate is a wrapper that owns the delegates that have been loaded
313    # as plugins. The interpreter wrapper will be using them, but we need to
314    # hold them in a list so that the lifetime is preserved at least as long as
315    # the interpreter wrapper.
316    self._delegates = []
317    if experimental_delegates:
318      self._delegates = experimental_delegates
319      for delegate in self._delegates:
320        self._interpreter.ModifyGraphWithDelegate(
321            delegate._get_native_delegate_pointer())  # pylint: disable=protected-access
322    self._signature_defs = self.get_signature_list()
323
324  def __del__(self):
325    # Must make sure the interpreter is destroyed before things that
326    # are used by it like the delegates. NOTE this only works on CPython
327    # probably.
328    # TODO(b/136468453): Remove need for __del__ ordering needs of CPython
329    # by using explicit closes(). See implementation of Interpreter __del__.
330    self._interpreter = None
331    self._delegates = None
332
333  def allocate_tensors(self):
334    self._ensure_safe()
335    return self._interpreter.AllocateTensors()
336
337  def _safe_to_run(self):
338    """Returns true if there exist no numpy array buffers.
339
340    This means it is safe to run tflite calls that may destroy internally
341    allocated memory. This works, because in the wrapper.cc we have made
342    the numpy base be the self._interpreter.
343    """
344    # NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
345    # If this environment is the only _interpreter, then the ref count should be
346    # 2 (1 in self and 1 in temporary of sys.getrefcount).
347    return sys.getrefcount(self._interpreter) == 2
348
349  def _ensure_safe(self):
350    """Makes sure no numpy arrays pointing to internal buffers are active.
351
352    This should be called from any function that will call a function on
353    _interpreter that may reallocate memory e.g. invoke(), ...
354
355    Raises:
356      RuntimeError: If there exist numpy objects pointing to internal memory
357        then we throw.
358    """
359    if not self._safe_to_run():
360      raise RuntimeError("""There is at least 1 reference to internal data
361      in the interpreter in the form of a numpy array or slice. Be sure to
362      only hold the function returned from tensor() if you are using raw
363      data access.""")
364
365  # Experimental and subject to change
366  def _get_op_details(self, op_index):
367    """Gets a dictionary with arrays of ids for tensors involved with an op.
368
369    Args:
370      op_index: Operation/node index of node to query.
371
372    Returns:
373      a dictionary containing the index, op name, and arrays with lists of the
374      indices for the inputs and outputs of the op/node.
375    """
376    op_index = int(op_index)
377    op_name = self._interpreter.NodeName(op_index)
378    op_inputs = self._interpreter.NodeInputs(op_index)
379    op_outputs = self._interpreter.NodeOutputs(op_index)
380
381    details = {
382        'index': op_index,
383        'op_name': op_name,
384        'inputs': op_inputs,
385        'outputs': op_outputs,
386    }
387
388    return details
389
390  def _get_tensor_details(self, tensor_index):
391    """Gets tensor details.
392
393    Args:
394      tensor_index: Tensor index of tensor to query.
395
396    Returns:
397      A dictionary containing the following fields of the tensor:
398        'name': The tensor name.
399        'index': The tensor index in the interpreter.
400        'shape': The shape of the tensor.
401        'quantization': Deprecated, use 'quantization_parameters'. This field
402            only works for per-tensor quantization, whereas
403            'quantization_parameters' works in all cases.
404        'quantization_parameters': The parameters used to quantize the tensor:
405          'scales': List of scales (one if per-tensor quantization)
406          'zero_points': List of zero_points (one if per-tensor quantization)
407          'quantized_dimension': Specifies the dimension of per-axis
408              quantization, in the case of multiple scales/zero_points.
409
410    Raises:
411      ValueError: If tensor_index is invalid.
412    """
413    tensor_index = int(tensor_index)
414    tensor_name = self._interpreter.TensorName(tensor_index)
415    tensor_size = self._interpreter.TensorSize(tensor_index)
416    tensor_size_signature = self._interpreter.TensorSizeSignature(tensor_index)
417    tensor_type = self._interpreter.TensorType(tensor_index)
418    tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
419    tensor_quantization_params = self._interpreter.TensorQuantizationParameters(
420        tensor_index)
421    tensor_sparsity_params = self._interpreter.TensorSparsityParameters(
422        tensor_index)
423
424    if not tensor_type:
425      raise ValueError('Could not get tensor details')
426
427    details = {
428        'name': tensor_name,
429        'index': tensor_index,
430        'shape': tensor_size,
431        'shape_signature': tensor_size_signature,
432        'dtype': tensor_type,
433        'quantization': tensor_quantization,
434        'quantization_parameters': {
435            'scales': tensor_quantization_params[0],
436            'zero_points': tensor_quantization_params[1],
437            'quantized_dimension': tensor_quantization_params[2],
438        },
439        'sparsity_parameters': tensor_sparsity_params
440    }
441
442    return details
443
444  # Experimental and subject to change
445  def _get_ops_details(self):
446    """Gets op details for every node.
447
448    Returns:
449      A list of dictionaries containing arrays with lists of tensor ids for
450      tensors involved in the op.
451    """
452    return [
453        self._get_op_details(idx) for idx in range(self._interpreter.NumNodes())
454    ]
455
456  def get_tensor_details(self):
457    """Gets tensor details for every tensor with valid tensor details.
458
459    Tensors where required information about the tensor is not found are not
460    added to the list. This includes temporary tensors without a name.
461
462    Returns:
463      A list of dictionaries containing tensor information.
464    """
465    tensor_details = []
466    for idx in range(self._interpreter.NumTensors()):
467      try:
468        tensor_details.append(self._get_tensor_details(idx))
469      except ValueError:
470        pass
471    return tensor_details
472
473  def get_input_details(self):
474    """Gets model input details.
475
476    Returns:
477      A list of input details.
478    """
479    return [
480        self._get_tensor_details(i) for i in self._interpreter.InputIndices()
481    ]
482
483  def set_tensor(self, tensor_index, value):
484    """Sets the value of the input tensor.
485
486    Note this copies data in `value`.
487
488    If you want to avoid copying, you can use the `tensor()` function to get a
489    numpy buffer pointing to the input buffer in the tflite interpreter.
490
491    Args:
492      tensor_index: Tensor index of tensor to set. This value can be gotten from
493        the 'index' field in get_input_details.
494      value: Value of tensor to set.
495
496    Raises:
497      ValueError: If the interpreter could not set the tensor.
498    """
499    self._interpreter.SetTensor(tensor_index, value)
500
501  def resize_tensor_input(self, input_index, tensor_size, strict=False):
502    """Resizes an input tensor.
503
504    Args:
505      input_index: Tensor index of input to set. This value can be gotten from
506        the 'index' field in get_input_details.
507      tensor_size: The tensor_shape to resize the input to.
508      strict: Only unknown dimensions can be resized when `strict` is True.
509        Unknown dimensions are indicated as `-1` in the `shape_signature`
510        attribute of a given tensor. (default False)
511
512    Raises:
513      ValueError: If the interpreter could not resize the input tensor.
514
515    Usage:
516    ```
517    interpreter = Interpreter(model_content=tflite_model)
518    interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3])
519    interpreter.allocate_tensors()
520    interpreter.set_tensor(0, test_images)
521    interpreter.invoke()
522    ```
523    """
524    self._ensure_safe()
525    # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
526    # parameter.
527    tensor_size = np.array(tensor_size, dtype=np.int32)
528    self._interpreter.ResizeInputTensor(input_index, tensor_size, strict)
529
530  def get_output_details(self):
531    """Gets model output details.
532
533    Returns:
534      A list of output details.
535    """
536    return [
537        self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
538    ]
539
540  def get_signature_list(self):
541    """Gets list of SignatureDefs in the model.
542
543    Example,
544    ```
545    signatures = interpreter.get_signature_list()
546    print(signatures)
547
548    # {
549    #   'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']}
550    # }
551
552    Then using the names in the signature list you can get a callable from
553    get_signature_runner().
554    ```
555
556    Returns:
557      A list of SignatureDef details in a dictionary structure.
558      It is keyed on the SignatureDef method name, and the value holds
559      dictionary of inputs and outputs.
560    """
561    full_signature_defs = self._interpreter.GetSignatureDefs()
562    for _, signature_def in full_signature_defs.items():
563      signature_def['inputs'] = list(signature_def['inputs'].keys())
564      signature_def['outputs'] = list(signature_def['outputs'].keys())
565    return full_signature_defs
566
567  def _get_full_signature_list(self):
568    """Gets list of SignatureDefs in the model.
569
570    Example,
571    ```
572    signatures = interpreter._get_full_signature_list()
573    print(signatures)
574
575    # {
576    #   'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
577    # }
578
579    Then using the names in the signature list you can get a callable from
580    get_signature_runner().
581    ```
582
583    Returns:
584      A list of SignatureDef details in a dictionary structure.
585      It is keyed on the SignatureDef method name, and the value holds
586      dictionary of inputs and outputs.
587    """
588    return self._interpreter.GetSignatureDefs()
589
590  def _set_input_tensor(self, input_name, value, method_name=None):
591    """Sets the value of the input tensor.
592
593    Input tensor is identified by `input_name` in the SignatureDef identified
594    by `method_name`.
595    If the model has a single SignatureDef then you can pass None as
596    `method_name`.
597
598    Note this copies data in `value`.
599
600    Example,
601    ```
602    input_data = np.array([1.2, 1.4], np.float32)
603    signatures = interpreter.get_signature_list()
604    print(signatures)
605    # {
606    #   'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
607    # }
608    interpreter._set_input_tensor(input_name='x', value=input_data,
609    method_name='add_fn')
610    ```
611
612    Args:
613      input_name: Name of the output tensor in the SignatureDef.
614      value: Value of tensor to set as a numpy array.
615      method_name: The exported method name for the SignatureDef, it can be None
616        if and only if the model has a single SignatureDef. Default value is
617        None.
618
619    Raises:
620      ValueError: If the interpreter could not set the tensor. Or
621      if `method_name` is None and model doesn't have a single
622      Signature.
623    """
624    if method_name is None:
625      if len(self._signature_defs) != 1:
626        raise ValueError(
627            'SignatureDef method_name is None and model has {0} Signatures. '
628            'None is only allowed when the model has 1 SignatureDef'.format(
629                len(self._signature_defs)))
630      else:
631        method_name = next(iter(self._signature_defs))
632    self._interpreter.SetInputTensorFromSignatureDefName(
633        input_name, method_name, value)
634
635  def get_signature_runner(self, method_name=None):
636    """Gets callable for inference of specific SignatureDef.
637
638    Example usage,
639    ```
640    interpreter = tf.lite.Interpreter(model_content=tflite_model)
641    interpreter.allocate_tensors()
642    fn = interpreter.get_signature_runner('div_with_remainder')
643    output = fn(x=np.array([3]), y=np.array([2]))
644    print(output)
645    # {
646    #   'quotient': array([1.], dtype=float32)
647    #   'remainder': array([1.], dtype=float32)
648    # }
649    ```
650
651    None can be passed for method_name if the model has a single Signature only.
652
653    All names used are this specific SignatureDef names.
654
655
656    Args:
657      method_name: The exported method name for the SignatureDef, it can be None
658        if and only if the model has a single SignatureDef. Default value is
659        None.
660
661    Returns:
662      This returns a callable that can run inference for SignatureDef defined
663      by argument 'method_name'.
664      The callable will take key arguments corresponding to the arguments of the
665      SignatureDef, that should have numpy values.
666      The callable will returns dictionary that maps from output names to numpy
667      values of the computed results.
668
669    Raises:
670      ValueError: If passed method_name is invalid.
671    """
672    if method_name is None:
673      if len(self._signature_defs) != 1:
674        raise ValueError(
675            'SignatureDef method_name is None and model has {0} Signatures. '
676            'None is only allowed when the model has 1 SignatureDef'.format(
677                len(self._signature_defs)))
678      else:
679        method_name = next(iter(self._signature_defs))
680    return SignatureRunner(interpreter=self, signature_def_name=method_name)
681
682  def get_tensor(self, tensor_index):
683    """Gets the value of the output tensor (get a copy).
684
685    If you wish to avoid the copy, use `tensor()`. This function cannot be used
686    to read intermediate results.
687
688    Args:
689      tensor_index: Tensor index of tensor to get. This value can be gotten from
690        the 'index' field in get_output_details.
691
692    Returns:
693      a numpy array.
694    """
695    return self._interpreter.GetTensor(tensor_index)
696
697  def tensor(self, tensor_index):
698    """Returns function that gives a numpy view of the current tensor buffer.
699
700    This allows reading and writing to this tensors w/o copies. This more
701    closely mirrors the C++ Interpreter class interface's tensor() member, hence
702    the name. Be careful to not hold these output references through calls
703    to `allocate_tensors()` and `invoke()`. This function cannot be used to read
704    intermediate results.
705
706    Usage:
707
708    ```
709    interpreter.allocate_tensors()
710    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
711    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
712    for i in range(10):
713      input().fill(3.)
714      interpreter.invoke()
715      print("inference %s" % output())
716    ```
717
718    Notice how this function avoids making a numpy array directly. This is
719    because it is important to not hold actual numpy views to the data longer
720    than necessary. If you do, then the interpreter can no longer be invoked,
721    because it is possible the interpreter would resize and invalidate the
722    referenced tensors. The NumPy API doesn't allow any mutability of the
723    the underlying buffers.
724
725    WRONG:
726
727    ```
728    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
729    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
730    interpreter.allocate_tensors()  # This will throw RuntimeError
731    for i in range(10):
732      input.fill(3.)
733      interpreter.invoke()  # this will throw RuntimeError since input,output
734    ```
735
736    Args:
737      tensor_index: Tensor index of tensor to get. This value can be gotten from
738        the 'index' field in get_output_details.
739
740    Returns:
741      A function that can return a new numpy array pointing to the internal
742      TFLite tensor state at any point. It is safe to hold the function forever,
743      but it is not safe to hold the numpy array forever.
744    """
745    return lambda: self._interpreter.tensor(self._interpreter, tensor_index)
746
747  def invoke(self):
748    """Invoke the interpreter.
749
750    Be sure to set the input sizes, allocate tensors and fill values before
751    calling this. Also, note that this function releases the GIL so heavy
752    computation can be done in the background while the Python interpreter
753    continues. No other function on this object should be called while the
754    invoke() call has not finished.
755
756    Raises:
757      ValueError: When the underlying interpreter fails raise ValueError.
758    """
759    self._ensure_safe()
760    self._interpreter.Invoke()
761
762  def reset_all_variables(self):
763    return self._interpreter.ResetVariableTensors()
764
765  # Experimental and subject to change.
766  def _native_handle(self):
767    """Returns a pointer to the underlying tflite::Interpreter instance.
768
769    This allows extending tflite.Interpreter's functionality in a custom C++
770    function. Consider how that may work in a custom pybind wrapper:
771
772      m.def("SomeNewFeature", ([](py::object handle) {
773        auto* interpreter =
774          reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>());
775        ...
776      }))
777
778    and corresponding Python call:
779
780      SomeNewFeature(interpreter.native_handle())
781
782    Note: This approach is fragile. Users must guarantee the C++ extension build
783    is consistent with the tflite.Interpreter's underlying C++ build.
784    """
785    return self._interpreter.interpreter()
786
787
788class InterpreterWithCustomOps(Interpreter):
789  """Interpreter interface for TensorFlow Lite Models that accepts custom ops.
790
791  The interface provided by this class is experimental and therefore not exposed
792  as part of the public API.
793
794  Wraps the tf.lite.Interpreter class and adds the ability to load custom ops
795  by providing the names of functions that take a pointer to a BuiltinOpResolver
796  and add a custom op.
797  """
798
799  def __init__(self,
800               model_path=None,
801               model_content=None,
802               experimental_delegates=None,
803               custom_op_registerers=None):
804    """Constructor.
805
806    Args:
807      model_path: Path to TF-Lite Flatbuffer file.
808      model_content: Content of model.
809      experimental_delegates: Experimental. Subject to change. List of
810        [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
811          objects returned by lite.load_delegate().
812      custom_op_registerers: List of str (symbol names) or functions that take a
813        pointer to a MutableOpResolver and register a custom op. When passing
814        functions, use a pybind function that takes a uintptr_t that can be
815        recast as a pointer to a MutableOpResolver.
816
817    Raises:
818      ValueError: If the interpreter was unable to create.
819    """
820    self._custom_op_registerers = custom_op_registerers or []
821    super(InterpreterWithCustomOps, self).__init__(
822        model_path=model_path,
823        model_content=model_content,
824        experimental_delegates=experimental_delegates)
825