1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TFLite SavedModel conversion test cases.
16
17  - Tests converting simple SavedModel graph to TFLite FlatBuffer.
18  - Tests converting simple SavedModel graph to frozen graph.
19  - Tests converting MNIST SavedModel to TFLite FlatBuffer.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27from tensorflow.lite.python import convert_saved_model
28from tensorflow.python.client import session
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.platform import test
35from tensorflow.python.saved_model import saved_model
36from tensorflow.python.saved_model import signature_constants
37from tensorflow.python.saved_model import tag_constants
38
39
40class TensorFunctionsTest(test_util.TensorFlowTestCase):
41
42  @test_util.run_v1_only("b/120545219")
43  def testGetTensorsValid(self):
44    in_tensor = array_ops.placeholder(
45        shape=[1, 16, 16, 3], dtype=dtypes.float32)
46    _ = in_tensor + in_tensor
47    sess = session.Session()
48
49    tensors = convert_saved_model.get_tensors_from_tensor_names(
50        sess.graph, ["Placeholder"])
51    self.assertEqual("Placeholder:0", tensors[0].name)
52
53  @test_util.run_v1_only("b/120545219")
54  def testGetTensorsInvalid(self):
55    in_tensor = array_ops.placeholder(
56        shape=[1, 16, 16, 3], dtype=dtypes.float32)
57    _ = in_tensor + in_tensor
58    sess = session.Session()
59
60    with self.assertRaises(ValueError) as error:
61      convert_saved_model.get_tensors_from_tensor_names(sess.graph,
62                                                        ["invalid-input"])
63    self.assertEqual("Invalid tensors 'invalid-input' were found.",
64                     str(error.exception))
65
66  @test_util.run_v1_only("b/120545219")
67  def testSetTensorShapeValid(self):
68    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
69    self.assertEqual([None, 3, 5], tensor.shape.as_list())
70
71    convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
72    self.assertEqual([5, 3, 5], tensor.shape.as_list())
73
74  @test_util.run_v1_only("b/120545219")
75  def testSetTensorShapeNoneValid(self):
76    tensor = array_ops.placeholder(dtype=dtypes.float32)
77    self.assertEqual(None, tensor.shape)
78
79    convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
80    self.assertEqual([1, 3, 5], tensor.shape.as_list())
81
82  @test_util.run_v1_only("b/120545219")
83  def testSetTensorShapeArrayInvalid(self):
84    # Tests set_tensor_shape where the tensor name passed in doesn't exist.
85    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
86    self.assertEqual([None, 3, 5], tensor.shape.as_list())
87
88    with self.assertRaises(ValueError) as error:
89      convert_saved_model.set_tensor_shapes([tensor],
90                                            {"invalid-input": [5, 3, 5]})
91    self.assertEqual(
92        "Invalid tensor 'invalid-input' found in tensor shapes map.",
93        str(error.exception))
94    self.assertEqual([None, 3, 5], tensor.shape.as_list())
95
96  @test_util.run_deprecated_v1
97  def testSetTensorShapeDimensionInvalid(self):
98    # Tests set_tensor_shape where the shape passed in is incompatiable.
99    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
100    self.assertEqual([None, 3, 5], tensor.shape.as_list())
101
102    with self.assertRaises(ValueError) as error:
103      convert_saved_model.set_tensor_shapes([tensor],
104                                            {"Placeholder": [1, 5, 5]})
105    self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
106                  str(error.exception))
107    self.assertEqual([None, 3, 5], tensor.shape.as_list())
108
109  @test_util.run_v1_only("b/120545219")
110  def testSetTensorShapeEmpty(self):
111    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
112    self.assertEqual([None, 3, 5], tensor.shape.as_list())
113
114    convert_saved_model.set_tensor_shapes([tensor], {})
115    self.assertEqual([None, 3, 5], tensor.shape.as_list())
116
117
118class FreezeSavedModelTest(test_util.TensorFlowTestCase):
119
120  def _createSimpleSavedModel(self, shape):
121    """Create a simple SavedModel on the fly."""
122    saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
123    with session.Session() as sess:
124      in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32)
125      out_tensor = in_tensor + in_tensor
126      inputs = {"x": in_tensor}
127      outputs = {"y": out_tensor}
128      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
129    return saved_model_dir
130
131  def _createSavedModelTwoInputArrays(self, shape):
132    """Create a simple SavedModel."""
133    saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
134    with session.Session() as sess:
135      in_tensor_1 = array_ops.placeholder(
136          shape=shape, dtype=dtypes.float32, name="inputB")
137      in_tensor_2 = array_ops.placeholder(
138          shape=shape, dtype=dtypes.float32, name="inputA")
139      out_tensor = in_tensor_1 + in_tensor_2
140      inputs = {"x": in_tensor_1, "y": in_tensor_2}
141      outputs = {"z": out_tensor}
142      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
143    return saved_model_dir
144
145  def _getArrayNames(self, tensors):
146    return [tensor.name for tensor in tensors]
147
148  def _getArrayShapes(self, tensors):
149    dims = []
150    for tensor in tensors:
151      dim_tensor = []
152      for dim in tensor.shape:
153        if isinstance(dim, tensor_shape.Dimension):
154          dim_tensor.append(dim.value)
155        else:
156          dim_tensor.append(dim)
157      dims.append(dim_tensor)
158    return dims
159
160  def _convertSavedModel(self,
161                         saved_model_dir,
162                         input_arrays=None,
163                         input_shapes=None,
164                         output_arrays=None,
165                         tag_set=None,
166                         signature_key=None):
167    if tag_set is None:
168      tag_set = set([tag_constants.SERVING])
169    if signature_key is None:
170      signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
171    graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
172        saved_model_dir=saved_model_dir,
173        input_arrays=input_arrays,
174        input_shapes=input_shapes,
175        output_arrays=output_arrays,
176        tag_set=tag_set,
177        signature_key=signature_key)
178    return graph_def, in_tensors, out_tensors
179
180  def testSimpleSavedModel(self):
181    """Test a SavedModel."""
182    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
183    _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
184
185    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
186    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
187    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
188
189  def testSimpleSavedModelWithNoneBatchSizeInShape(self):
190    """Test a SavedModel with None in input tensor's shape."""
191    saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
192    _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
193
194    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
195    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
196    self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
197
198  def testSimpleSavedModelWithInvalidSignatureKey(self):
199    """Test a SavedModel that fails due to an invalid signature_key."""
200    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
201    with self.assertRaises(ValueError) as error:
202      self._convertSavedModel(saved_model_dir, signature_key="invalid-key")
203    self.assertEqual(
204        "No 'invalid-key' in the SavedModel's SignatureDefs. "
205        "Possible values are 'serving_default'.", str(error.exception))
206
207  def testSimpleSavedModelWithInvalidOutputArray(self):
208    """Test a SavedModel that fails due to invalid output arrays."""
209    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
210    with self.assertRaises(ValueError) as error:
211      self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"])
212    self.assertEqual("Invalid tensors 'invalid-output' were found.",
213                     str(error.exception))
214
215  def testSimpleSavedModelWithWrongInputArrays(self):
216    """Test a SavedModel that fails due to invalid input arrays."""
217    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
218
219    # Check invalid input_arrays.
220    with self.assertRaises(ValueError) as error:
221      self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"])
222    self.assertEqual("Invalid tensors 'invalid-input' were found.",
223                     str(error.exception))
224
225    # Check valid and invalid input_arrays.
226    with self.assertRaises(ValueError) as error:
227      self._convertSavedModel(
228          saved_model_dir, input_arrays=["Placeholder", "invalid-input"])
229    self.assertEqual("Invalid tensors 'invalid-input' were found.",
230                     str(error.exception))
231
232  def testSimpleSavedModelWithCorrectArrays(self):
233    """Test a SavedModel with correct input_arrays and output_arrays."""
234    saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
235    _, in_tensors, out_tensors = self._convertSavedModel(
236        saved_model_dir=saved_model_dir,
237        input_arrays=["Placeholder"],
238        output_arrays=["add"])
239
240    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
241    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
242    self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
243
244  def testSimpleSavedModelWithCorrectInputArrays(self):
245    """Test a SavedModel with correct input_arrays and input_shapes."""
246    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
247    _, in_tensors, out_tensors = self._convertSavedModel(
248        saved_model_dir=saved_model_dir,
249        input_arrays=["Placeholder"],
250        input_shapes={"Placeholder": [1, 16, 16, 3]})
251
252    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
253    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
254    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
255
256  def testTwoInputArrays(self):
257    """Test a simple SavedModel."""
258    saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
259
260    _, in_tensors, out_tensors = self._convertSavedModel(
261        saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"])
262
263    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
264    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"])
265    self.assertEqual(
266        self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]])
267
268  def testSubsetInputArrays(self):
269    """Test a SavedModel with a subset of the input array names of the model."""
270    saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
271
272    # Check case where input shape is given.
273    _, in_tensors, out_tensors = self._convertSavedModel(
274        saved_model_dir=saved_model_dir,
275        input_arrays=["inputA"],
276        input_shapes={"inputA": [1, 16, 16, 3]})
277
278    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
279    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
280    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
281
282    # Check case where input shape is None.
283    _, in_tensors, out_tensors = self._convertSavedModel(
284        saved_model_dir=saved_model_dir, input_arrays=["inputA"])
285
286    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
287    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
288    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
289
290  def testMultipleMetaGraphDef(self):
291    """Test saved model with multiple MetaGraphDefs."""
292    saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd")
293    builder = saved_model.builder.SavedModelBuilder(saved_model_dir)
294    with session.Session(graph=ops.Graph()) as sess:
295      # MetaGraphDef 1
296      in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32)
297      out_tensor = in_tensor + in_tensor
298      sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor)
299      sig_input_tensor_signature = {"x": sig_input_tensor}
300      sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor)
301      sig_output_tensor_signature = {"y": sig_output_tensor}
302      predict_signature_def = (
303          saved_model.signature_def_utils.build_signature_def(
304              sig_input_tensor_signature, sig_output_tensor_signature,
305              saved_model.signature_constants.PREDICT_METHOD_NAME))
306      signature_def_map = {
307          saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
308              predict_signature_def
309      }
310      builder.add_meta_graph_and_variables(
311          sess,
312          tags=[saved_model.tag_constants.SERVING, "additional_test_tag"],
313          signature_def_map=signature_def_map)
314
315      # MetaGraphDef 2
316      builder.add_meta_graph(tags=["tflite"])
317      builder.save(True)
318
319    # Convert to tflite
320    _, in_tensors, out_tensors = self._convertSavedModel(
321        saved_model_dir=saved_model_dir,
322        tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
323
324    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
325    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
326    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]])
327
328
329if __name__ == "__main__":
330  test.main()
331