1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TFLite SavedModel conversion test cases. 16 17 - Tests converting simple SavedModel graph to TFLite FlatBuffer. 18 - Tests converting simple SavedModel graph to frozen graph. 19 - Tests converting MNIST SavedModel to TFLite FlatBuffer. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27from tensorflow.lite.python import convert_saved_model 28from tensorflow.python.client import session 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.platform import test 35from tensorflow.python.saved_model import saved_model 36from tensorflow.python.saved_model import signature_constants 37from tensorflow.python.saved_model import tag_constants 38 39 40class FreezeSavedModelTest(test_util.TensorFlowTestCase): 41 42 def _createSimpleSavedModel(self, shape): 43 """Create a simple SavedModel on the fly.""" 44 saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") 45 with session.Session() as sess: 46 in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32) 47 out_tensor = in_tensor + in_tensor 48 inputs = {"x": in_tensor} 49 outputs = {"y": out_tensor} 50 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 51 return saved_model_dir 52 53 def _createSavedModelTwoInputArrays(self, shape): 54 """Create a simple SavedModel.""" 55 saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") 56 with session.Session() as sess: 57 in_tensor_1 = array_ops.placeholder( 58 shape=shape, dtype=dtypes.float32, name="inputB") 59 in_tensor_2 = array_ops.placeholder( 60 shape=shape, dtype=dtypes.float32, name="inputA") 61 out_tensor = in_tensor_1 + in_tensor_2 62 inputs = {"x": in_tensor_1, "y": in_tensor_2} 63 outputs = {"z": out_tensor} 64 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 65 return saved_model_dir 66 67 def _getArrayNames(self, tensors): 68 return [tensor.name for tensor in tensors] 69 70 def _getArrayShapes(self, tensors): 71 dims = [] 72 for tensor in tensors: 73 dim_tensor = [] 74 for dim in tensor.shape: 75 if isinstance(dim, tensor_shape.Dimension): 76 dim_tensor.append(dim.value) 77 else: 78 dim_tensor.append(dim) 79 dims.append(dim_tensor) 80 return dims 81 82 def _convertSavedModel(self, 83 saved_model_dir, 84 input_arrays=None, 85 input_shapes=None, 86 output_arrays=None, 87 tag_set=None, 88 signature_key=None): 89 if tag_set is None: 90 tag_set = set([tag_constants.SERVING]) 91 if signature_key is None: 92 signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 93 graph_def, in_tensors, out_tensors, _ = ( 94 convert_saved_model.freeze_saved_model( 95 saved_model_dir=saved_model_dir, 96 input_arrays=input_arrays, 97 input_shapes=input_shapes, 98 output_arrays=output_arrays, 99 tag_set=tag_set, 100 signature_key=signature_key)) 101 return graph_def, in_tensors, out_tensors 102 103 def testSimpleSavedModel(self): 104 """Test a SavedModel.""" 105 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 106 _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) 107 108 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 109 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 110 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 111 112 def testSimpleSavedModelWithNoneBatchSizeInShape(self): 113 """Test a SavedModel with None in input tensor's shape.""" 114 saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) 115 _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) 116 117 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 118 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 119 self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) 120 121 def testSimpleSavedModelWithInvalidSignatureKey(self): 122 """Test a SavedModel that fails due to an invalid signature_key.""" 123 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 124 with self.assertRaises(ValueError) as error: 125 self._convertSavedModel(saved_model_dir, signature_key="invalid-key") 126 self.assertEqual( 127 "No 'invalid-key' in the SavedModel's SignatureDefs. " 128 "Possible values are 'serving_default'.", str(error.exception)) 129 130 def testSimpleSavedModelWithInvalidOutputArray(self): 131 """Test a SavedModel that fails due to invalid output arrays.""" 132 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 133 with self.assertRaises(ValueError) as error: 134 self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"]) 135 self.assertEqual("Invalid tensors 'invalid-output' were found.", 136 str(error.exception)) 137 138 def testSimpleSavedModelWithWrongInputArrays(self): 139 """Test a SavedModel that fails due to invalid input arrays.""" 140 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 141 142 # Check invalid input_arrays. 143 with self.assertRaises(ValueError) as error: 144 self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"]) 145 self.assertEqual("Invalid tensors 'invalid-input' were found.", 146 str(error.exception)) 147 148 # Check valid and invalid input_arrays. 149 with self.assertRaises(ValueError) as error: 150 self._convertSavedModel( 151 saved_model_dir, input_arrays=["Placeholder", "invalid-input"]) 152 self.assertEqual("Invalid tensors 'invalid-input' were found.", 153 str(error.exception)) 154 155 def testSimpleSavedModelWithCorrectArrays(self): 156 """Test a SavedModel with correct input_arrays and output_arrays.""" 157 saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) 158 _, in_tensors, out_tensors = self._convertSavedModel( 159 saved_model_dir=saved_model_dir, 160 input_arrays=["Placeholder"], 161 output_arrays=["add"]) 162 163 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 164 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 165 self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) 166 167 def testSimpleSavedModelWithCorrectInputArrays(self): 168 """Test a SavedModel with correct input_arrays and input_shapes.""" 169 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 170 _, in_tensors, out_tensors = self._convertSavedModel( 171 saved_model_dir=saved_model_dir, 172 input_arrays=["Placeholder"], 173 input_shapes={"Placeholder": [1, 16, 16, 3]}) 174 175 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 176 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 177 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 178 179 def testTwoInputArrays(self): 180 """Test a simple SavedModel.""" 181 saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) 182 183 _, in_tensors, out_tensors = self._convertSavedModel( 184 saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"]) 185 186 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 187 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"]) 188 self.assertEqual( 189 self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]]) 190 191 def testSubsetInputArrays(self): 192 """Test a SavedModel with a subset of the input array names of the model.""" 193 saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) 194 195 # Check case where input shape is given. 196 _, in_tensors, out_tensors = self._convertSavedModel( 197 saved_model_dir=saved_model_dir, 198 input_arrays=["inputA"], 199 input_shapes={"inputA": [1, 16, 16, 3]}) 200 201 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 202 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) 203 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 204 205 # Check case where input shape is None. 206 _, in_tensors, out_tensors = self._convertSavedModel( 207 saved_model_dir=saved_model_dir, input_arrays=["inputA"]) 208 209 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 210 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) 211 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 212 213 def testMultipleMetaGraphDef(self): 214 """Test saved model with multiple MetaGraphDefs.""" 215 saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") 216 builder = saved_model.builder.SavedModelBuilder(saved_model_dir) 217 with session.Session(graph=ops.Graph()) as sess: 218 # MetaGraphDef 1 219 in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32) 220 out_tensor = in_tensor + in_tensor 221 sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor) 222 sig_input_tensor_signature = {"x": sig_input_tensor} 223 sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor) 224 sig_output_tensor_signature = {"y": sig_output_tensor} 225 predict_signature_def = ( 226 saved_model.signature_def_utils.build_signature_def( 227 sig_input_tensor_signature, sig_output_tensor_signature, 228 saved_model.signature_constants.PREDICT_METHOD_NAME)) 229 signature_def_map = { 230 saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 231 predict_signature_def 232 } 233 builder.add_meta_graph_and_variables( 234 sess, 235 tags=[saved_model.tag_constants.SERVING, "additional_test_tag"], 236 signature_def_map=signature_def_map) 237 238 # MetaGraphDef 2 239 builder.add_meta_graph(tags=["tflite"]) 240 builder.save(True) 241 242 # Convert to tflite 243 _, in_tensors, out_tensors = self._convertSavedModel( 244 saved_model_dir=saved_model_dir, 245 tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) 246 247 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 248 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 249 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) 250 251 252if __name__ == "__main__": 253 test.main() 254