1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""TensorFlow Lite Python Interface: Sanity check.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import ctypes 22import io 23import sys 24 25import numpy as np 26import six 27 28# Force loaded shared object symbols to be globally visible. This is needed so 29# that the interpreter_wrapper, in one .so file, can see the test_registerer, 30# in a different .so file. Note that this may already be set by default. 31# pylint: disable=g-import-not-at-top 32if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'): 33 sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) 34 35from tensorflow.lite.python import interpreter as interpreter_wrapper 36from tensorflow.lite.python.testdata import _pywrap_test_registerer as test_registerer 37from tensorflow.python.framework import test_util 38from tensorflow.python.platform import resource_loader 39from tensorflow.python.platform import test 40# pylint: enable=g-import-not-at-top 41 42 43class InterpreterCustomOpsTest(test_util.TensorFlowTestCase): 44 45 def testRegistererByName(self): 46 interpreter = interpreter_wrapper.InterpreterWithCustomOps( 47 model_path=resource_loader.get_path_to_datafile( 48 'testdata/permute_float.tflite'), 49 custom_op_registerers=['TF_TestRegisterer']) 50 self.assertTrue(interpreter._safe_to_run()) 51 self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) 52 53 def testRegistererByFunc(self): 54 interpreter = interpreter_wrapper.InterpreterWithCustomOps( 55 model_path=resource_loader.get_path_to_datafile( 56 'testdata/permute_float.tflite'), 57 custom_op_registerers=[test_registerer.TF_TestRegisterer]) 58 self.assertTrue(interpreter._safe_to_run()) 59 self.assertEqual(test_registerer.get_num_test_registerer_calls(), 1) 60 61 def testRegistererFailure(self): 62 bogus_name = 'CompletelyBogusRegistererName' 63 with self.assertRaisesRegex( 64 ValueError, 'Looking up symbol \'' + bogus_name + '\' failed'): 65 interpreter_wrapper.InterpreterWithCustomOps( 66 model_path=resource_loader.get_path_to_datafile( 67 'testdata/permute_float.tflite'), 68 custom_op_registerers=[bogus_name]) 69 70 def testNoCustomOps(self): 71 interpreter = interpreter_wrapper.InterpreterWithCustomOps( 72 model_path=resource_loader.get_path_to_datafile( 73 'testdata/permute_float.tflite')) 74 self.assertTrue(interpreter._safe_to_run()) 75 76 77class InterpreterTest(test_util.TensorFlowTestCase): 78 79 def assertQuantizationParamsEqual(self, scales, zero_points, 80 quantized_dimension, params): 81 self.assertAllEqual(scales, params['scales']) 82 self.assertAllEqual(zero_points, params['zero_points']) 83 self.assertEqual(quantized_dimension, params['quantized_dimension']) 84 85 def testThreads_NegativeValue(self): 86 with self.assertRaisesRegex(ValueError, 'num_threads should >= 1'): 87 interpreter_wrapper.Interpreter( 88 model_path=resource_loader.get_path_to_datafile( 89 'testdata/permute_float.tflite'), 90 num_threads=-1) 91 92 def testThreads_WrongType(self): 93 with self.assertRaisesRegex(ValueError, 94 'type of num_threads should be int'): 95 interpreter_wrapper.Interpreter( 96 model_path=resource_loader.get_path_to_datafile( 97 'testdata/permute_float.tflite'), 98 num_threads=4.2) 99 100 def testFloat(self): 101 interpreter = interpreter_wrapper.Interpreter( 102 model_path=resource_loader.get_path_to_datafile( 103 'testdata/permute_float.tflite')) 104 interpreter.allocate_tensors() 105 106 input_details = interpreter.get_input_details() 107 self.assertEqual(1, len(input_details)) 108 self.assertEqual('input', input_details[0]['name']) 109 self.assertEqual(np.float32, input_details[0]['dtype']) 110 self.assertTrue(([1, 4] == input_details[0]['shape']).all()) 111 self.assertEqual((0.0, 0), input_details[0]['quantization']) 112 self.assertQuantizationParamsEqual( 113 [], [], 0, input_details[0]['quantization_parameters']) 114 115 output_details = interpreter.get_output_details() 116 self.assertEqual(1, len(output_details)) 117 self.assertEqual('output', output_details[0]['name']) 118 self.assertEqual(np.float32, output_details[0]['dtype']) 119 self.assertTrue(([1, 4] == output_details[0]['shape']).all()) 120 self.assertEqual((0.0, 0), output_details[0]['quantization']) 121 self.assertQuantizationParamsEqual( 122 [], [], 0, output_details[0]['quantization_parameters']) 123 124 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 125 expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) 126 interpreter.set_tensor(input_details[0]['index'], test_input) 127 interpreter.invoke() 128 129 output_data = interpreter.get_tensor(output_details[0]['index']) 130 self.assertTrue((expected_output == output_data).all()) 131 132 def testFloatWithTwoThreads(self): 133 interpreter = interpreter_wrapper.Interpreter( 134 model_path=resource_loader.get_path_to_datafile( 135 'testdata/permute_float.tflite'), 136 num_threads=2) 137 interpreter.allocate_tensors() 138 139 input_details = interpreter.get_input_details() 140 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 141 expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) 142 interpreter.set_tensor(input_details[0]['index'], test_input) 143 interpreter.invoke() 144 145 output_details = interpreter.get_output_details() 146 output_data = interpreter.get_tensor(output_details[0]['index']) 147 self.assertTrue((expected_output == output_data).all()) 148 149 def testUint8(self): 150 model_path = resource_loader.get_path_to_datafile( 151 'testdata/permute_uint8.tflite') 152 with io.open(model_path, 'rb') as model_file: 153 data = model_file.read() 154 155 interpreter = interpreter_wrapper.Interpreter(model_content=data) 156 interpreter.allocate_tensors() 157 158 input_details = interpreter.get_input_details() 159 self.assertEqual(1, len(input_details)) 160 self.assertEqual('input', input_details[0]['name']) 161 self.assertEqual(np.uint8, input_details[0]['dtype']) 162 self.assertTrue(([1, 4] == input_details[0]['shape']).all()) 163 self.assertEqual((1.0, 0), input_details[0]['quantization']) 164 self.assertQuantizationParamsEqual( 165 [1.0], [0], 0, input_details[0]['quantization_parameters']) 166 167 output_details = interpreter.get_output_details() 168 self.assertEqual(1, len(output_details)) 169 self.assertEqual('output', output_details[0]['name']) 170 self.assertEqual(np.uint8, output_details[0]['dtype']) 171 self.assertTrue(([1, 4] == output_details[0]['shape']).all()) 172 self.assertEqual((1.0, 0), output_details[0]['quantization']) 173 self.assertQuantizationParamsEqual( 174 [1.0], [0], 0, output_details[0]['quantization_parameters']) 175 176 test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) 177 expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) 178 interpreter.resize_tensor_input(input_details[0]['index'], test_input.shape) 179 interpreter.allocate_tensors() 180 interpreter.set_tensor(input_details[0]['index'], test_input) 181 interpreter.invoke() 182 183 output_data = interpreter.get_tensor(output_details[0]['index']) 184 self.assertTrue((expected_output == output_data).all()) 185 186 def testString(self): 187 interpreter = interpreter_wrapper.Interpreter( 188 model_path=resource_loader.get_path_to_datafile( 189 'testdata/gather_string.tflite')) 190 interpreter.allocate_tensors() 191 192 input_details = interpreter.get_input_details() 193 self.assertEqual(2, len(input_details)) 194 self.assertEqual('input', input_details[0]['name']) 195 self.assertEqual(np.string_, input_details[0]['dtype']) 196 self.assertTrue(([10] == input_details[0]['shape']).all()) 197 self.assertEqual((0.0, 0), input_details[0]['quantization']) 198 self.assertQuantizationParamsEqual( 199 [], [], 0, input_details[0]['quantization_parameters']) 200 self.assertEqual('indices', input_details[1]['name']) 201 self.assertEqual(np.int64, input_details[1]['dtype']) 202 self.assertTrue(([3] == input_details[1]['shape']).all()) 203 self.assertEqual((0.0, 0), input_details[1]['quantization']) 204 self.assertQuantizationParamsEqual( 205 [], [], 0, input_details[1]['quantization_parameters']) 206 207 output_details = interpreter.get_output_details() 208 self.assertEqual(1, len(output_details)) 209 self.assertEqual('output', output_details[0]['name']) 210 self.assertEqual(np.string_, output_details[0]['dtype']) 211 self.assertTrue(([3] == output_details[0]['shape']).all()) 212 self.assertEqual((0.0, 0), output_details[0]['quantization']) 213 self.assertQuantizationParamsEqual( 214 [], [], 0, output_details[0]['quantization_parameters']) 215 216 test_input = np.array([1, 2, 3], dtype=np.int64) 217 interpreter.set_tensor(input_details[1]['index'], test_input) 218 219 test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']) 220 expected_output = np.array([b'b', b'c', b'd']) 221 interpreter.set_tensor(input_details[0]['index'], test_input) 222 interpreter.invoke() 223 224 output_data = interpreter.get_tensor(output_details[0]['index']) 225 self.assertTrue((expected_output == output_data).all()) 226 227 def testStringZeroDim(self): 228 data = b'abcd' + bytes(16) 229 interpreter = interpreter_wrapper.Interpreter( 230 model_path=resource_loader.get_path_to_datafile( 231 'testdata/gather_string_0d.tflite')) 232 interpreter.allocate_tensors() 233 234 input_details = interpreter.get_input_details() 235 interpreter.set_tensor(input_details[0]['index'], np.array(data)) 236 test_input_tensor = interpreter.get_tensor(input_details[0]['index']) 237 self.assertEqual(len(data), len(test_input_tensor.item(0))) 238 239 def testPerChannelParams(self): 240 interpreter = interpreter_wrapper.Interpreter( 241 model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin')) 242 interpreter.allocate_tensors() 243 244 # Tensor index 1 is the weight. 245 weight_details = interpreter.get_tensor_details()[1] 246 qparams = weight_details['quantization_parameters'] 247 # Ensure that we retrieve per channel quantization params correctly. 248 self.assertEqual(len(qparams['scales']), 128) 249 250 def testDenseTensorAccess(self): 251 interpreter = interpreter_wrapper.Interpreter( 252 model_path=resource_loader.get_path_to_datafile('testdata/pc_conv.bin')) 253 interpreter.allocate_tensors() 254 weight_details = interpreter.get_tensor_details()[1] 255 s_params = weight_details['sparsity_parameters'] 256 self.assertEqual(s_params, {}) 257 258 def testSparseTensorAccess(self): 259 interpreter = interpreter_wrapper.InterpreterWithCustomOps( 260 model_path=resource_loader.get_path_to_datafile( 261 '../testdata/sparse_tensor.bin'), 262 custom_op_registerers=['TF_TestRegisterer']) 263 interpreter.allocate_tensors() 264 265 # Tensor at index 0 is sparse. 266 compressed_buffer = interpreter.get_tensor(0) 267 # Ensure that the buffer is of correct size and value. 268 self.assertEqual(len(compressed_buffer), 12) 269 sparse_value = [1, 0, 0, 4, 2, 3, 0, 0, 5, 0, 0, 6] 270 self.assertAllEqual(compressed_buffer, sparse_value) 271 272 tensor_details = interpreter.get_tensor_details()[0] 273 s_params = tensor_details['sparsity_parameters'] 274 275 # Ensure sparsity parameter returned is correct 276 self.assertAllEqual(s_params['traversal_order'], [0, 1, 2, 3]) 277 self.assertAllEqual(s_params['block_map'], [0, 1]) 278 dense_dim_metadata = {'format': 0, 'dense_size': 2} 279 self.assertAllEqual(s_params['dim_metadata'][0], dense_dim_metadata) 280 self.assertAllEqual(s_params['dim_metadata'][2], dense_dim_metadata) 281 self.assertAllEqual(s_params['dim_metadata'][3], dense_dim_metadata) 282 self.assertEqual(s_params['dim_metadata'][1]['format'], 1) 283 self.assertAllEqual(s_params['dim_metadata'][1]['array_segments'], 284 [0, 2, 3]) 285 self.assertAllEqual(s_params['dim_metadata'][1]['array_indices'], [0, 1, 1]) 286 287 288class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): 289 290 def testInvalidModelContent(self): 291 with self.assertRaisesRegex(ValueError, 292 'Model provided has model identifier \''): 293 interpreter_wrapper.Interpreter(model_content=six.b('garbage')) 294 295 def testInvalidModelFile(self): 296 with self.assertRaisesRegex(ValueError, 297 'Could not open \'totally_invalid_file_name\''): 298 interpreter_wrapper.Interpreter(model_path='totally_invalid_file_name') 299 300 def testInvokeBeforeReady(self): 301 interpreter = interpreter_wrapper.Interpreter( 302 model_path=resource_loader.get_path_to_datafile( 303 'testdata/permute_float.tflite')) 304 with self.assertRaisesRegex(RuntimeError, 305 'Invoke called on model that is not ready'): 306 interpreter.invoke() 307 308 def testInvalidModelFileContent(self): 309 with self.assertRaisesRegex( 310 ValueError, '`model_path` or `model_content` must be specified.'): 311 interpreter_wrapper.Interpreter(model_path=None, model_content=None) 312 313 def testInvalidIndex(self): 314 interpreter = interpreter_wrapper.Interpreter( 315 model_path=resource_loader.get_path_to_datafile( 316 'testdata/permute_float.tflite')) 317 interpreter.allocate_tensors() 318 # Invalid tensor index passed. 319 with self.assertRaisesRegex(ValueError, 'Tensor with no shape found.'): 320 interpreter._get_tensor_details(4) 321 with self.assertRaisesRegex(ValueError, 'Invalid node index'): 322 interpreter._get_op_details(4) 323 324 325class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): 326 327 def setUp(self): 328 self.interpreter = interpreter_wrapper.Interpreter( 329 model_path=resource_loader.get_path_to_datafile( 330 'testdata/permute_float.tflite')) 331 self.interpreter.allocate_tensors() 332 self.input0 = self.interpreter.get_input_details()[0]['index'] 333 self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) 334 335 def testTensorAccessor(self): 336 """Check that tensor returns a reference.""" 337 array_ref = self.interpreter.tensor(self.input0) 338 np.copyto(array_ref(), self.initial_data) 339 self.assertAllEqual(array_ref(), self.initial_data) 340 self.assertAllEqual( 341 self.interpreter.get_tensor(self.input0), self.initial_data) 342 343 def testGetTensorAccessor(self): 344 """Check that get_tensor returns a copy.""" 345 self.interpreter.set_tensor(self.input0, self.initial_data) 346 array_initial_copy = self.interpreter.get_tensor(self.input0) 347 new_value = np.add(1., array_initial_copy) 348 self.interpreter.set_tensor(self.input0, new_value) 349 self.assertAllEqual(array_initial_copy, self.initial_data) 350 self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) 351 352 def testBase(self): 353 self.assertTrue(self.interpreter._safe_to_run()) 354 _ = self.interpreter.tensor(self.input0) 355 self.assertTrue(self.interpreter._safe_to_run()) 356 in0 = self.interpreter.tensor(self.input0)() 357 self.assertFalse(self.interpreter._safe_to_run()) 358 in0b = self.interpreter.tensor(self.input0)() 359 self.assertFalse(self.interpreter._safe_to_run()) 360 # Now get rid of the buffers so that we can evaluate. 361 del in0 362 del in0b 363 self.assertTrue(self.interpreter._safe_to_run()) 364 365 def testBaseProtectsFunctions(self): 366 in0 = self.interpreter.tensor(self.input0)() 367 # Make sure we get an exception if we try to run an unsafe operation 368 with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'): 369 _ = self.interpreter.allocate_tensors() 370 # Make sure we get an exception if we try to run an unsafe operation 371 with self.assertRaisesRegex(RuntimeError, 'There is at least 1 reference'): 372 _ = self.interpreter.invoke() 373 # Now test that we can run 374 del in0 # this is our only buffer reference, so now it is safe to change 375 in0safe = self.interpreter.tensor(self.input0) 376 _ = self.interpreter.allocate_tensors() 377 del in0safe # make sure in0Safe is held but lint doesn't complain 378 379 380class InterpreterDelegateTest(test_util.TensorFlowTestCase): 381 382 def setUp(self): 383 self._delegate_file = resource_loader.get_path_to_datafile( 384 'testdata/test_delegate.so') 385 self._model_file = resource_loader.get_path_to_datafile( 386 'testdata/permute_float.tflite') 387 388 # Load the library to reset the counters. 389 library = ctypes.pydll.LoadLibrary(self._delegate_file) 390 library.initialize_counters() 391 392 def _TestInterpreter(self, model_path, options=None): 393 """Test wrapper function that creates an interpreter with the delegate.""" 394 delegate = interpreter_wrapper.load_delegate(self._delegate_file, options) 395 return interpreter_wrapper.Interpreter( 396 model_path=model_path, experimental_delegates=[delegate]) 397 398 def testDelegate(self): 399 """Tests the delegate creation and destruction.""" 400 interpreter = self._TestInterpreter(model_path=self._model_file) 401 lib = interpreter._delegates[0]._library 402 403 self.assertEqual(lib.get_num_delegates_created(), 1) 404 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 405 self.assertEqual(lib.get_num_delegates_invoked(), 1) 406 407 del interpreter 408 409 self.assertEqual(lib.get_num_delegates_created(), 1) 410 self.assertEqual(lib.get_num_delegates_destroyed(), 1) 411 self.assertEqual(lib.get_num_delegates_invoked(), 1) 412 413 def testMultipleInterpreters(self): 414 delegate = interpreter_wrapper.load_delegate(self._delegate_file) 415 lib = delegate._library 416 417 self.assertEqual(lib.get_num_delegates_created(), 1) 418 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 419 self.assertEqual(lib.get_num_delegates_invoked(), 0) 420 421 interpreter_a = interpreter_wrapper.Interpreter( 422 model_path=self._model_file, experimental_delegates=[delegate]) 423 424 self.assertEqual(lib.get_num_delegates_created(), 1) 425 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 426 self.assertEqual(lib.get_num_delegates_invoked(), 1) 427 428 interpreter_b = interpreter_wrapper.Interpreter( 429 model_path=self._model_file, experimental_delegates=[delegate]) 430 431 self.assertEqual(lib.get_num_delegates_created(), 1) 432 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 433 self.assertEqual(lib.get_num_delegates_invoked(), 2) 434 435 del delegate 436 del interpreter_a 437 438 self.assertEqual(lib.get_num_delegates_created(), 1) 439 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 440 self.assertEqual(lib.get_num_delegates_invoked(), 2) 441 442 del interpreter_b 443 444 self.assertEqual(lib.get_num_delegates_created(), 1) 445 self.assertEqual(lib.get_num_delegates_destroyed(), 1) 446 self.assertEqual(lib.get_num_delegates_invoked(), 2) 447 448 def testDestructionOrder(self): 449 """Make sure internal _interpreter object is destroyed before delegate.""" 450 self.skipTest('TODO(b/142136355): fix flakiness and re-enable') 451 # Track which order destructions were doned in 452 destructions = [] 453 454 def register_destruction(x): 455 destructions.append( 456 x if isinstance(x, str) else six.ensure_text(x, 'utf-8')) 457 return 0 458 459 # Make a wrapper for the callback so we can send this to ctypes 460 delegate = interpreter_wrapper.load_delegate(self._delegate_file) 461 # Make an interpreter with the delegate 462 interpreter = interpreter_wrapper.Interpreter( 463 model_path=resource_loader.get_path_to_datafile( 464 'testdata/permute_float.tflite'), 465 experimental_delegates=[delegate]) 466 467 class InterpreterDestroyCallback(object): 468 469 def __del__(self): 470 register_destruction('interpreter') 471 472 interpreter._interpreter.stuff = InterpreterDestroyCallback() 473 # Destroy both delegate and interpreter 474 library = delegate._library 475 prototype = ctypes.CFUNCTYPE(ctypes.c_int, (ctypes.c_char_p)) 476 library.set_destroy_callback(prototype(register_destruction)) 477 del delegate 478 del interpreter 479 library.set_destroy_callback(None) 480 # check the interpreter was destroyed before the delegate 481 self.assertEqual(destructions, ['interpreter', 'test_delegate']) 482 483 def testOptions(self): 484 delegate_a = interpreter_wrapper.load_delegate(self._delegate_file) 485 lib = delegate_a._library 486 487 self.assertEqual(lib.get_num_delegates_created(), 1) 488 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 489 self.assertEqual(lib.get_num_delegates_invoked(), 0) 490 self.assertEqual(lib.get_options_counter(), 0) 491 492 delegate_b = interpreter_wrapper.load_delegate( 493 self._delegate_file, options={ 494 'unused': False, 495 'options_counter': 2 496 }) 497 lib = delegate_b._library 498 499 self.assertEqual(lib.get_num_delegates_created(), 2) 500 self.assertEqual(lib.get_num_delegates_destroyed(), 0) 501 self.assertEqual(lib.get_num_delegates_invoked(), 0) 502 self.assertEqual(lib.get_options_counter(), 2) 503 504 del delegate_a 505 del delegate_b 506 507 self.assertEqual(lib.get_num_delegates_created(), 2) 508 self.assertEqual(lib.get_num_delegates_destroyed(), 2) 509 self.assertEqual(lib.get_num_delegates_invoked(), 0) 510 self.assertEqual(lib.get_options_counter(), 2) 511 512 def testFail(self): 513 with self.assertRaisesRegex( 514 # Due to exception chaining in PY3, we can't be more specific here and check that 515 # the phrase 'Fail argument sent' is present. 516 ValueError, 517 r'Failed to load delegate from'): 518 interpreter_wrapper.load_delegate( 519 self._delegate_file, options={'fail': 'fail'}) 520 521 522if __name__ == '__main__': 523 test.main() 524