1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py functionality related to select TF op usage.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.lite.python import lite 28from tensorflow.lite.python import test_util as tflite_test_util 29from tensorflow.lite.python.convert import register_custom_opdefs 30from tensorflow.lite.python.interpreter import Interpreter 31from tensorflow.lite.python.testdata import double_op 32from tensorflow.python.client import session 33from tensorflow.python.eager import def_function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import test_util 38from tensorflow.python.framework.importer import import_graph_def 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.platform import test 42from tensorflow.python.saved_model import saved_model 43from tensorflow.python.training.tracking import tracking 44 45 46class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase): 47 48 @parameterized.named_parameters( 49 ('EnableMlirConverter', True), # enable mlir 50 ('DisableMlirConverter', False)) # disable mlir 51 def testFlexMode(self, enable_mlir): 52 with ops.Graph().as_default(): 53 in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) 54 out_tensor = in_tensor + in_tensor 55 sess = session.Session() 56 57 # Convert model and ensure model is not None. 58 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 59 [out_tensor]) 60 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 61 converter.experimental_new_converter = enable_mlir 62 tflite_model = converter.convert() 63 self.assertTrue(tflite_model) 64 65 # Check the model works with TensorFlow ops. 66 interpreter = Interpreter(model_content=tflite_model) 67 interpreter.allocate_tensors() 68 input_details = interpreter.get_input_details() 69 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 70 interpreter.set_tensor(input_details[0]['index'], test_input) 71 interpreter.invoke() 72 73 output_details = interpreter.get_output_details() 74 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) 75 output_data = interpreter.get_tensor(output_details[0]['index']) 76 self.assertTrue((expected_output == output_data).all()) 77 78 def testDeprecatedFlags(self): 79 with ops.Graph().as_default(): 80 in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32) 81 out_tensor = in_tensor + in_tensor 82 sess = session.Session() 83 84 # Convert model and ensure model is not None. 85 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 86 [out_tensor]) 87 converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS]) 88 89 # Ensure `target_ops` is set to the correct value after flag deprecation. 90 self.assertEqual(converter.target_ops, set([lite.OpsSet.SELECT_TF_OPS])) 91 self.assertEqual(converter.target_spec.supported_ops, 92 set([lite.OpsSet.SELECT_TF_OPS])) 93 94 tflite_model = converter.convert() 95 self.assertTrue(tflite_model) 96 97 # Check the model works with TensorFlow ops. 98 interpreter = Interpreter(model_content=tflite_model) 99 interpreter.allocate_tensors() 100 input_details = interpreter.get_input_details() 101 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32) 102 interpreter.set_tensor(input_details[0]['index'], test_input) 103 interpreter.invoke() 104 105 output_details = interpreter.get_output_details() 106 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32) 107 output_data = interpreter.get_tensor(output_details[0]['index']) 108 self.assertTrue((expected_output == output_data).all()) 109 110 111class FromConcreteFunctionTest(test_util.TensorFlowTestCase, 112 parameterized.TestCase): 113 114 @parameterized.named_parameters( 115 ('EnableMlirConverter', True), # enable mlir 116 ('DisableMlirConverter', False)) # disable mlir 117 @test_util.run_v2_only 118 def testFloat(self, enable_mlir): 119 input_data = constant_op.constant(1., shape=[1]) 120 root = tracking.AutoTrackable() 121 root.v1 = variables.Variable(3.) 122 root.v2 = variables.Variable(2.) 123 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 124 concrete_func = root.f.get_concrete_function(input_data) 125 126 # Convert model. 127 converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) 128 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 129 converter.experimental_new_converter = enable_mlir 130 tflite_model = converter.convert() 131 132 # Check the model works with TensorFlow ops. 133 interpreter = Interpreter(model_content=tflite_model) 134 interpreter.allocate_tensors() 135 input_details = interpreter.get_input_details() 136 test_input = np.array([4.0], dtype=np.float32) 137 interpreter.set_tensor(input_details[0]['index'], test_input) 138 interpreter.invoke() 139 140 output_details = interpreter.get_output_details() 141 expected_output = np.array([24.0], dtype=np.float32) 142 output_data = interpreter.get_tensor(output_details[0]['index']) 143 self.assertTrue((expected_output == output_data).all()) 144 145 146class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): 147 148 def _createGraphWithCustomOp(self, opname='CustomAdd'): 149 custom_opdefs_str = ( 150 'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' 151 'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: ' 152 '\'Output\' type: DT_FLOAT}') 153 154 # Create a graph that has one add op. 155 new_graph = graph_pb2.GraphDef() 156 with ops.Graph().as_default(): 157 with session.Session() as sess: 158 in_tensor = array_ops.placeholder( 159 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') 160 out_tensor = in_tensor + in_tensor 161 inputs = {'x': in_tensor} 162 outputs = {'z': out_tensor} 163 164 new_graph.CopyFrom(sess.graph_def) 165 166 # Rename Add op name to opname. 167 for node in new_graph.node: 168 if node.op.startswith('Add'): 169 node.op = opname 170 del node.attr['T'] 171 172 # Register custom op defs to import modified graph def. 173 register_custom_opdefs([custom_opdefs_str]) 174 175 return (new_graph, inputs, outputs) 176 177 def testFlexWithCustomOp(self): 178 new_graph, inputs, outputs = self._createGraphWithCustomOp( 179 opname='CustomAdd4') 180 181 # Import to load the custom opdef. 182 saved_model_dir = os.path.join(self.get_temp_dir(), 'model') 183 with ops.Graph().as_default(): 184 with session.Session() as sess: 185 import_graph_def(new_graph, name='') 186 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 187 188 converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir) 189 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 190 converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4'] 191 tflite_model = converter.convert() 192 193 self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model)) 194 195 def testFlexWithDoubleOp(self): 196 # Create a graph that has one double op. 197 saved_model_dir = os.path.join(self.get_temp_dir(), 'model2') 198 with ops.Graph().as_default(): 199 with session.Session() as sess: 200 in_tensor = array_ops.placeholder( 201 shape=[1, 4], dtype=dtypes.int32, name='input') 202 out_tensor = double_op.double(in_tensor) 203 inputs = {'x': in_tensor} 204 outputs = {'z': out_tensor} 205 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 206 207 converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir) 208 converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS]) 209 converter.target_spec.experimental_select_user_tf_ops = ['Double'] 210 tflite_model = converter.convert() 211 self.assertTrue(tflite_model) 212 self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model)) 213 214 # Check the model works with TensorFlow ops. 215 interpreter = Interpreter(model_content=tflite_model) 216 interpreter.allocate_tensors() 217 input_details = interpreter.get_input_details() 218 test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32) 219 interpreter.set_tensor(input_details[0]['index'], test_input) 220 interpreter.invoke() 221 222 output_details = interpreter.get_output_details() 223 expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32) 224 output_data = interpreter.get_tensor(output_details[0]['index']) 225 self.assertTrue((expected_output == output_data).all()) 226 227 228if __name__ == '__main__': 229 test.main() 230