1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TFLite SavedModel conversion test cases.
16
17  - Tests converting simple SavedModel graph to TFLite FlatBuffer.
18  - Tests converting simple SavedModel graph to frozen graph.
19  - Tests converting MNIST SavedModel to TFLite FlatBuffer.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27from tensorflow.lite.python import convert_saved_model
28from tensorflow.python.client import session
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.platform import test
35from tensorflow.python.saved_model import saved_model
36from tensorflow.python.saved_model import signature_constants
37from tensorflow.python.saved_model import tag_constants
38
39
40class FreezeSavedModelTest(test_util.TensorFlowTestCase):
41
42  def _createSimpleSavedModel(self, shape):
43    """Create a simple SavedModel on the fly."""
44    saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
45    with session.Session() as sess:
46      in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32)
47      out_tensor = in_tensor + in_tensor
48      inputs = {"x": in_tensor}
49      outputs = {"y": out_tensor}
50      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
51    return saved_model_dir
52
53  def _createSavedModelTwoInputArrays(self, shape):
54    """Create a simple SavedModel."""
55    saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
56    with session.Session() as sess:
57      in_tensor_1 = array_ops.placeholder(
58          shape=shape, dtype=dtypes.float32, name="inputB")
59      in_tensor_2 = array_ops.placeholder(
60          shape=shape, dtype=dtypes.float32, name="inputA")
61      out_tensor = in_tensor_1 + in_tensor_2
62      inputs = {"x": in_tensor_1, "y": in_tensor_2}
63      outputs = {"z": out_tensor}
64      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
65    return saved_model_dir
66
67  def _getArrayNames(self, tensors):
68    return [tensor.name for tensor in tensors]
69
70  def _getArrayShapes(self, tensors):
71    dims = []
72    for tensor in tensors:
73      dim_tensor = []
74      for dim in tensor.shape:
75        if isinstance(dim, tensor_shape.Dimension):
76          dim_tensor.append(dim.value)
77        else:
78          dim_tensor.append(dim)
79      dims.append(dim_tensor)
80    return dims
81
82  def _convertSavedModel(self,
83                         saved_model_dir,
84                         input_arrays=None,
85                         input_shapes=None,
86                         output_arrays=None,
87                         tag_set=None,
88                         signature_key=None):
89    if tag_set is None:
90      tag_set = set([tag_constants.SERVING])
91    if signature_key is None:
92      signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
93    graph_def, in_tensors, out_tensors, _ = (
94        convert_saved_model.freeze_saved_model(
95            saved_model_dir=saved_model_dir,
96            input_arrays=input_arrays,
97            input_shapes=input_shapes,
98            output_arrays=output_arrays,
99            tag_set=tag_set,
100            signature_key=signature_key))
101    return graph_def, in_tensors, out_tensors
102
103  def testSimpleSavedModel(self):
104    """Test a SavedModel."""
105    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
106    _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
107
108    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
109    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
110    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
111
112  def testSimpleSavedModelWithNoneBatchSizeInShape(self):
113    """Test a SavedModel with None in input tensor's shape."""
114    saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
115    _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir)
116
117    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
118    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
119    self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
120
121  def testSimpleSavedModelWithInvalidSignatureKey(self):
122    """Test a SavedModel that fails due to an invalid signature_key."""
123    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
124    with self.assertRaises(ValueError) as error:
125      self._convertSavedModel(saved_model_dir, signature_key="invalid-key")
126    self.assertEqual(
127        "No 'invalid-key' in the SavedModel's SignatureDefs. "
128        "Possible values are 'serving_default'.", str(error.exception))
129
130  def testSimpleSavedModelWithInvalidOutputArray(self):
131    """Test a SavedModel that fails due to invalid output arrays."""
132    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
133    with self.assertRaises(ValueError) as error:
134      self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"])
135    self.assertEqual("Invalid tensors 'invalid-output' were found.",
136                     str(error.exception))
137
138  def testSimpleSavedModelWithWrongInputArrays(self):
139    """Test a SavedModel that fails due to invalid input arrays."""
140    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
141
142    # Check invalid input_arrays.
143    with self.assertRaises(ValueError) as error:
144      self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"])
145    self.assertEqual("Invalid tensors 'invalid-input' were found.",
146                     str(error.exception))
147
148    # Check valid and invalid input_arrays.
149    with self.assertRaises(ValueError) as error:
150      self._convertSavedModel(
151          saved_model_dir, input_arrays=["Placeholder", "invalid-input"])
152    self.assertEqual("Invalid tensors 'invalid-input' were found.",
153                     str(error.exception))
154
155  def testSimpleSavedModelWithCorrectArrays(self):
156    """Test a SavedModel with correct input_arrays and output_arrays."""
157    saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
158    _, in_tensors, out_tensors = self._convertSavedModel(
159        saved_model_dir=saved_model_dir,
160        input_arrays=["Placeholder"],
161        output_arrays=["add"])
162
163    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
164    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
165    self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]])
166
167  def testSimpleSavedModelWithCorrectInputArrays(self):
168    """Test a SavedModel with correct input_arrays and input_shapes."""
169    saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
170    _, in_tensors, out_tensors = self._convertSavedModel(
171        saved_model_dir=saved_model_dir,
172        input_arrays=["Placeholder"],
173        input_shapes={"Placeholder": [1, 16, 16, 3]})
174
175    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
176    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
177    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
178
179  def testTwoInputArrays(self):
180    """Test a simple SavedModel."""
181    saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
182
183    _, in_tensors, out_tensors = self._convertSavedModel(
184        saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"])
185
186    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
187    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"])
188    self.assertEqual(
189        self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]])
190
191  def testSubsetInputArrays(self):
192    """Test a SavedModel with a subset of the input array names of the model."""
193    saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3])
194
195    # Check case where input shape is given.
196    _, in_tensors, out_tensors = self._convertSavedModel(
197        saved_model_dir=saved_model_dir,
198        input_arrays=["inputA"],
199        input_shapes={"inputA": [1, 16, 16, 3]})
200
201    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
202    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
203    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
204
205    # Check case where input shape is None.
206    _, in_tensors, out_tensors = self._convertSavedModel(
207        saved_model_dir=saved_model_dir, input_arrays=["inputA"])
208
209    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
210    self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"])
211    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]])
212
213  def testMultipleMetaGraphDef(self):
214    """Test saved model with multiple MetaGraphDefs."""
215    saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd")
216    builder = saved_model.builder.SavedModelBuilder(saved_model_dir)
217    with session.Session(graph=ops.Graph()) as sess:
218      # MetaGraphDef 1
219      in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32)
220      out_tensor = in_tensor + in_tensor
221      sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor)
222      sig_input_tensor_signature = {"x": sig_input_tensor}
223      sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor)
224      sig_output_tensor_signature = {"y": sig_output_tensor}
225      predict_signature_def = (
226          saved_model.signature_def_utils.build_signature_def(
227              sig_input_tensor_signature, sig_output_tensor_signature,
228              saved_model.signature_constants.PREDICT_METHOD_NAME))
229      signature_def_map = {
230          saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
231              predict_signature_def
232      }
233      builder.add_meta_graph_and_variables(
234          sess,
235          tags=[saved_model.tag_constants.SERVING, "additional_test_tag"],
236          signature_def_map=signature_def_map)
237
238      # MetaGraphDef 2
239      builder.add_meta_graph(tags=["tflite"])
240      builder.save(True)
241
242    # Convert to tflite
243    _, in_tensors, out_tensors = self._convertSavedModel(
244        saved_model_dir=saved_model_dir,
245        tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
246
247    self.assertEqual(self._getArrayNames(out_tensors), ["add:0"])
248    self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"])
249    self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]])
250
251
252if __name__ == "__main__":
253  test.main()
254