1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TFLite SavedModel conversion test cases. 16 17 - Tests converting simple SavedModel graph to TFLite FlatBuffer. 18 - Tests converting simple SavedModel graph to frozen graph. 19 - Tests converting MNIST SavedModel to TFLite FlatBuffer. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27from tensorflow.lite.python import convert_saved_model 28from tensorflow.python.client import session 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.platform import test 35from tensorflow.python.saved_model import saved_model 36from tensorflow.python.saved_model import signature_constants 37from tensorflow.python.saved_model import tag_constants 38 39 40class TensorFunctionsTest(test_util.TensorFlowTestCase): 41 42 @test_util.run_v1_only("b/120545219") 43 def testGetTensorsValid(self): 44 in_tensor = array_ops.placeholder( 45 shape=[1, 16, 16, 3], dtype=dtypes.float32) 46 _ = in_tensor + in_tensor 47 sess = session.Session() 48 49 tensors = convert_saved_model.get_tensors_from_tensor_names( 50 sess.graph, ["Placeholder"]) 51 self.assertEqual("Placeholder:0", tensors[0].name) 52 53 @test_util.run_v1_only("b/120545219") 54 def testGetTensorsInvalid(self): 55 in_tensor = array_ops.placeholder( 56 shape=[1, 16, 16, 3], dtype=dtypes.float32) 57 _ = in_tensor + in_tensor 58 sess = session.Session() 59 60 with self.assertRaises(ValueError) as error: 61 convert_saved_model.get_tensors_from_tensor_names(sess.graph, 62 ["invalid-input"]) 63 self.assertEqual("Invalid tensors 'invalid-input' were found.", 64 str(error.exception)) 65 66 @test_util.run_v1_only("b/120545219") 67 def testSetTensorShapeValid(self): 68 tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) 69 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 70 71 convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) 72 self.assertEqual([5, 3, 5], tensor.shape.as_list()) 73 74 @test_util.run_v1_only("b/120545219") 75 def testSetTensorShapeNoneValid(self): 76 tensor = array_ops.placeholder(dtype=dtypes.float32) 77 self.assertEqual(None, tensor.shape) 78 79 convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) 80 self.assertEqual([1, 3, 5], tensor.shape.as_list()) 81 82 @test_util.run_v1_only("b/120545219") 83 def testSetTensorShapeArrayInvalid(self): 84 # Tests set_tensor_shape where the tensor name passed in doesn't exist. 85 tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) 86 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 87 88 with self.assertRaises(ValueError) as error: 89 convert_saved_model.set_tensor_shapes([tensor], 90 {"invalid-input": [5, 3, 5]}) 91 self.assertEqual( 92 "Invalid tensor 'invalid-input' found in tensor shapes map.", 93 str(error.exception)) 94 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 95 96 @test_util.run_deprecated_v1 97 def testSetTensorShapeDimensionInvalid(self): 98 # Tests set_tensor_shape where the shape passed in is incompatiable. 99 tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) 100 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 101 102 with self.assertRaises(ValueError) as error: 103 convert_saved_model.set_tensor_shapes([tensor], 104 {"Placeholder": [1, 5, 5]}) 105 self.assertIn("The shape of tensor 'Placeholder' cannot be changed", 106 str(error.exception)) 107 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 108 109 @test_util.run_v1_only("b/120545219") 110 def testSetTensorShapeEmpty(self): 111 tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) 112 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 113 114 convert_saved_model.set_tensor_shapes([tensor], {}) 115 self.assertEqual([None, 3, 5], tensor.shape.as_list()) 116 117 118class FreezeSavedModelTest(test_util.TensorFlowTestCase): 119 120 def _createSimpleSavedModel(self, shape): 121 """Create a simple SavedModel on the fly.""" 122 saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") 123 with session.Session() as sess: 124 in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32) 125 out_tensor = in_tensor + in_tensor 126 inputs = {"x": in_tensor} 127 outputs = {"y": out_tensor} 128 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 129 return saved_model_dir 130 131 def _createSavedModelTwoInputArrays(self, shape): 132 """Create a simple SavedModel.""" 133 saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") 134 with session.Session() as sess: 135 in_tensor_1 = array_ops.placeholder( 136 shape=shape, dtype=dtypes.float32, name="inputB") 137 in_tensor_2 = array_ops.placeholder( 138 shape=shape, dtype=dtypes.float32, name="inputA") 139 out_tensor = in_tensor_1 + in_tensor_2 140 inputs = {"x": in_tensor_1, "y": in_tensor_2} 141 outputs = {"z": out_tensor} 142 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 143 return saved_model_dir 144 145 def _getArrayNames(self, tensors): 146 return [tensor.name for tensor in tensors] 147 148 def _getArrayShapes(self, tensors): 149 dims = [] 150 for tensor in tensors: 151 dim_tensor = [] 152 for dim in tensor.shape: 153 if isinstance(dim, tensor_shape.Dimension): 154 dim_tensor.append(dim.value) 155 else: 156 dim_tensor.append(dim) 157 dims.append(dim_tensor) 158 return dims 159 160 def _convertSavedModel(self, 161 saved_model_dir, 162 input_arrays=None, 163 input_shapes=None, 164 output_arrays=None, 165 tag_set=None, 166 signature_key=None): 167 if tag_set is None: 168 tag_set = set([tag_constants.SERVING]) 169 if signature_key is None: 170 signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 171 graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( 172 saved_model_dir=saved_model_dir, 173 input_arrays=input_arrays, 174 input_shapes=input_shapes, 175 output_arrays=output_arrays, 176 tag_set=tag_set, 177 signature_key=signature_key) 178 return graph_def, in_tensors, out_tensors 179 180 def testSimpleSavedModel(self): 181 """Test a SavedModel.""" 182 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 183 _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) 184 185 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 186 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 187 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 188 189 def testSimpleSavedModelWithNoneBatchSizeInShape(self): 190 """Test a SavedModel with None in input tensor's shape.""" 191 saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) 192 _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) 193 194 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 195 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 196 self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) 197 198 def testSimpleSavedModelWithInvalidSignatureKey(self): 199 """Test a SavedModel that fails due to an invalid signature_key.""" 200 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 201 with self.assertRaises(ValueError) as error: 202 self._convertSavedModel(saved_model_dir, signature_key="invalid-key") 203 self.assertEqual( 204 "No 'invalid-key' in the SavedModel's SignatureDefs. " 205 "Possible values are 'serving_default'.", str(error.exception)) 206 207 def testSimpleSavedModelWithInvalidOutputArray(self): 208 """Test a SavedModel that fails due to invalid output arrays.""" 209 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 210 with self.assertRaises(ValueError) as error: 211 self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"]) 212 self.assertEqual("Invalid tensors 'invalid-output' were found.", 213 str(error.exception)) 214 215 def testSimpleSavedModelWithWrongInputArrays(self): 216 """Test a SavedModel that fails due to invalid input arrays.""" 217 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 218 219 # Check invalid input_arrays. 220 with self.assertRaises(ValueError) as error: 221 self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"]) 222 self.assertEqual("Invalid tensors 'invalid-input' were found.", 223 str(error.exception)) 224 225 # Check valid and invalid input_arrays. 226 with self.assertRaises(ValueError) as error: 227 self._convertSavedModel( 228 saved_model_dir, input_arrays=["Placeholder", "invalid-input"]) 229 self.assertEqual("Invalid tensors 'invalid-input' were found.", 230 str(error.exception)) 231 232 def testSimpleSavedModelWithCorrectArrays(self): 233 """Test a SavedModel with correct input_arrays and output_arrays.""" 234 saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) 235 _, in_tensors, out_tensors = self._convertSavedModel( 236 saved_model_dir=saved_model_dir, 237 input_arrays=["Placeholder"], 238 output_arrays=["add"]) 239 240 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 241 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 242 self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) 243 244 def testSimpleSavedModelWithCorrectInputArrays(self): 245 """Test a SavedModel with correct input_arrays and input_shapes.""" 246 saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) 247 _, in_tensors, out_tensors = self._convertSavedModel( 248 saved_model_dir=saved_model_dir, 249 input_arrays=["Placeholder"], 250 input_shapes={"Placeholder": [1, 16, 16, 3]}) 251 252 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 253 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 254 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 255 256 def testTwoInputArrays(self): 257 """Test a simple SavedModel.""" 258 saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) 259 260 _, in_tensors, out_tensors = self._convertSavedModel( 261 saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"]) 262 263 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 264 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"]) 265 self.assertEqual( 266 self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]]) 267 268 def testSubsetInputArrays(self): 269 """Test a SavedModel with a subset of the input array names of the model.""" 270 saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) 271 272 # Check case where input shape is given. 273 _, in_tensors, out_tensors = self._convertSavedModel( 274 saved_model_dir=saved_model_dir, 275 input_arrays=["inputA"], 276 input_shapes={"inputA": [1, 16, 16, 3]}) 277 278 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 279 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) 280 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 281 282 # Check case where input shape is None. 283 _, in_tensors, out_tensors = self._convertSavedModel( 284 saved_model_dir=saved_model_dir, input_arrays=["inputA"]) 285 286 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 287 self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) 288 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) 289 290 def testMultipleMetaGraphDef(self): 291 """Test saved model with multiple MetaGraphDefs.""" 292 saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") 293 builder = saved_model.builder.SavedModelBuilder(saved_model_dir) 294 with session.Session(graph=ops.Graph()) as sess: 295 # MetaGraphDef 1 296 in_tensor = array_ops.placeholder(shape=[1, 28, 28], dtype=dtypes.float32) 297 out_tensor = in_tensor + in_tensor 298 sig_input_tensor = saved_model.utils.build_tensor_info(in_tensor) 299 sig_input_tensor_signature = {"x": sig_input_tensor} 300 sig_output_tensor = saved_model.utils.build_tensor_info(out_tensor) 301 sig_output_tensor_signature = {"y": sig_output_tensor} 302 predict_signature_def = ( 303 saved_model.signature_def_utils.build_signature_def( 304 sig_input_tensor_signature, sig_output_tensor_signature, 305 saved_model.signature_constants.PREDICT_METHOD_NAME)) 306 signature_def_map = { 307 saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 308 predict_signature_def 309 } 310 builder.add_meta_graph_and_variables( 311 sess, 312 tags=[saved_model.tag_constants.SERVING, "additional_test_tag"], 313 signature_def_map=signature_def_map) 314 315 # MetaGraphDef 2 316 builder.add_meta_graph(tags=["tflite"]) 317 builder.save(True) 318 319 # Convert to tflite 320 _, in_tensors, out_tensors = self._convertSavedModel( 321 saved_model_dir=saved_model_dir, 322 tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) 323 324 self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) 325 self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) 326 self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) 327 328 329if __name__ == "__main__": 330 test.main() 331