1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py functionality related to TensorFlow 2.0."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.lite.python import lite
24from tensorflow.lite.python.interpreter import Interpreter
25from tensorflow.python import keras
26from tensorflow.python.eager import def_function
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import variables
32from tensorflow.python.platform import test
33from tensorflow.python.saved_model.load import load
34from tensorflow.python.saved_model.save import save
35from tensorflow.python.training.tracking import tracking
36
37
38class FromConcreteFunctionTest(test_util.TensorFlowTestCase):
39
40  def _evaluateTFLiteModel(self, tflite_model, input_data):
41    """Evaluates the model on the `input_data`."""
42    interpreter = Interpreter(model_content=tflite_model)
43    interpreter.allocate_tensors()
44
45    input_details = interpreter.get_input_details()
46    output_details = interpreter.get_output_details()
47
48    for input_tensor, tensor_data in zip(input_details, input_data):
49      interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
50    interpreter.invoke()
51    return interpreter.get_tensor(output_details[0]['index'])
52
53  @test_util.run_v2_only
54  def testTypeInvalid(self):
55    root = tracking.AutoTrackable()
56    root.v1 = variables.Variable(3.)
57    root.v2 = variables.Variable(2.)
58    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
59
60    with self.assertRaises(ValueError) as error:
61      _ = lite.TFLiteConverterV2.from_concrete_function(root.f)
62    self.assertIn('call from_concrete_function', str(error.exception))
63
64  @test_util.run_v2_only
65  def testFloat(self):
66    input_data = constant_op.constant(1., shape=[1])
67    root = tracking.AutoTrackable()
68    root.v1 = variables.Variable(3.)
69    root.v2 = variables.Variable(2.)
70    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
71    concrete_func = root.f.get_concrete_function(input_data)
72
73    # Convert model.
74    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
75    tflite_model = converter.convert()
76
77    # Check values from converted model.
78    expected_value = root.f(input_data)
79    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
80    self.assertEqual(expected_value.numpy(), actual_value)
81
82  @test_util.run_v2_only
83  def testSizeNone(self):
84    # Test with a shape of None
85    input_data = constant_op.constant(1., shape=None)
86    root = tracking.AutoTrackable()
87    root.v1 = variables.Variable(3.)
88    root.f = def_function.function(lambda x: root.v1 * x)
89    concrete_func = root.f.get_concrete_function(input_data)
90
91    # Convert model.
92    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
93    tflite_model = converter.convert()
94
95    # Check values from converted model.
96    expected_value = root.f(input_data)
97    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
98    self.assertEqual(expected_value.numpy(), actual_value)
99
100  @test_util.run_v2_only
101  def testConstSavedModel(self):
102    """Test a basic model with functions to make sure functions are inlined."""
103    self.skipTest('b/124205572')
104    input_data = constant_op.constant(1., shape=[1])
105    root = tracking.AutoTrackable()
106    root.f = def_function.function(lambda x: 2. * x)
107    to_save = root.f.get_concrete_function(input_data)
108
109    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
110    save(root, save_dir, to_save)
111    saved_model = load(save_dir)
112    concrete_func = saved_model.signatures['serving_default']
113
114    # Convert model and ensure model is not None.
115    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
116    tflite_model = converter.convert()
117
118    # Check values from converted model.
119    expected_value = root.f(input_data)
120    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
121    self.assertEqual(expected_value.numpy(), actual_value)
122
123  @test_util.run_v2_only
124  def testVariableSavedModel(self):
125    """Test a basic model with Variables with saving/loading the SavedModel."""
126    self.skipTest('b/124205572')
127    input_data = constant_op.constant(1., shape=[1])
128    root = tracking.AutoTrackable()
129    root.v1 = variables.Variable(3.)
130    root.v2 = variables.Variable(2.)
131    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
132    to_save = root.f.get_concrete_function(input_data)
133
134    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
135    save(root, save_dir, to_save)
136    saved_model = load(save_dir)
137    concrete_func = saved_model.signatures['serving_default']
138
139    # Convert model and ensure model is not None.
140    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
141    tflite_model = converter.convert()
142
143    # Check values from converted model.
144    expected_value = root.f(input_data)
145    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
146    self.assertEqual(expected_value.numpy(), actual_value)
147
148  @test_util.run_v2_only
149  def testMultiFunctionModel(self):
150    """Test a basic model with Variables."""
151
152    class BasicModel(tracking.AutoTrackable):
153
154      def __init__(self):
155        self.y = None
156        self.z = None
157
158      @def_function.function
159      def add(self, x):
160        if self.y is None:
161          self.y = variables.Variable(2.)
162        return x + self.y
163
164      @def_function.function
165      def sub(self, x):
166        if self.z is None:
167          self.z = variables.Variable(3.)
168        return x - self.z
169
170    input_data = constant_op.constant(1., shape=[1])
171    root = BasicModel()
172    concrete_func = root.add.get_concrete_function(input_data)
173
174    # Convert model and ensure model is not None.
175    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
176    tflite_model = converter.convert()
177
178    # Check values from converted model.
179    expected_value = root.add(input_data)
180    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
181    self.assertEqual(expected_value.numpy(), actual_value)
182
183  @test_util.run_v2_only
184  def testKerasModel(self):
185    input_data = constant_op.constant(1., shape=[1, 1])
186
187    # Create a simple Keras model.
188    x = [-1, 0, 1, 2, 3, 4]
189    y = [-3, -1, 1, 3, 5, 7]
190
191    model = keras.models.Sequential(
192        [keras.layers.Dense(units=1, input_shape=[1])])
193    model.compile(optimizer='sgd', loss='mean_squared_error')
194    model.fit(x, y, epochs=1)
195
196    # Get the concrete function from the Keras model.
197    @def_function.function
198    def to_save(x):
199      return model(x)
200
201    concrete_func = to_save.get_concrete_function(
202        tensor_spec.TensorSpec([None, 1], dtypes.float32))
203
204    # Convert model and ensure model is not None.
205    converter = lite.TFLiteConverterV2.from_concrete_function(concrete_func)
206    tflite_model = converter.convert()
207
208    # Check values from converted model.
209    expected_value = to_save(input_data)
210    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
211    self.assertEqual(expected_value.numpy(), actual_value)
212
213
214if __name__ == '__main__':
215  test.main()
216