1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Lite Python Interface: Sanity check.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import io 21import numpy as np 22import six 23 24from tensorflow.lite.python import interpreter as interpreter_wrapper 25from tensorflow.python.framework import test_util 26from tensorflow.python.platform import resource_loader 27from tensorflow.python.platform import test 28 29 30class InterpreterTest(test_util.TensorFlowTestCase): 31 32 def testFloat(self): 33 interpreter = interpreter_wrapper.Interpreter( 34 model_path=resource_loader.get_path_to_datafile( 35 'testdata/permute_float.tflite')) 36 interpreter.allocate_tensors() 37 38 input_details = interpreter.get_input_details() 39 self.assertEqual(1, len(input_details)) 40 self.assertEqual('input', input_details[0]['name']) 41 self.assertEqual(np.float32, input_details[0]['dtype']) 42 self.assertTrue(([1, 4] == input_details[0]['shape']).all()) 43 self.assertEqual((0.0, 0), input_details[0]['quantization']) 44 45 output_details = interpreter.get_output_details() 46 self.assertEqual(1, len(output_details)) 47 self.assertEqual('output', output_details[0]['name']) 48 self.assertEqual(np.float32, output_details[0]['dtype']) 49 self.assertTrue(([1, 4] == output_details[0]['shape']).all()) 50 self.assertEqual((0.0, 0), output_details[0]['quantization']) 51 52 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 53 expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32) 54 interpreter.set_tensor(input_details[0]['index'], test_input) 55 interpreter.invoke() 56 57 output_data = interpreter.get_tensor(output_details[0]['index']) 58 self.assertTrue((expected_output == output_data).all()) 59 60 def testUint8(self): 61 model_path = resource_loader.get_path_to_datafile( 62 'testdata/permute_uint8.tflite') 63 with io.open(model_path, 'rb') as model_file: 64 data = model_file.read() 65 66 interpreter = interpreter_wrapper.Interpreter(model_content=data) 67 interpreter.allocate_tensors() 68 69 input_details = interpreter.get_input_details() 70 self.assertEqual(1, len(input_details)) 71 self.assertEqual('input', input_details[0]['name']) 72 self.assertEqual(np.uint8, input_details[0]['dtype']) 73 self.assertTrue(([1, 4] == input_details[0]['shape']).all()) 74 self.assertEqual((1.0, 0), input_details[0]['quantization']) 75 76 output_details = interpreter.get_output_details() 77 self.assertEqual(1, len(output_details)) 78 self.assertEqual('output', output_details[0]['name']) 79 self.assertEqual(np.uint8, output_details[0]['dtype']) 80 self.assertTrue(([1, 4] == output_details[0]['shape']).all()) 81 self.assertEqual((1.0, 0), output_details[0]['quantization']) 82 83 test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) 84 expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) 85 interpreter.resize_tensor_input(input_details[0]['index'], 86 test_input.shape) 87 interpreter.allocate_tensors() 88 interpreter.set_tensor(input_details[0]['index'], test_input) 89 interpreter.invoke() 90 91 output_data = interpreter.get_tensor(output_details[0]['index']) 92 self.assertTrue((expected_output == output_data).all()) 93 94 def testString(self): 95 interpreter = interpreter_wrapper.Interpreter( 96 model_path=resource_loader.get_path_to_datafile( 97 'testdata/gather_string.tflite')) 98 interpreter.allocate_tensors() 99 100 input_details = interpreter.get_input_details() 101 self.assertEqual(2, len(input_details)) 102 self.assertEqual('input', input_details[0]['name']) 103 self.assertEqual(np.string_, input_details[0]['dtype']) 104 self.assertTrue(([10] == input_details[0]['shape']).all()) 105 self.assertEqual((0.0, 0), input_details[0]['quantization']) 106 self.assertEqual('indices', input_details[1]['name']) 107 self.assertEqual(np.int64, input_details[1]['dtype']) 108 self.assertTrue(([3] == input_details[1]['shape']).all()) 109 self.assertEqual((0.0, 0), input_details[1]['quantization']) 110 111 output_details = interpreter.get_output_details() 112 self.assertEqual(1, len(output_details)) 113 self.assertEqual('output', output_details[0]['name']) 114 self.assertEqual(np.string_, output_details[0]['dtype']) 115 self.assertTrue(([3] == output_details[0]['shape']).all()) 116 self.assertEqual((0.0, 0), output_details[0]['quantization']) 117 118 test_input = np.array([1, 2, 3], dtype=np.int64) 119 interpreter.set_tensor(input_details[1]['index'], test_input) 120 121 test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']) 122 expected_output = np.array([b'b', b'c', b'd']) 123 interpreter.set_tensor(input_details[0]['index'], test_input) 124 interpreter.invoke() 125 126 output_data = interpreter.get_tensor(output_details[0]['index']) 127 self.assertTrue((expected_output == output_data).all()) 128 129 130class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): 131 132 def testInvalidModelContent(self): 133 with self.assertRaisesRegexp(ValueError, 134 'Model provided has model identifier \''): 135 interpreter_wrapper.Interpreter(model_content=six.b('garbage')) 136 137 def testInvalidModelFile(self): 138 with self.assertRaisesRegexp( 139 ValueError, 'Could not open \'totally_invalid_file_name\''): 140 interpreter_wrapper.Interpreter( 141 model_path='totally_invalid_file_name') 142 143 def testInvokeBeforeReady(self): 144 interpreter = interpreter_wrapper.Interpreter( 145 model_path=resource_loader.get_path_to_datafile( 146 'testdata/permute_float.tflite')) 147 with self.assertRaisesRegexp(RuntimeError, 148 'Invoke called on model that is not ready'): 149 interpreter.invoke() 150 151 152class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): 153 154 def setUp(self): 155 self.interpreter = interpreter_wrapper.Interpreter( 156 model_path=resource_loader.get_path_to_datafile( 157 'testdata/permute_float.tflite')) 158 self.interpreter.allocate_tensors() 159 self.input0 = self.interpreter.get_input_details()[0]['index'] 160 self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) 161 162 def testTensorAccessor(self): 163 """Check that tensor returns a reference.""" 164 array_ref = self.interpreter.tensor(self.input0) 165 np.copyto(array_ref(), self.initial_data) 166 self.assertAllEqual(array_ref(), self.initial_data) 167 self.assertAllEqual( 168 self.interpreter.get_tensor(self.input0), self.initial_data) 169 170 def testGetTensorAccessor(self): 171 """Check that get_tensor returns a copy.""" 172 self.interpreter.set_tensor(self.input0, self.initial_data) 173 array_initial_copy = self.interpreter.get_tensor(self.input0) 174 new_value = np.add(1., array_initial_copy) 175 self.interpreter.set_tensor(self.input0, new_value) 176 self.assertAllEqual(array_initial_copy, self.initial_data) 177 self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) 178 179 def testBase(self): 180 self.assertTrue(self.interpreter._safe_to_run()) 181 _ = self.interpreter.tensor(self.input0) 182 self.assertTrue(self.interpreter._safe_to_run()) 183 in0 = self.interpreter.tensor(self.input0)() 184 self.assertFalse(self.interpreter._safe_to_run()) 185 in0b = self.interpreter.tensor(self.input0)() 186 self.assertFalse(self.interpreter._safe_to_run()) 187 # Now get rid of the buffers so that we can evaluate. 188 del in0 189 del in0b 190 self.assertTrue(self.interpreter._safe_to_run()) 191 192 def testBaseProtectsFunctions(self): 193 in0 = self.interpreter.tensor(self.input0)() 194 # Make sure we get an exception if we try to run an unsafe operation 195 with self.assertRaisesRegexp( 196 RuntimeError, 'There is at least 1 reference'): 197 _ = self.interpreter.allocate_tensors() 198 # Make sure we get an exception if we try to run an unsafe operation 199 with self.assertRaisesRegexp( 200 RuntimeError, 'There is at least 1 reference'): 201 _ = self.interpreter.invoke() 202 # Now test that we can run 203 del in0 # this is our only buffer reference, so now it is safe to change 204 in0safe = self.interpreter.tensor(self.input0) 205 _ = self.interpreter.allocate_tensors() 206 del in0safe # make sure in0Safe is held but lint doesn't complain 207 208if __name__ == '__main__': 209 test.main() 210