1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py functionality related to TensorFlow 2.0.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.lite.python import lite 24from tensorflow.lite.python.interpreter import Interpreter 25from tensorflow.python import keras 26from tensorflow.python.eager import def_function 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import test 33from tensorflow.python.saved_model.load import load 34from tensorflow.python.saved_model.save import save 35from tensorflow.python.training.tracking import tracking 36 37 38class FromConcreteFunctionTest(test_util.TensorFlowTestCase): 39 40 def _evaluateTFLiteModel(self, tflite_model, input_data): 41 """Evaluates the model on the `input_data`.""" 42 interpreter = Interpreter(model_content=tflite_model) 43 interpreter.allocate_tensors() 44 45 input_details = interpreter.get_input_details() 46 output_details = interpreter.get_output_details() 47 48 for input_tensor, tensor_data in zip(input_details, input_data): 49 interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) 50 interpreter.invoke() 51 return interpreter.get_tensor(output_details[0]['index']) 52 53 @test_util.run_v2_only 54 def testTypeInvalid(self): 55 root = tracking.AutoTrackable() 56 root.v1 = variables.Variable(3.) 57 root.v2 = variables.Variable(2.) 58 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 59 60 with self.assertRaises(ValueError) as error: 61 _ = lite.TFLiteConverterV2.from_concrete_function(root.f) 62 self.assertIn('call from_concrete_function', str(error.exception)) 63 64 @test_util.run_v2_only 65 def testFloat(self): 66 input_data = constant_op.constant(1., shape=[1]) 67 root = tracking.AutoTrackable() 68 root.v1 = variables.Variable(3.) 69 root.v2 = variables.Variable(2.) 70 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 71 concrete_func = root.f.get_concrete_function(input_data) 72 73 # Convert model. 74 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 75 tflite_model = converter.convert() 76 77 # Check values from converted model. 78 expected_value = root.f(input_data) 79 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 80 self.assertEqual(expected_value.numpy(), actual_value) 81 82 @test_util.run_v2_only 83 def testSizeNone(self): 84 # Test with a shape of None 85 input_data = constant_op.constant(1., shape=None) 86 root = tracking.AutoTrackable() 87 root.v1 = variables.Variable(3.) 88 root.f = def_function.function(lambda x: root.v1 * x) 89 concrete_func = root.f.get_concrete_function(input_data) 90 91 # Convert model. 92 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 93 tflite_model = converter.convert() 94 95 # Check values from converted model. 96 expected_value = root.f(input_data) 97 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 98 self.assertEqual(expected_value.numpy(), actual_value) 99 100 @test_util.run_v2_only 101 def testConstSavedModel(self): 102 """Test a basic model with functions to make sure functions are inlined.""" 103 self.skipTest('b/124205572') 104 input_data = constant_op.constant(1., shape=[1]) 105 root = tracking.AutoTrackable() 106 root.f = def_function.function(lambda x: 2. * x) 107 to_save = root.f.get_concrete_function(input_data) 108 109 save_dir = os.path.join(self.get_temp_dir(), 'saved_model') 110 save(root, save_dir, to_save) 111 saved_model = load(save_dir) 112 concrete_func = saved_model.signatures['serving_default'] 113 114 # Convert model and ensure model is not None. 115 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 116 tflite_model = converter.convert() 117 118 # Check values from converted model. 119 expected_value = root.f(input_data) 120 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 121 self.assertEqual(expected_value.numpy(), actual_value) 122 123 @test_util.run_v2_only 124 def testVariableSavedModel(self): 125 """Test a basic model with Variables with saving/loading the SavedModel.""" 126 self.skipTest('b/124205572') 127 input_data = constant_op.constant(1., shape=[1]) 128 root = tracking.AutoTrackable() 129 root.v1 = variables.Variable(3.) 130 root.v2 = variables.Variable(2.) 131 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 132 to_save = root.f.get_concrete_function(input_data) 133 134 save_dir = os.path.join(self.get_temp_dir(), 'saved_model') 135 save(root, save_dir, to_save) 136 saved_model = load(save_dir) 137 concrete_func = saved_model.signatures['serving_default'] 138 139 # Convert model and ensure model is not None. 140 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 141 tflite_model = converter.convert() 142 143 # Check values from converted model. 144 expected_value = root.f(input_data) 145 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 146 self.assertEqual(expected_value.numpy(), actual_value) 147 148 @test_util.run_v2_only 149 def testMultiFunctionModel(self): 150 """Test a basic model with Variables.""" 151 152 class BasicModel(tracking.AutoTrackable): 153 154 def __init__(self): 155 self.y = None 156 self.z = None 157 158 @def_function.function 159 def add(self, x): 160 if self.y is None: 161 self.y = variables.Variable(2.) 162 return x + self.y 163 164 @def_function.function 165 def sub(self, x): 166 if self.z is None: 167 self.z = variables.Variable(3.) 168 return x - self.z 169 170 input_data = constant_op.constant(1., shape=[1]) 171 root = BasicModel() 172 concrete_func = root.add.get_concrete_function(input_data) 173 174 # Convert model and ensure model is not None. 175 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 176 tflite_model = converter.convert() 177 178 # Check values from converted model. 179 expected_value = root.add(input_data) 180 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 181 self.assertEqual(expected_value.numpy(), actual_value) 182 183 @test_util.run_v2_only 184 def testKerasModel(self): 185 input_data = constant_op.constant(1., shape=[1, 1]) 186 187 # Create a simple Keras model. 188 x = [-1, 0, 1, 2, 3, 4] 189 y = [-3, -1, 1, 3, 5, 7] 190 191 model = keras.models.Sequential( 192 [keras.layers.Dense(units=1, input_shape=[1])]) 193 model.compile(optimizer='sgd', loss='mean_squared_error') 194 model.fit(x, y, epochs=1) 195 196 # Get the concrete function from the Keras model. 197 @def_function.function 198 def to_save(x): 199 return model(x) 200 201 concrete_func = to_save.get_concrete_function( 202 tensor_spec.TensorSpec([None, 1], dtypes.float32)) 203 204 # Convert model and ensure model is not None. 205 converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func) 206 tflite_model = converter.convert() 207 208 # Check values from converted model. 209 expected_value = to_save(input_data) 210 actual_value = self._evaluateTFLiteModel(tflite_model, [input_data]) 211 self.assertEqual(expected_value.numpy(), actual_value) 212 213 214if __name__ == '__main__': 215 test.main() 216