1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python TF-Lite interpreter."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import sys
21import numpy as np
22
23# pylint: disable=g-import-not-at-top
24try:
25  from tensorflow.python.util.lazy_loader import LazyLoader
26  from tensorflow.python.util.tf_export import tf_export as _tf_export
27
28  # Lazy load since some of the performance benchmark skylark rules
29  # break dependencies. Must use double quotes to match code internal rewrite
30  # rule.
31  # pylint: disable=g-inconsistent-quotes
32  _interpreter_wrapper = LazyLoader(
33      "_interpreter_wrapper", globals(),
34      "tensorflow.lite.python.interpreter_wrapper."
35      "tensorflow_wrap_interpreter_wrapper")
36  # pylint: enable=g-inconsistent-quotes
37
38  del LazyLoader
39except ImportError:
40  # When full Tensorflow Python PIP is not available do not use lazy load
41  # and instead uf the tflite_runtime path.
42  from tflite_runtime.lite.python import interpreter_wrapper as _interpreter_wrapper
43
44  def tf_export_dummy(*x, **kwargs):
45    del x, kwargs
46    return lambda x: x
47  _tf_export = tf_export_dummy
48
49
50@_tf_export('lite.Interpreter')
51class Interpreter(object):
52  """Interpreter inferace for TF-Lite Models."""
53
54  def __init__(self, model_path=None, model_content=None):
55    """Constructor.
56
57    Args:
58      model_path: Path to TF-Lite Flatbuffer file.
59      model_content: Content of model.
60
61    Raises:
62      ValueError: If the interpreter was unable to create.
63    """
64    if model_path and not model_content:
65      self._interpreter = (
66          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
67              model_path))
68      if not self._interpreter:
69        raise ValueError('Failed to open {}'.format(model_path))
70    elif model_content and not model_path:
71      # Take a reference, so the pointer remains valid.
72      # Since python strings are immutable then PyString_XX functions
73      # will always return the same pointer.
74      self._model_content = model_content
75      self._interpreter = (
76          _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
77              model_content))
78    elif not model_path and not model_path:
79      raise ValueError('`model_path` or `model_content` must be specified.')
80    else:
81      raise ValueError('Can\'t both provide `model_path` and `model_content`')
82
83  def allocate_tensors(self):
84    self._ensure_safe()
85    return self._interpreter.AllocateTensors()
86
87  def _safe_to_run(self):
88    """Returns true if there exist no numpy array buffers.
89
90    This means it is safe to run tflite calls that may destroy internally
91    allocated memory. This works, because in the wrapper.cc we have made
92    the numpy base be the self._interpreter.
93    """
94    # NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
95    # If this environment is the only _interpreter, then the ref count should be
96    # 2 (1 in self and 1 in temporary of sys.getrefcount).
97    return sys.getrefcount(self._interpreter) == 2
98
99  def _ensure_safe(self):
100    """Makes sure no numpy arrays pointing to internal buffers are active.
101
102    This should be called from any function that will call a function on
103    _interpreter that may reallocate memory e.g. invoke(), ...
104
105    Raises:
106      RuntimeError: If there exist numpy objects pointing to internal memory
107        then we throw.
108    """
109    if not self._safe_to_run():
110      raise RuntimeError("""There is at least 1 reference to internal data
111      in the interpreter in the form of a numpy array or slice. Be sure to
112      only hold the function returned from tensor() if you are using raw
113      data access.""")
114
115  def _get_tensor_details(self, tensor_index):
116    """Gets tensor details.
117
118    Args:
119      tensor_index: Tensor index of tensor to query.
120
121    Returns:
122      a dictionary containing the name, index, shape and type of the tensor.
123
124    Raises:
125      ValueError: If tensor_index is invalid.
126    """
127    tensor_index = int(tensor_index)
128    tensor_name = self._interpreter.TensorName(tensor_index)
129    tensor_size = self._interpreter.TensorSize(tensor_index)
130    tensor_type = self._interpreter.TensorType(tensor_index)
131    tensor_quantization = self._interpreter.TensorQuantization(tensor_index)
132
133    if not tensor_name or not tensor_type:
134      raise ValueError('Could not get tensor details')
135
136    details = {
137        'name': tensor_name,
138        'index': tensor_index,
139        'shape': tensor_size,
140        'dtype': tensor_type,
141        'quantization': tensor_quantization,
142    }
143
144    return details
145
146  def get_tensor_details(self):
147    """Gets tensor details for every tensor with valid tensor details.
148
149    Tensors where required information about the tensor is not found are not
150    added to the list. This includes temporary tensors without a name.
151
152    Returns:
153      A list of dictionaries containing tensor information.
154    """
155    tensor_details = []
156    for idx in range(self._interpreter.NumTensors()):
157      try:
158        tensor_details.append(self._get_tensor_details(idx))
159      except ValueError:
160        pass
161    return tensor_details
162
163  def get_input_details(self):
164    """Gets model input details.
165
166    Returns:
167      A list of input details.
168    """
169    return [
170        self._get_tensor_details(i) for i in self._interpreter.InputIndices()
171    ]
172
173  def set_tensor(self, tensor_index, value):
174    """Sets the value of the input tensor. Note this copies data in `value`.
175
176    If you want to avoid copying, you can use the `tensor()` function to get a
177    numpy buffer pointing to the input buffer in the tflite interpreter.
178
179    Args:
180      tensor_index: Tensor index of tensor to set. This value can be gotten from
181                    the 'index' field in get_input_details.
182      value: Value of tensor to set.
183
184    Raises:
185      ValueError: If the interpreter could not set the tensor.
186    """
187    self._interpreter.SetTensor(tensor_index, value)
188
189  def resize_tensor_input(self, input_index, tensor_size):
190    """Resizes an input tensor.
191
192    Args:
193      input_index: Tensor index of input to set. This value can be gotten from
194                   the 'index' field in get_input_details.
195      tensor_size: The tensor_shape to resize the input to.
196
197    Raises:
198      ValueError: If the interpreter could not resize the input tensor.
199    """
200    self._ensure_safe()
201    # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
202    # parameter.
203    tensor_size = np.array(tensor_size, dtype=np.int32)
204    self._interpreter.ResizeInputTensor(input_index, tensor_size)
205
206  def get_output_details(self):
207    """Gets model output details.
208
209    Returns:
210      A list of output details.
211    """
212    return [
213        self._get_tensor_details(i) for i in self._interpreter.OutputIndices()
214    ]
215
216  def get_tensor(self, tensor_index):
217    """Gets the value of the input tensor (get a copy).
218
219    If you wish to avoid the copy, use `tensor()`. This function cannot be used
220    to read intermediate results.
221
222    Args:
223      tensor_index: Tensor index of tensor to get. This value can be gotten from
224                    the 'index' field in get_output_details.
225
226    Returns:
227      a numpy array.
228    """
229    return self._interpreter.GetTensor(tensor_index)
230
231  def tensor(self, tensor_index):
232    """Returns function that gives a numpy view of the current tensor buffer.
233
234    This allows reading and writing to this tensors w/o copies. This more
235    closely mirrors the C++ Interpreter class interface's tensor() member, hence
236    the name. Be careful to not hold these output references through calls
237    to `allocate_tensors()` and `invoke()`. This function cannot be used to read
238    intermediate results.
239
240    Usage:
241
242    ```
243    interpreter.allocate_tensors()
244    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
245    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
246    for i in range(10):
247      input().fill(3.)
248      interpreter.invoke()
249      print("inference %s" % output())
250    ```
251
252    Notice how this function avoids making a numpy array directly. This is
253    because it is important to not hold actual numpy views to the data longer
254    than necessary. If you do, then the interpreter can no longer be invoked,
255    because it is possible the interpreter would resize and invalidate the
256    referenced tensors. The NumPy API doesn't allow any mutability of the
257    the underlying buffers.
258
259    WRONG:
260
261    ```
262    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
263    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
264    interpreter.allocate_tensors()  # This will throw RuntimeError
265    for i in range(10):
266      input.fill(3.)
267      interpreter.invoke()  # this will throw RuntimeError since input,output
268    ```
269
270    Args:
271      tensor_index: Tensor index of tensor to get. This value can be gotten from
272                    the 'index' field in get_output_details.
273
274    Returns:
275      A function that can return a new numpy array pointing to the internal
276      TFLite tensor state at any point. It is safe to hold the function forever,
277      but it is not safe to hold the numpy array forever.
278    """
279    return lambda: self._interpreter.tensor(self._interpreter, tensor_index)
280
281  def invoke(self):
282    """Invoke the interpreter.
283
284    Be sure to set the input sizes, allocate tensors and fill values before
285    calling this.
286
287    Raises:
288      ValueError: When the underlying interpreter fails raise ValueError.
289    """
290    self._ensure_safe()
291    self._interpreter.Invoke()
292
293  def reset_all_variables(self):
294    return self._interpreter.ResetVariableTensors()
295