1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SignatureDef utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import types_pb2 22from tensorflow.core.protobuf import meta_graph_pb2 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import test 29from tensorflow.python.saved_model import signature_constants 30from tensorflow.python.saved_model import signature_def_utils_impl 31from tensorflow.python.saved_model import utils 32 33 34# We'll reuse the same tensor_infos in multiple contexts just for the tests. 35# The validator doesn't check shapes so we just omit them. 36_STRING = meta_graph_pb2.TensorInfo( 37 name="foobar", 38 dtype=dtypes.string.as_datatype_enum 39) 40 41 42_FLOAT = meta_graph_pb2.TensorInfo( 43 name="foobar", 44 dtype=dtypes.float32.as_datatype_enum 45) 46 47 48def _make_signature(inputs, outputs, name=None): 49 input_info = { 50 input_name: utils.build_tensor_info(tensor) 51 for input_name, tensor in inputs.items() 52 } 53 output_info = { 54 output_name: utils.build_tensor_info(tensor) 55 for output_name, tensor in outputs.items() 56 } 57 return signature_def_utils_impl.build_signature_def(input_info, output_info, 58 name) 59 60 61class SignatureDefUtilsTest(test.TestCase): 62 63 def testBuildSignatureDef(self): 64 # Force the test to run in graph mode. 65 # This tests a deprecated v1 API that uses functionality that does not work 66 # with eager tensors (namely build_tensor_info). 67 with ops.Graph().as_default(): 68 x = array_ops.placeholder(dtypes.float32, 1, name="x") 69 x_tensor_info = utils.build_tensor_info(x) 70 inputs = {} 71 inputs["foo-input"] = x_tensor_info 72 73 y = array_ops.placeholder(dtypes.float32, name="y") 74 y_tensor_info = utils.build_tensor_info(y) 75 outputs = {} 76 outputs["foo-output"] = y_tensor_info 77 78 signature_def = signature_def_utils_impl.build_signature_def( 79 inputs, outputs, "foo-method-name") 80 self.assertEqual("foo-method-name", signature_def.method_name) 81 82 # Check inputs in signature def. 83 self.assertEqual(1, len(signature_def.inputs)) 84 x_tensor_info_actual = signature_def.inputs["foo-input"] 85 self.assertEqual("x:0", x_tensor_info_actual.name) 86 self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) 87 self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) 88 self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) 89 90 # Check outputs in signature def. 91 self.assertEqual(1, len(signature_def.outputs)) 92 y_tensor_info_actual = signature_def.outputs["foo-output"] 93 self.assertEqual("y:0", y_tensor_info_actual.name) 94 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 95 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 96 97 def testRegressionSignatureDef(self): 98 # Force the test to run in graph mode. 99 # This tests a deprecated v1 API that uses functionality that does not work 100 # with eager tensors (namely build_tensor_info). 101 with ops.Graph().as_default(): 102 input1 = constant_op.constant("a", name="input-1") 103 output1 = constant_op.constant(2.2, name="output-1") 104 signature_def = signature_def_utils_impl.regression_signature_def( 105 input1, output1) 106 107 self.assertEqual(signature_constants.REGRESS_METHOD_NAME, 108 signature_def.method_name) 109 110 # Check inputs in signature def. 111 self.assertEqual(1, len(signature_def.inputs)) 112 x_tensor_info_actual = ( 113 signature_def.inputs[signature_constants.REGRESS_INPUTS]) 114 self.assertEqual("input-1:0", x_tensor_info_actual.name) 115 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 116 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 117 118 # Check outputs in signature def. 119 self.assertEqual(1, len(signature_def.outputs)) 120 y_tensor_info_actual = ( 121 signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) 122 self.assertEqual("output-1:0", y_tensor_info_actual.name) 123 self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) 124 self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) 125 126 def testClassificationSignatureDef(self): 127 # Force the test to run in graph mode. 128 # This tests a deprecated v1 API that uses functionality that does not work 129 # with eager tensors (namely build_tensor_info). 130 with ops.Graph().as_default(): 131 input1 = constant_op.constant("a", name="input-1") 132 output1 = constant_op.constant("b", name="output-1") 133 output2 = constant_op.constant(3.3, name="output-2") 134 signature_def = signature_def_utils_impl.classification_signature_def( 135 input1, output1, output2) 136 137 self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, 138 signature_def.method_name) 139 140 # Check inputs in signature def. 141 self.assertEqual(1, len(signature_def.inputs)) 142 x_tensor_info_actual = ( 143 signature_def.inputs[signature_constants.CLASSIFY_INPUTS]) 144 self.assertEqual("input-1:0", x_tensor_info_actual.name) 145 self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) 146 self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) 147 148 # Check outputs in signature def. 149 self.assertEqual(2, len(signature_def.outputs)) 150 classes_tensor_info_actual = ( 151 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES]) 152 self.assertEqual("output-1:0", classes_tensor_info_actual.name) 153 self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype) 154 self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim)) 155 scores_tensor_info_actual = ( 156 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) 157 self.assertEqual("output-2:0", scores_tensor_info_actual.name) 158 self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype) 159 self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) 160 161 def testPredictionSignatureDef(self): 162 # Force the test to run in graph mode. 163 # This tests a deprecated v1 API that uses functionality that does not work 164 # with eager tensors (namely build_tensor_info). 165 with ops.Graph().as_default(): 166 input1 = constant_op.constant("a", name="input-1") 167 input2 = constant_op.constant("b", name="input-2") 168 output1 = constant_op.constant("c", name="output-1") 169 output2 = constant_op.constant("d", name="output-2") 170 signature_def = signature_def_utils_impl.predict_signature_def( 171 { 172 "input-1": input1, 173 "input-2": input2 174 }, { 175 "output-1": output1, 176 "output-2": output2 177 }) 178 179 self.assertEqual(signature_constants.PREDICT_METHOD_NAME, 180 signature_def.method_name) 181 182 # Check inputs in signature def. 183 self.assertEqual(2, len(signature_def.inputs)) 184 input1_tensor_info_actual = (signature_def.inputs["input-1"]) 185 self.assertEqual("input-1:0", input1_tensor_info_actual.name) 186 self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) 187 self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) 188 input2_tensor_info_actual = (signature_def.inputs["input-2"]) 189 self.assertEqual("input-2:0", input2_tensor_info_actual.name) 190 self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) 191 self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) 192 193 # Check outputs in signature def. 194 self.assertEqual(2, len(signature_def.outputs)) 195 output1_tensor_info_actual = (signature_def.outputs["output-1"]) 196 self.assertEqual("output-1:0", output1_tensor_info_actual.name) 197 self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype) 198 self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim)) 199 output2_tensor_info_actual = (signature_def.outputs["output-2"]) 200 self.assertEqual("output-2:0", output2_tensor_info_actual.name) 201 self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) 202 self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) 203 204 def testTrainSignatureDef(self): 205 self._testSupervisedSignatureDef( 206 signature_def_utils_impl.supervised_train_signature_def, 207 signature_constants.SUPERVISED_TRAIN_METHOD_NAME) 208 209 def testEvalSignatureDef(self): 210 self._testSupervisedSignatureDef( 211 signature_def_utils_impl.supervised_eval_signature_def, 212 signature_constants.SUPERVISED_EVAL_METHOD_NAME) 213 214 def _testSupervisedSignatureDef(self, fn_to_test, method_name): 215 # Force the test to run in graph mode. 216 # This tests a deprecated v1 API that uses functionality that does not work 217 # with eager tensors (namely build_tensor_info). 218 with ops.Graph().as_default(): 219 inputs = { 220 "input-1": constant_op.constant("a", name="input-1"), 221 "input-2": constant_op.constant("b", name="input-2"), 222 } 223 loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} 224 predictions = { 225 "classes": constant_op.constant([100], name="classes"), 226 } 227 metrics_val = constant_op.constant(100.0, name="metrics_val") 228 metrics = { 229 "metrics/value": 230 metrics_val, 231 "metrics/update_op": 232 array_ops.identity(metrics_val, name="metrics_op"), 233 } 234 235 signature_def = fn_to_test(inputs, loss, predictions, metrics) 236 237 self.assertEqual(method_name, signature_def.method_name) 238 239 # Check inputs in signature def. 240 self.assertEqual(2, len(signature_def.inputs)) 241 input1_tensor_info_actual = (signature_def.inputs["input-1"]) 242 self.assertEqual("input-1:0", input1_tensor_info_actual.name) 243 self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) 244 self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) 245 input2_tensor_info_actual = (signature_def.inputs["input-2"]) 246 self.assertEqual("input-2:0", input2_tensor_info_actual.name) 247 self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) 248 self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) 249 250 # Check outputs in signature def. 251 self.assertEqual(4, len(signature_def.outputs)) 252 self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) 253 self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) 254 255 self.assertEqual("classes:0", signature_def.outputs["classes"].name) 256 self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) 257 258 self.assertEqual( 259 "metrics_val:0", signature_def.outputs["metrics/value"].name) 260 self.assertEqual( 261 types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) 262 263 self.assertEqual( 264 "metrics_op:0", signature_def.outputs["metrics/update_op"].name) 265 self.assertEqual( 266 types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) 267 268 def testTrainSignatureDefMissingInputs(self): 269 self._testSupervisedSignatureDefMissingInputs( 270 signature_def_utils_impl.supervised_train_signature_def, 271 signature_constants.SUPERVISED_TRAIN_METHOD_NAME) 272 273 def testEvalSignatureDefMissingInputs(self): 274 self._testSupervisedSignatureDefMissingInputs( 275 signature_def_utils_impl.supervised_eval_signature_def, 276 signature_constants.SUPERVISED_EVAL_METHOD_NAME) 277 278 def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): 279 # Force the test to run in graph mode. 280 # This tests a deprecated v1 API that uses functionality that does not work 281 # with eager tensors (namely build_tensor_info). 282 with ops.Graph().as_default(): 283 inputs = { 284 "input-1": constant_op.constant("a", name="input-1"), 285 "input-2": constant_op.constant("b", name="input-2"), 286 } 287 loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} 288 predictions = { 289 "classes": constant_op.constant([100], name="classes"), 290 } 291 metrics_val = constant_op.constant(100, name="metrics_val") 292 metrics = { 293 "metrics/value": 294 metrics_val, 295 "metrics/update_op": 296 array_ops.identity(metrics_val, name="metrics_op"), 297 } 298 299 with self.assertRaises(ValueError): 300 signature_def = fn_to_test({}, 301 loss=loss, 302 predictions=predictions, 303 metrics=metrics) 304 305 signature_def = fn_to_test(inputs, loss=loss) 306 self.assertEqual(method_name, signature_def.method_name) 307 self.assertEqual(1, len(signature_def.outputs)) 308 309 signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) 310 self.assertEqual(method_name, signature_def.method_name) 311 self.assertEqual(3, len(signature_def.outputs)) 312 313 def _assertValidSignature(self, inputs, outputs, method_name): 314 signature_def = signature_def_utils_impl.build_signature_def( 315 inputs, outputs, method_name) 316 self.assertTrue( 317 signature_def_utils_impl.is_valid_signature(signature_def)) 318 319 def _assertInvalidSignature(self, inputs, outputs, method_name): 320 signature_def = signature_def_utils_impl.build_signature_def( 321 inputs, outputs, method_name) 322 self.assertFalse( 323 signature_def_utils_impl.is_valid_signature(signature_def)) 324 325 def testValidSignaturesAreAccepted(self): 326 self._assertValidSignature( 327 {"inputs": _STRING}, 328 {"classes": _STRING, "scores": _FLOAT}, 329 signature_constants.CLASSIFY_METHOD_NAME) 330 331 self._assertValidSignature( 332 {"inputs": _STRING}, 333 {"classes": _STRING}, 334 signature_constants.CLASSIFY_METHOD_NAME) 335 336 self._assertValidSignature( 337 {"inputs": _STRING}, 338 {"scores": _FLOAT}, 339 signature_constants.CLASSIFY_METHOD_NAME) 340 341 self._assertValidSignature( 342 {"inputs": _STRING}, 343 {"outputs": _FLOAT}, 344 signature_constants.REGRESS_METHOD_NAME) 345 346 self._assertValidSignature( 347 {"foo": _STRING, "bar": _FLOAT}, 348 {"baz": _STRING, "qux": _FLOAT}, 349 signature_constants.PREDICT_METHOD_NAME) 350 351 def testInvalidMethodNameSignatureIsRejected(self): 352 # WRONG METHOD 353 self._assertInvalidSignature( 354 {"inputs": _STRING}, 355 {"classes": _STRING, "scores": _FLOAT}, 356 "WRONG method name") 357 358 def testInvalidClassificationSignaturesAreRejected(self): 359 # CLASSIFY: wrong types 360 self._assertInvalidSignature( 361 {"inputs": _FLOAT}, 362 {"classes": _STRING, "scores": _FLOAT}, 363 signature_constants.CLASSIFY_METHOD_NAME) 364 365 self._assertInvalidSignature( 366 {"inputs": _STRING}, 367 {"classes": _FLOAT, "scores": _FLOAT}, 368 signature_constants.CLASSIFY_METHOD_NAME) 369 370 self._assertInvalidSignature( 371 {"inputs": _STRING}, 372 {"classes": _STRING, "scores": _STRING}, 373 signature_constants.CLASSIFY_METHOD_NAME) 374 375 # CLASSIFY: wrong keys 376 self._assertInvalidSignature( 377 {}, 378 {"classes": _STRING, "scores": _FLOAT}, 379 signature_constants.CLASSIFY_METHOD_NAME) 380 381 self._assertInvalidSignature( 382 {"inputs_WRONG": _STRING}, 383 {"classes": _STRING, "scores": _FLOAT}, 384 signature_constants.CLASSIFY_METHOD_NAME) 385 386 self._assertInvalidSignature( 387 {"inputs": _STRING}, 388 {"classes_WRONG": _STRING, "scores": _FLOAT}, 389 signature_constants.CLASSIFY_METHOD_NAME) 390 391 self._assertInvalidSignature( 392 {"inputs": _STRING}, 393 {}, 394 signature_constants.CLASSIFY_METHOD_NAME) 395 396 self._assertInvalidSignature( 397 {"inputs": _STRING}, 398 {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING}, 399 signature_constants.CLASSIFY_METHOD_NAME) 400 401 def testInvalidRegressionSignaturesAreRejected(self): 402 # REGRESS: wrong types 403 self._assertInvalidSignature( 404 {"inputs": _FLOAT}, 405 {"outputs": _FLOAT}, 406 signature_constants.REGRESS_METHOD_NAME) 407 408 self._assertInvalidSignature( 409 {"inputs": _STRING}, 410 {"outputs": _STRING}, 411 signature_constants.REGRESS_METHOD_NAME) 412 413 # REGRESS: wrong keys 414 self._assertInvalidSignature( 415 {}, 416 {"outputs": _FLOAT}, 417 signature_constants.REGRESS_METHOD_NAME) 418 419 self._assertInvalidSignature( 420 {"inputs_WRONG": _STRING}, 421 {"outputs": _FLOAT}, 422 signature_constants.REGRESS_METHOD_NAME) 423 424 self._assertInvalidSignature( 425 {"inputs": _STRING}, 426 {"outputs_WRONG": _FLOAT}, 427 signature_constants.REGRESS_METHOD_NAME) 428 429 self._assertInvalidSignature( 430 {"inputs": _STRING}, 431 {}, 432 signature_constants.REGRESS_METHOD_NAME) 433 434 self._assertInvalidSignature( 435 {"inputs": _STRING}, 436 {"outputs": _FLOAT, "extra_WRONG": _STRING}, 437 signature_constants.REGRESS_METHOD_NAME) 438 439 def testInvalidPredictSignaturesAreRejected(self): 440 # PREDICT: wrong keys 441 self._assertInvalidSignature( 442 {}, 443 {"baz": _STRING, "qux": _FLOAT}, 444 signature_constants.PREDICT_METHOD_NAME) 445 446 self._assertInvalidSignature( 447 {"foo": _STRING, "bar": _FLOAT}, 448 {}, 449 signature_constants.PREDICT_METHOD_NAME) 450 451 def testOpSignatureDef(self): 452 # Force the test to run in graph mode. 453 # This tests a deprecated v1 API that uses functionality that does not work 454 # with eager tensors (namely build_tensor_info_from_op). 455 with ops.Graph().as_default(): 456 key = "adding_1_and_2_key" 457 add_op = math_ops.add(1, 2, name="adding_1_and_2") 458 signature_def = signature_def_utils_impl.op_signature_def(add_op, key) 459 460 self.assertIn(key, signature_def.outputs) 461 self.assertEqual(add_op.name, signature_def.outputs[key].name) 462 463 def testLoadOpFromSignatureDef(self): 464 # Force the test to run in graph mode. 465 # This tests a deprecated v1 API that uses functionality that does not work 466 # with eager tensors (namely build_tensor_info_from_op). 467 with ops.Graph().as_default(): 468 key = "adding_1_and_2_key" 469 add_op = math_ops.add(1, 2, name="adding_1_and_2") 470 signature_def = signature_def_utils_impl.op_signature_def(add_op, key) 471 self.assertEqual( 472 add_op, 473 signature_def_utils_impl.load_op_from_signature_def( 474 signature_def, key)) 475 476 477if __name__ == "__main__": 478 test.main() 479