1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Lite Python Interface: Sanity check."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import io
21import numpy as np
22import six
23
24from tensorflow.lite.python import interpreter as interpreter_wrapper
25from tensorflow.python.framework import test_util
26from tensorflow.python.platform import resource_loader
27from tensorflow.python.platform import test
28
29
30class InterpreterTest(test_util.TensorFlowTestCase):
31
32  def testFloat(self):
33    interpreter = interpreter_wrapper.Interpreter(
34        model_path=resource_loader.get_path_to_datafile(
35            'testdata/permute_float.tflite'))
36    interpreter.allocate_tensors()
37
38    input_details = interpreter.get_input_details()
39    self.assertEqual(1, len(input_details))
40    self.assertEqual('input', input_details[0]['name'])
41    self.assertEqual(np.float32, input_details[0]['dtype'])
42    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
43    self.assertEqual((0.0, 0), input_details[0]['quantization'])
44
45    output_details = interpreter.get_output_details()
46    self.assertEqual(1, len(output_details))
47    self.assertEqual('output', output_details[0]['name'])
48    self.assertEqual(np.float32, output_details[0]['dtype'])
49    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
50    self.assertEqual((0.0, 0), output_details[0]['quantization'])
51
52    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
53    expected_output = np.array([[4.0, 3.0, 2.0, 1.0]], dtype=np.float32)
54    interpreter.set_tensor(input_details[0]['index'], test_input)
55    interpreter.invoke()
56
57    output_data = interpreter.get_tensor(output_details[0]['index'])
58    self.assertTrue((expected_output == output_data).all())
59
60  def testUint8(self):
61    model_path = resource_loader.get_path_to_datafile(
62        'testdata/permute_uint8.tflite')
63    with io.open(model_path, 'rb') as model_file:
64      data = model_file.read()
65
66    interpreter = interpreter_wrapper.Interpreter(model_content=data)
67    interpreter.allocate_tensors()
68
69    input_details = interpreter.get_input_details()
70    self.assertEqual(1, len(input_details))
71    self.assertEqual('input', input_details[0]['name'])
72    self.assertEqual(np.uint8, input_details[0]['dtype'])
73    self.assertTrue(([1, 4] == input_details[0]['shape']).all())
74    self.assertEqual((1.0, 0), input_details[0]['quantization'])
75
76    output_details = interpreter.get_output_details()
77    self.assertEqual(1, len(output_details))
78    self.assertEqual('output', output_details[0]['name'])
79    self.assertEqual(np.uint8, output_details[0]['dtype'])
80    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
81    self.assertEqual((1.0, 0), output_details[0]['quantization'])
82
83    test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
84    expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
85    interpreter.resize_tensor_input(input_details[0]['index'],
86                                    test_input.shape)
87    interpreter.allocate_tensors()
88    interpreter.set_tensor(input_details[0]['index'], test_input)
89    interpreter.invoke()
90
91    output_data = interpreter.get_tensor(output_details[0]['index'])
92    self.assertTrue((expected_output == output_data).all())
93
94  def testString(self):
95    interpreter = interpreter_wrapper.Interpreter(
96        model_path=resource_loader.get_path_to_datafile(
97            'testdata/gather_string.tflite'))
98    interpreter.allocate_tensors()
99
100    input_details = interpreter.get_input_details()
101    self.assertEqual(2, len(input_details))
102    self.assertEqual('input', input_details[0]['name'])
103    self.assertEqual(np.string_, input_details[0]['dtype'])
104    self.assertTrue(([10] == input_details[0]['shape']).all())
105    self.assertEqual((0.0, 0), input_details[0]['quantization'])
106    self.assertEqual('indices', input_details[1]['name'])
107    self.assertEqual(np.int64, input_details[1]['dtype'])
108    self.assertTrue(([3] == input_details[1]['shape']).all())
109    self.assertEqual((0.0, 0), input_details[1]['quantization'])
110
111    output_details = interpreter.get_output_details()
112    self.assertEqual(1, len(output_details))
113    self.assertEqual('output', output_details[0]['name'])
114    self.assertEqual(np.string_, output_details[0]['dtype'])
115    self.assertTrue(([3] == output_details[0]['shape']).all())
116    self.assertEqual((0.0, 0), output_details[0]['quantization'])
117
118    test_input = np.array([1, 2, 3], dtype=np.int64)
119    interpreter.set_tensor(input_details[1]['index'], test_input)
120
121    test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
122    expected_output = np.array([b'b', b'c', b'd'])
123    interpreter.set_tensor(input_details[0]['index'], test_input)
124    interpreter.invoke()
125
126    output_data = interpreter.get_tensor(output_details[0]['index'])
127    self.assertTrue((expected_output == output_data).all())
128
129
130class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
131
132  def testInvalidModelContent(self):
133    with self.assertRaisesRegexp(ValueError,
134                                 'Model provided has model identifier \''):
135      interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
136
137  def testInvalidModelFile(self):
138    with self.assertRaisesRegexp(
139        ValueError, 'Could not open \'totally_invalid_file_name\''):
140      interpreter_wrapper.Interpreter(
141          model_path='totally_invalid_file_name')
142
143  def testInvokeBeforeReady(self):
144    interpreter = interpreter_wrapper.Interpreter(
145        model_path=resource_loader.get_path_to_datafile(
146            'testdata/permute_float.tflite'))
147    with self.assertRaisesRegexp(RuntimeError,
148                                 'Invoke called on model that is not ready'):
149      interpreter.invoke()
150
151
152class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
153
154  def setUp(self):
155    self.interpreter = interpreter_wrapper.Interpreter(
156        model_path=resource_loader.get_path_to_datafile(
157            'testdata/permute_float.tflite'))
158    self.interpreter.allocate_tensors()
159    self.input0 = self.interpreter.get_input_details()[0]['index']
160    self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32)
161
162  def testTensorAccessor(self):
163    """Check that tensor returns a reference."""
164    array_ref = self.interpreter.tensor(self.input0)
165    np.copyto(array_ref(), self.initial_data)
166    self.assertAllEqual(array_ref(), self.initial_data)
167    self.assertAllEqual(
168        self.interpreter.get_tensor(self.input0), self.initial_data)
169
170  def testGetTensorAccessor(self):
171    """Check that get_tensor returns a copy."""
172    self.interpreter.set_tensor(self.input0, self.initial_data)
173    array_initial_copy = self.interpreter.get_tensor(self.input0)
174    new_value = np.add(1., array_initial_copy)
175    self.interpreter.set_tensor(self.input0, new_value)
176    self.assertAllEqual(array_initial_copy, self.initial_data)
177    self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value)
178
179  def testBase(self):
180    self.assertTrue(self.interpreter._safe_to_run())
181    _ = self.interpreter.tensor(self.input0)
182    self.assertTrue(self.interpreter._safe_to_run())
183    in0 = self.interpreter.tensor(self.input0)()
184    self.assertFalse(self.interpreter._safe_to_run())
185    in0b = self.interpreter.tensor(self.input0)()
186    self.assertFalse(self.interpreter._safe_to_run())
187    # Now get rid of the buffers so that we can evaluate.
188    del in0
189    del in0b
190    self.assertTrue(self.interpreter._safe_to_run())
191
192  def testBaseProtectsFunctions(self):
193    in0 = self.interpreter.tensor(self.input0)()
194    # Make sure we get an exception if we try to run an unsafe operation
195    with self.assertRaisesRegexp(
196        RuntimeError, 'There is at least 1 reference'):
197      _ = self.interpreter.allocate_tensors()
198    # Make sure we get an exception if we try to run an unsafe operation
199    with self.assertRaisesRegexp(
200        RuntimeError, 'There is at least 1 reference'):
201      _ = self.interpreter.invoke()
202    # Now test that we can run
203    del in0  # this is our only buffer reference, so now it is safe to change
204    in0safe = self.interpreter.tensor(self.input0)
205    _ = self.interpreter.allocate_tensors()
206    del in0safe  # make sure in0Safe is held but lint doesn't complain
207
208if __name__ == '__main__':
209  test.main()
210