1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py functionality related to select TF op usage."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.lite.python import lite
28from tensorflow.lite.python import test_util as tflite_test_util
29from tensorflow.lite.python.convert import register_custom_opdefs
30from tensorflow.lite.python.interpreter import Interpreter
31from tensorflow.lite.python.testdata import double_op
32from tensorflow.python.client import session
33from tensorflow.python.eager import def_function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import test_util
38from tensorflow.python.framework.importer import import_graph_def
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.platform import test
42from tensorflow.python.saved_model import saved_model
43from tensorflow.python.training.tracking import tracking
44
45
46class FromSessionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
47
48  @parameterized.named_parameters(
49      ('EnableMlirConverter', True),  # enable mlir
50      ('DisableMlirConverter', False))  # disable mlir
51  def testFlexMode(self, enable_mlir):
52    with ops.Graph().as_default():
53      in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
54      out_tensor = in_tensor + in_tensor
55      sess = session.Session()
56
57    # Convert model and ensure model is not None.
58    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
59                                                  [out_tensor])
60    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
61    converter.experimental_new_converter = enable_mlir
62    tflite_model = converter.convert()
63    self.assertTrue(tflite_model)
64
65    # Check the model works with TensorFlow ops.
66    interpreter = Interpreter(model_content=tflite_model)
67    interpreter.allocate_tensors()
68    input_details = interpreter.get_input_details()
69    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
70    interpreter.set_tensor(input_details[0]['index'], test_input)
71    interpreter.invoke()
72
73    output_details = interpreter.get_output_details()
74    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
75    output_data = interpreter.get_tensor(output_details[0]['index'])
76    self.assertTrue((expected_output == output_data).all())
77
78  def testDeprecatedFlags(self):
79    with ops.Graph().as_default():
80      in_tensor = array_ops.placeholder(shape=[1, 4], dtype=dtypes.float32)
81      out_tensor = in_tensor + in_tensor
82      sess = session.Session()
83
84    # Convert model and ensure model is not None.
85    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
86                                                  [out_tensor])
87    converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS])
88
89    # Ensure `target_ops` is set to the correct value after flag deprecation.
90    self.assertEqual(converter.target_ops, set([lite.OpsSet.SELECT_TF_OPS]))
91    self.assertEqual(converter.target_spec.supported_ops,
92                     set([lite.OpsSet.SELECT_TF_OPS]))
93
94    tflite_model = converter.convert()
95    self.assertTrue(tflite_model)
96
97    # Check the model works with TensorFlow ops.
98    interpreter = Interpreter(model_content=tflite_model)
99    interpreter.allocate_tensors()
100    input_details = interpreter.get_input_details()
101    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)
102    interpreter.set_tensor(input_details[0]['index'], test_input)
103    interpreter.invoke()
104
105    output_details = interpreter.get_output_details()
106    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.float32)
107    output_data = interpreter.get_tensor(output_details[0]['index'])
108    self.assertTrue((expected_output == output_data).all())
109
110
111class FromConcreteFunctionTest(test_util.TensorFlowTestCase,
112                               parameterized.TestCase):
113
114  @parameterized.named_parameters(
115      ('EnableMlirConverter', True),  # enable mlir
116      ('DisableMlirConverter', False))  # disable mlir
117  @test_util.run_v2_only
118  def testFloat(self, enable_mlir):
119    input_data = constant_op.constant(1., shape=[1])
120    root = tracking.AutoTrackable()
121    root.v1 = variables.Variable(3.)
122    root.v2 = variables.Variable(2.)
123    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
124    concrete_func = root.f.get_concrete_function(input_data)
125
126    # Convert model.
127    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
128    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
129    converter.experimental_new_converter = enable_mlir
130    tflite_model = converter.convert()
131
132    # Check the model works with TensorFlow ops.
133    interpreter = Interpreter(model_content=tflite_model)
134    interpreter.allocate_tensors()
135    input_details = interpreter.get_input_details()
136    test_input = np.array([4.0], dtype=np.float32)
137    interpreter.set_tensor(input_details[0]['index'], test_input)
138    interpreter.invoke()
139
140    output_details = interpreter.get_output_details()
141    expected_output = np.array([24.0], dtype=np.float32)
142    output_data = interpreter.get_tensor(output_details[0]['index'])
143    self.assertTrue((expected_output == output_data).all())
144
145
146class WithCustomOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
147
148  def _createGraphWithCustomOp(self, opname='CustomAdd'):
149    custom_opdefs_str = (
150        'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
151        'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
152        '\'Output\' type: DT_FLOAT}')
153
154    # Create a graph that has one add op.
155    new_graph = graph_pb2.GraphDef()
156    with ops.Graph().as_default():
157      with session.Session() as sess:
158        in_tensor = array_ops.placeholder(
159            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
160        out_tensor = in_tensor + in_tensor
161        inputs = {'x': in_tensor}
162        outputs = {'z': out_tensor}
163
164        new_graph.CopyFrom(sess.graph_def)
165
166    # Rename Add op name to opname.
167    for node in new_graph.node:
168      if node.op.startswith('Add'):
169        node.op = opname
170        del node.attr['T']
171
172    # Register custom op defs to import modified graph def.
173    register_custom_opdefs([custom_opdefs_str])
174
175    return (new_graph, inputs, outputs)
176
177  def testFlexWithCustomOp(self):
178    new_graph, inputs, outputs = self._createGraphWithCustomOp(
179        opname='CustomAdd4')
180
181    # Import to load the custom opdef.
182    saved_model_dir = os.path.join(self.get_temp_dir(), 'model')
183    with ops.Graph().as_default():
184      with session.Session() as sess:
185        import_graph_def(new_graph, name='')
186        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
187
188    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
189    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
190    converter.target_spec.experimental_select_user_tf_ops = ['CustomAdd4']
191    tflite_model = converter.convert()
192
193    self.assertIn('FlexCustomAdd4', tflite_test_util.get_ops_list(tflite_model))
194
195  def testFlexWithDoubleOp(self):
196    # Create a graph that has one double op.
197    saved_model_dir = os.path.join(self.get_temp_dir(), 'model2')
198    with ops.Graph().as_default():
199      with session.Session() as sess:
200        in_tensor = array_ops.placeholder(
201            shape=[1, 4], dtype=dtypes.int32, name='input')
202        out_tensor = double_op.double(in_tensor)
203        inputs = {'x': in_tensor}
204        outputs = {'z': out_tensor}
205        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
206
207    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
208    converter.target_spec.supported_ops = set([lite.OpsSet.SELECT_TF_OPS])
209    converter.target_spec.experimental_select_user_tf_ops = ['Double']
210    tflite_model = converter.convert()
211    self.assertTrue(tflite_model)
212    self.assertIn('FlexDouble', tflite_test_util.get_ops_list(tflite_model))
213
214    # Check the model works with TensorFlow ops.
215    interpreter = Interpreter(model_content=tflite_model)
216    interpreter.allocate_tensors()
217    input_details = interpreter.get_input_details()
218    test_input = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.int32)
219    interpreter.set_tensor(input_details[0]['index'], test_input)
220    interpreter.invoke()
221
222    output_details = interpreter.get_output_details()
223    expected_output = np.array([[2.0, 4.0, 6.0, 8.0]], dtype=np.int32)
224    output_data = interpreter.get_tensor(output_details[0]['index'])
225    self.assertTrue((expected_output == output_data).all())
226
227
228if __name__ == '__main__':
229  test.main()
230