1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""TensorFlow Lite Python Interface: Sanity check."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import ctypes
22import io
23import sys
24
25import numpy as np
26import six
27
28# Force loaded shared object symbols to be globally visible. This is needed so
29# that the interpreter_wrapper, in one .so file, can see the test_registerer,
30# in a different .so file. Note that this may already be set by default.
31# pylint: disable=g-import-not-at-top
32if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
33  sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
34
35from tensorflow.lite.python import interpreter as interpreter_wrapper
36from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer
37from tensorflow.python.framework import test_util
38from tensorflow.python.platform import resource_loader
39from tensorflow.python.platform import test
40# pylint: enable=g-import-not-at-top
41
42
43class InterpreterCustomOpsTest(test_util.TensorFlowTestCase):
44
45  def testRegistererByName(self):
46    interpreter = interpreter_wrapper.InterpreterWithCustomOps(
47        model_path=resource_loader.get_path_to_datafile(
48            'testdata/permute_float.tflite'),
49        custom_op_registerers=['TF_TestRegisterer'])
50    self.assertTrue(interpreter._safe_to_run())
51    self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
52
53  def testRegistererByFunc(self):
54    interpreter = interpreter_wrapper.InterpreterWithCustomOps(
55        model_path=resource_loader.get_path_to_datafile(
56            'testdata/permute_float.tflite'),
57        custom_op_registerers=[test_registerer.TF_TestRegisterer])
58    self.assertTrue(interpreter._safe_to_run())
59    self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1)
60
61  def testRegistererFailure(self):
62    bogus_name = 'CompletelyBogusRegistererName'
63    with self.assertRaisesRegex(
64        ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'):
65      interpreter_wrapper.InterpreterWithCustomOps(
66          model_path=resource_loader.get_path_to_datafile(
67              'testdata/permute_float.tflite'),
68          custom_op_registerers=[bogus_name])
69
70  def testNoCustomOps(self):
71    interpreter = interpreter_wrapper.InterpreterWithCustomOps(
72        model_path=resource_loader.get_path_to_datafile(
73            'testdata/permute_float.tflite'))
74    self.assertTrue(interpreter._safe_to_run())
75
76
77class InterpreterTest(test_util.TensorFlowTestCase):
78
79  def assertQuantizationParamsEqual(self, scales, zero_points,
80                                    quantized_dimension, params):
81    self.assertAllEqual(scales, params['scales'])
82    self.assertAllEqual(zero_points, params['zero_points'])
83    self.assertEqual(quantized_dimension, params['quantized_dimension'])
84
85  def testThreads_NegativeValue(self):
86    with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'):
87      interpreter_wrapper.Interpreter(
88          model_path=resource_loader.get_path_to_datafile(
89              'testdata/permute_float.tflite'),
90          num_threads=-1)
91
92  def testThreads_WrongType(self):
93    with self.assertRaisesRegex(ValueError,
94                                'type of num_threads should be int'):
95      interpreter_wrapper.Interpreter(
96          model_path=resource_loader.get_path_to_datafile(
97              'testdata/permute_float.tflite'),
98          num_threads=4.2)
99
100  def testFloat(self):
101    interpreter = interpreter_wrapper.Interpreter(
102        model_path=resource_loader.get_path_to_datafile(
103            'testdata/permute_float.tflite'))
104    interpreter.allocate_tensors()
105
106    input_details = interpreter.get_input_details()
107    self.assertEqual(1, len(input_details))
108    self.assertEqual('input', input_details[0]['name'])
109    self.assertEqual(np.float32, input_details[0]['dtype'])
110    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
111    self.assertEqual((0.0, 0), input_details[0]['quantization'])
112    self.assertQuantizationParamsEqual(
113        [], [], 0, input_details[0]['quantization_parameters'])
114
115    output_details = interpreter.get_output_details()
116    self.assertEqual(1, len(output_details))
117    self.assertEqual('output', output_details[0]['name'])
118    self.assertEqual(np.float32, output_details[0]['dtype'])
119    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
120    self.assertEqual((0.0, 0), output_details[0]['quantization'])
121    self.assertQuantizationParamsEqual(
122        [], [], 0, output_details[0]['quantization_parameters'])
123
124    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
125    expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
126    interpreter.set_tensor(input_details[0]['index'], test_input)
127    interpreter.invoke()
128
129    output_data = interpreter.get_tensor(output_details[0]['index'])
130    self.assertTrue((expected_output == output_data).all())
131
132  def testFloatWithTwoThreads(self):
133    interpreter = interpreter_wrapper.Interpreter(
134        model_path=resource_loader.get_path_to_datafile(
135            'testdata/permute_float.tflite'),
136        num_threads=2)
137    interpreter.allocate_tensors()
138
139    input_details = interpreter.get_input_details()
140    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
141    expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
142    interpreter.set_tensor(input_details[0]['index'], test_input)
143    interpreter.invoke()
144
145    output_details = interpreter.get_output_details()
146    output_data = interpreter.get_tensor(output_details[0]['index'])
147    self.assertTrue((expected_output == output_data).all())
148
149  def testUint8(self):
150    model_path = resource_loader.get_path_to_datafile(
151        'testdata/permute_uint8.tflite')
152    with io.open(model_path, 'rb') as model_file:
153      data = model_file.read()
154
155    interpreter = interpreter_wrapper.Interpreter(model_content=data)
156    interpreter.allocate_tensors()
157
158    input_details = interpreter.get_input_details()
159    self.assertEqual(1, len(input_details))
160    self.assertEqual('input', input_details[0]['name'])
161    self.assertEqual(np.uint8, input_details[0]['dtype'])
162    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
163    self.assertEqual((1.0, 0), input_details[0]['quantization'])
164    self.assertQuantizationParamsEqual(
165        [1.0], [0], 0, input_details[0]['quantization_parameters'])
166
167    output_details = interpreter.get_output_details()
168    self.assertEqual(1, len(output_details))
169    self.assertEqual('output', output_details[0]['name'])
170    self.assertEqual(np.uint8, output_details[0]['dtype'])
171    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
172    self.assertEqual((1.0, 0), output_details[0]['quantization'])
173    self.assertQuantizationParamsEqual(
174        [1.0], [0], 0, output_details[0]['quantization_parameters'])
175
176    test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
177    expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
178    interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape)
179    interpreter.allocate_tensors()
180    interpreter.set_tensor(input_details[0]['index'], test_input)
181    interpreter.invoke()
182
183    output_data = interpreter.get_tensor(output_details[0]['index'])
184    self.assertTrue((expected_output == output_data).all())
185
186  def testString(self):
187    interpreter = interpreter_wrapper.Interpreter(
188        model_path=resource_loader.get_path_to_datafile(
189            'testdata/gather_string.tflite'))
190    interpreter.allocate_tensors()
191
192    input_details = interpreter.get_input_details()
193    self.assertEqual(2, len(input_details))
194    self.assertEqual('input', input_details[0]['name'])
195    self.assertEqual(np.string_, input_details[0]['dtype'])
196    self.assertTrue(([10] == input_details[0]['shape']).all())
197    self.assertEqual((0.0, 0), input_details[0]['quantization'])
198    self.assertQuantizationParamsEqual(
199        [], [], 0, input_details[0]['quantization_parameters'])
200    self.assertEqual('indices', input_details[1]['name'])
201    self.assertEqual(np.int64, input_details[1]['dtype'])
202    self.assertTrue(([3] == input_details[1]['shape']).all())
203    self.assertEqual((0.0, 0), input_details[1]['quantization'])
204    self.assertQuantizationParamsEqual(
205        [], [], 0, input_details[1]['quantization_parameters'])
206
207    output_details = interpreter.get_output_details()
208    self.assertEqual(1, len(output_details))
209    self.assertEqual('output', output_details[0]['name'])
210    self.assertEqual(np.string_, output_details[0]['dtype'])
211    self.assertTrue(([3] == output_details[0]['shape']).all())
212    self.assertEqual((0.0, 0), output_details[0]['quantization'])
213    self.assertQuantizationParamsEqual(
214        [], [], 0, output_details[0]['quantization_parameters'])
215
216    test_input = np.array([1, 2, 3], dtype=np.int64)
217    interpreter.set_tensor(input_details[1]['index'], test_input)
218
219    test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
220    expected_output = np.array([b'b', b'c', b'd'])
221    interpreter.set_tensor(input_details[0]['index'], test_input)
222    interpreter.invoke()
223
224    output_data = interpreter.get_tensor(output_details[0]['index'])
225    self.assertTrue((expected_output == output_data).all())
226
227  def testStringZeroDim(self):
228    data = b'abcd' + bytes(16)
229    interpreter = interpreter_wrapper.Interpreter(
230        model_path=resource_loader.get_path_to_datafile(
231            'testdata/gather_string_0d.tflite'))
232    interpreter.allocate_tensors()
233
234    input_details = interpreter.get_input_details()
235    interpreter.set_tensor(input_details[0]['index'], np.array(data))
236    test_input_tensor = interpreter.get_tensor(input_details[0]['index'])
237    self.assertEqual(len(data), len(test_input_tensor.item(0)))
238
239  def testPerChannelParams(self):
240    interpreter = interpreter_wrapper.Interpreter(
241        model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
242    interpreter.allocate_tensors()
243
244    # Tensor index 1 is the weight.
245    weight_details = interpreter.get_tensor_details()[1]
246    qparams = weight_details['quantization_parameters']
247    # Ensure that we retrieve per channel quantization params correctly.
248    self.assertEqual(len(qparams['scales']), 128)
249
250  def testDenseTensorAccess(self):
251    interpreter = interpreter_wrapper.Interpreter(
252        model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin'))
253    interpreter.allocate_tensors()
254    weight_details = interpreter.get_tensor_details()[1]
255    s_params = weight_details['sparsity_parameters']
256    self.assertEqual(s_params, {})
257
258  def testSparseTensorAccess(self):
259    interpreter = interpreter_wrapper.InterpreterWithCustomOps(
260        model_path=resource_loader.get_path_to_datafile(
261            '../testdata/sparse_tensor.bin'),
262        custom_op_registerers=['TF_TestRegisterer'])
263    interpreter.allocate_tensors()
264
265    # Tensor at index 0 is sparse.
266    compressed_buffer = interpreter.get_tensor(0)
267    # Ensure that the buffer is of correct size and value.
268    self.assertEqual(len(compressed_buffer), 12)
269    sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6]
270    self.assertAllEqual(compressed_buffer, sparse_value)
271
272    tensor_details = interpreter.get_tensor_details()[0]
273    s_params = tensor_details['sparsity_parameters']
274
275    # Ensure sparsity parameter returned is correct
276    self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3])
277    self.assertAllEqual(s_params['block_map'], [0, 1])
278    dense_dim_metadata = {'format': 0, 'dense_size': 2}
279    self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata)
280    self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata)
281    self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata)
282    self.assertEqual(s_params['dim_metadata'][1]['format'], 1)
283    self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'],
284                        [0, 2, 3])
285    self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1])
286
287
288class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
289
290  def testInvalidModelContent(self):
291    with self.assertRaisesRegex(ValueError,
292                                'Model provided has model identifier \''):
293      interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
294
295  def testInvalidModelFile(self):
296    with self.assertRaisesRegex(ValueError,
297                                'Could not open \'totally_invalid_file_name\''):
298      interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name')
299
300  def testInvokeBeforeReady(self):
301    interpreter = interpreter_wrapper.Interpreter(
302        model_path=resource_loader.get_path_to_datafile(
303            'testdata/permute_float.tflite'))
304    with self.assertRaisesRegex(RuntimeError,
305                                'Invoke called on model that is not ready'):
306      interpreter.invoke()
307
308  def testInvalidModelFileContent(self):
309    with self.assertRaisesRegex(
310        ValueError, '`model_path` or `model_content` must be specified.'):
311      interpreter_wrapper.Interpreter(model_path=None, model_content=None)
312
313  def testInvalidIndex(self):
314    interpreter = interpreter_wrapper.Interpreter(
315        model_path=resource_loader.get_path_to_datafile(
316            'testdata/permute_float.tflite'))
317    interpreter.allocate_tensors()
318    # Invalid tensor index passed.
319    with self.assertRaisesRegex(ValueError, 'Tensor with no shape found.'):
320      interpreter._get_tensor_details(4)
321    with self.assertRaisesRegex(ValueError, 'Invalid node index'):
322      interpreter._get_op_details(4)
323
324
325class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
326
327  def setUp(self):
328    self.interpreter = interpreter_wrapper.Interpreter(
329        model_path=resource_loader.get_path_to_datafile(
330            'testdata/permute_float.tflite'))
331    self.interpreter.allocate_tensors()
332    self.input0 = self.interpreter.get_input_details()[0]['index']
333    self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32)
334
335  def testTensorAccessor(self):
336    """Check that tensor returns a reference."""
337    array_ref = self.interpreter.tensor(self.input0)
338    np.copyto(array_ref(), self.initial_data)
339    self.assertAllEqual(array_ref(), self.initial_data)
340    self.assertAllEqual(
341        self.interpreter.get_tensor(self.input0), self.initial_data)
342
343  def testGetTensorAccessor(self):
344    """Check that get_tensor returns a copy."""
345    self.interpreter.set_tensor(self.input0, self.initial_data)
346    array_initial_copy = self.interpreter.get_tensor(self.input0)
347    new_value = np.add(1., array_initial_copy)
348    self.interpreter.set_tensor(self.input0, new_value)
349    self.assertAllEqual(array_initial_copy, self.initial_data)
350    self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value)
351
352  def testBase(self):
353    self.assertTrue(self.interpreter._safe_to_run())
354    _ = self.interpreter.tensor(self.input0)
355    self.assertTrue(self.interpreter._safe_to_run())
356    in0 = self.interpreter.tensor(self.input0)()
357    self.assertFalse(self.interpreter._safe_to_run())
358    in0b = self.interpreter.tensor(self.input0)()
359    self.assertFalse(self.interpreter._safe_to_run())
360    # Now get rid of the buffers so that we can evaluate.
361    del in0
362    del in0b
363    self.assertTrue(self.interpreter._safe_to_run())
364
365  def testBaseProtectsFunctions(self):
366    in0 = self.interpreter.tensor(self.input0)()
367    # Make sure we get an exception if we try to run an unsafe operation
368    with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'):
369      _ = self.interpreter.allocate_tensors()
370    # Make sure we get an exception if we try to run an unsafe operation
371    with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'):
372      _ = self.interpreter.invoke()
373    # Now test that we can run
374    del in0  # this is our only buffer reference, so now it is safe to change
375    in0safe = self.interpreter.tensor(self.input0)
376    _ = self.interpreter.allocate_tensors()
377    del in0safe  # make sure in0Safe is held but lint doesn't complain
378
379
380class InterpreterDelegateTest(test_util.TensorFlowTestCase):
381
382  def setUp(self):
383    self._delegate_file = resource_loader.get_path_to_datafile(
384        'testdata/test_delegate.so')
385    self._model_file = resource_loader.get_path_to_datafile(
386        'testdata/permute_float.tflite')
387
388    # Load the library to reset the counters.
389    library = ctypes.pydll.LoadLibrary(self._delegate_file)
390    library.initialize_counters()
391
392  def _TestInterpreter(self, model_path, options=None):
393    """Test wrapper function that creates an interpreter with the delegate."""
394    delegate = interpreter_wrapper.load_delegate(self._delegate_file, options)
395    return interpreter_wrapper.Interpreter(
396        model_path=model_path, experimental_delegates=[delegate])
397
398  def testDelegate(self):
399    """Tests the delegate creation and destruction."""
400    interpreter = self._TestInterpreter(model_path=self._model_file)
401    lib = interpreter._delegates[0]._library
402
403    self.assertEqual(lib.get_num_delegates_created(), 1)
404    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
405    self.assertEqual(lib.get_num_delegates_invoked(), 1)
406
407    del interpreter
408
409    self.assertEqual(lib.get_num_delegates_created(), 1)
410    self.assertEqual(lib.get_num_delegates_destroyed(), 1)
411    self.assertEqual(lib.get_num_delegates_invoked(), 1)
412
413  def testMultipleInterpreters(self):
414    delegate = interpreter_wrapper.load_delegate(self._delegate_file)
415    lib = delegate._library
416
417    self.assertEqual(lib.get_num_delegates_created(), 1)
418    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
419    self.assertEqual(lib.get_num_delegates_invoked(), 0)
420
421    interpreter_a = interpreter_wrapper.Interpreter(
422        model_path=self._model_file, experimental_delegates=[delegate])
423
424    self.assertEqual(lib.get_num_delegates_created(), 1)
425    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
426    self.assertEqual(lib.get_num_delegates_invoked(), 1)
427
428    interpreter_b = interpreter_wrapper.Interpreter(
429        model_path=self._model_file, experimental_delegates=[delegate])
430
431    self.assertEqual(lib.get_num_delegates_created(), 1)
432    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
433    self.assertEqual(lib.get_num_delegates_invoked(), 2)
434
435    del delegate
436    del interpreter_a
437
438    self.assertEqual(lib.get_num_delegates_created(), 1)
439    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
440    self.assertEqual(lib.get_num_delegates_invoked(), 2)
441
442    del interpreter_b
443
444    self.assertEqual(lib.get_num_delegates_created(), 1)
445    self.assertEqual(lib.get_num_delegates_destroyed(), 1)
446    self.assertEqual(lib.get_num_delegates_invoked(), 2)
447
448  def testDestructionOrder(self):
449    """Make sure internal _interpreter object is destroyed before delegate."""
450    self.skipTest('TODO(b/142136355): fix flakiness and re-enable')
451    # Track which order destructions were doned in
452    destructions = []
453
454    def register_destruction(x):
455      destructions.append(
456          x if isinstance(x, str) else six.ensure_text(x, 'utf-8'))
457      return 0
458
459    # Make a wrapper for the callback so we can send this to ctypes
460    delegate = interpreter_wrapper.load_delegate(self._delegate_file)
461    # Make an interpreter with the delegate
462    interpreter = interpreter_wrapper.Interpreter(
463        model_path=resource_loader.get_path_to_datafile(
464            'testdata/permute_float.tflite'),
465        experimental_delegates=[delegate])
466
467    class InterpreterDestroyCallback(object):
468
469      def __del__(self):
470        register_destruction('interpreter')
471
472    interpreter._interpreter.stuff = InterpreterDestroyCallback()
473    # Destroy both delegate and interpreter
474    library = delegate._library
475    prototype = ctypes.CFUNCTYPE(ctypes.c_int, (ctypes.c_char_p))
476    library.set_destroy_callback(prototype(register_destruction))
477    del delegate
478    del interpreter
479    library.set_destroy_callback(None)
480    # check the interpreter was destroyed before the delegate
481    self.assertEqual(destructions, ['interpreter', 'test_delegate'])
482
483  def testOptions(self):
484    delegate_a = interpreter_wrapper.load_delegate(self._delegate_file)
485    lib = delegate_a._library
486
487    self.assertEqual(lib.get_num_delegates_created(), 1)
488    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
489    self.assertEqual(lib.get_num_delegates_invoked(), 0)
490    self.assertEqual(lib.get_options_counter(), 0)
491
492    delegate_b = interpreter_wrapper.load_delegate(
493        self._delegate_file, options={
494            'unused': False,
495            'options_counter': 2
496        })
497    lib = delegate_b._library
498
499    self.assertEqual(lib.get_num_delegates_created(), 2)
500    self.assertEqual(lib.get_num_delegates_destroyed(), 0)
501    self.assertEqual(lib.get_num_delegates_invoked(), 0)
502    self.assertEqual(lib.get_options_counter(), 2)
503
504    del delegate_a
505    del delegate_b
506
507    self.assertEqual(lib.get_num_delegates_created(), 2)
508    self.assertEqual(lib.get_num_delegates_destroyed(), 2)
509    self.assertEqual(lib.get_num_delegates_invoked(), 0)
510    self.assertEqual(lib.get_options_counter(), 2)
511
512  def testFail(self):
513    with self.assertRaisesRegex(
514        # Due to exception chaining in PY3, we can't be more specific here and check that
515        # the phrase 'Fail argument sent' is present.
516        ValueError,
517        r'Failed to load delegate from'):
518      interpreter_wrapper.load_delegate(
519          self._delegate_file, options={'fail': 'fail'})
520
521
522if __name__ == '__main__':
523  test.main()
524