1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SignatureDef utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import types_pb2
22from tensorflow.core.protobuf import meta_graph_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29from tensorflow.python.saved_model import signature_constants
30from tensorflow.python.saved_model import signature_def_utils_impl
31from tensorflow.python.saved_model import utils
32
33
34# We'll reuse the same tensor_infos in multiple contexts just for the tests.
35# The validator doesn't check shapes so we just omit them.
36_STRING = meta_graph_pb2.TensorInfo(
37    name="foobar",
38    dtype=dtypes.string.as_datatype_enum
39)
40
41
42_FLOAT = meta_graph_pb2.TensorInfo(
43    name="foobar",
44    dtype=dtypes.float32.as_datatype_enum
45)
46
47
48def _make_signature(inputs, outputs, name=None):
49  input_info = {
50      input_name: utils.build_tensor_info(tensor)
51      for input_name, tensor in inputs.items()
52  }
53  output_info = {
54      output_name: utils.build_tensor_info(tensor)
55      for output_name, tensor in outputs.items()
56  }
57  return signature_def_utils_impl.build_signature_def(input_info, output_info,
58                                                      name)
59
60
61class SignatureDefUtilsTest(test.TestCase):
62
63  @test_util.run_deprecated_v1
64  def testBuildSignatureDef(self):
65    x = array_ops.placeholder(dtypes.float32, 1, name="x")
66    x_tensor_info = utils.build_tensor_info(x)
67    inputs = dict()
68    inputs["foo-input"] = x_tensor_info
69
70    y = array_ops.placeholder(dtypes.float32, name="y")
71    y_tensor_info = utils.build_tensor_info(y)
72    outputs = dict()
73    outputs["foo-output"] = y_tensor_info
74
75    signature_def = signature_def_utils_impl.build_signature_def(
76        inputs, outputs, "foo-method-name")
77    self.assertEqual("foo-method-name", signature_def.method_name)
78
79    # Check inputs in signature def.
80    self.assertEqual(1, len(signature_def.inputs))
81    x_tensor_info_actual = signature_def.inputs["foo-input"]
82    self.assertEqual("x:0", x_tensor_info_actual.name)
83    self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
84    self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
85    self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
86
87    # Check outputs in signature def.
88    self.assertEqual(1, len(signature_def.outputs))
89    y_tensor_info_actual = signature_def.outputs["foo-output"]
90    self.assertEqual("y:0", y_tensor_info_actual.name)
91    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
92    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
93
94  @test_util.run_deprecated_v1
95  def testRegressionSignatureDef(self):
96    input1 = constant_op.constant("a", name="input-1")
97    output1 = constant_op.constant(2.2, name="output-1")
98    signature_def = signature_def_utils_impl.regression_signature_def(
99        input1, output1)
100
101    self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
102                     signature_def.method_name)
103
104    # Check inputs in signature def.
105    self.assertEqual(1, len(signature_def.inputs))
106    x_tensor_info_actual = (
107        signature_def.inputs[signature_constants.REGRESS_INPUTS])
108    self.assertEqual("input-1:0", x_tensor_info_actual.name)
109    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
110    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
111
112    # Check outputs in signature def.
113    self.assertEqual(1, len(signature_def.outputs))
114    y_tensor_info_actual = (
115        signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
116    self.assertEqual("output-1:0", y_tensor_info_actual.name)
117    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
118    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
119
120  @test_util.run_deprecated_v1
121  def testClassificationSignatureDef(self):
122    input1 = constant_op.constant("a", name="input-1")
123    output1 = constant_op.constant("b", name="output-1")
124    output2 = constant_op.constant(3.3, name="output-2")
125    signature_def = signature_def_utils_impl.classification_signature_def(
126        input1, output1, output2)
127
128    self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
129                     signature_def.method_name)
130
131    # Check inputs in signature def.
132    self.assertEqual(1, len(signature_def.inputs))
133    x_tensor_info_actual = (
134        signature_def.inputs[signature_constants.CLASSIFY_INPUTS])
135    self.assertEqual("input-1:0", x_tensor_info_actual.name)
136    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
137    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
138
139    # Check outputs in signature def.
140    self.assertEqual(2, len(signature_def.outputs))
141    classes_tensor_info_actual = (
142        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES])
143    self.assertEqual("output-1:0", classes_tensor_info_actual.name)
144    self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype)
145    self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim))
146    scores_tensor_info_actual = (
147        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
148    self.assertEqual("output-2:0", scores_tensor_info_actual.name)
149    self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype)
150    self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
151
152  @test_util.run_deprecated_v1
153  def testPredictionSignatureDef(self):
154    input1 = constant_op.constant("a", name="input-1")
155    input2 = constant_op.constant("b", name="input-2")
156    output1 = constant_op.constant("c", name="output-1")
157    output2 = constant_op.constant("d", name="output-2")
158    signature_def = signature_def_utils_impl.predict_signature_def({
159        "input-1": input1,
160        "input-2": input2
161    }, {"output-1": output1,
162        "output-2": output2})
163
164    self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
165                     signature_def.method_name)
166
167    # Check inputs in signature def.
168    self.assertEqual(2, len(signature_def.inputs))
169    input1_tensor_info_actual = (signature_def.inputs["input-1"])
170    self.assertEqual("input-1:0", input1_tensor_info_actual.name)
171    self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
172    self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
173    input2_tensor_info_actual = (signature_def.inputs["input-2"])
174    self.assertEqual("input-2:0", input2_tensor_info_actual.name)
175    self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
176    self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
177
178    # Check outputs in signature def.
179    self.assertEqual(2, len(signature_def.outputs))
180    output1_tensor_info_actual = (signature_def.outputs["output-1"])
181    self.assertEqual("output-1:0", output1_tensor_info_actual.name)
182    self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype)
183    self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim))
184    output2_tensor_info_actual = (signature_def.outputs["output-2"])
185    self.assertEqual("output-2:0", output2_tensor_info_actual.name)
186    self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
187    self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
188
189  @test_util.run_deprecated_v1
190  def testTrainSignatureDef(self):
191    self._testSupervisedSignatureDef(
192        signature_def_utils_impl.supervised_train_signature_def,
193        signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
194
195  @test_util.run_deprecated_v1
196  def testEvalSignatureDef(self):
197    self._testSupervisedSignatureDef(
198        signature_def_utils_impl.supervised_eval_signature_def,
199        signature_constants.SUPERVISED_EVAL_METHOD_NAME)
200
201  def _testSupervisedSignatureDef(self, fn_to_test, method_name):
202    inputs = {
203        "input-1": constant_op.constant("a", name="input-1"),
204        "input-2": constant_op.constant("b", name="input-2"),
205    }
206    loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
207    predictions = {
208        "classes": constant_op.constant([100], name="classes"),
209    }
210    metrics_val = constant_op.constant(100.0, name="metrics_val")
211    metrics = {
212        "metrics/value": metrics_val,
213        "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
214    }
215
216    signature_def = fn_to_test(inputs, loss, predictions, metrics)
217
218    self.assertEqual(method_name, signature_def.method_name)
219
220    # Check inputs in signature def.
221    self.assertEqual(2, len(signature_def.inputs))
222    input1_tensor_info_actual = (signature_def.inputs["input-1"])
223    self.assertEqual("input-1:0", input1_tensor_info_actual.name)
224    self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
225    self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
226    input2_tensor_info_actual = (signature_def.inputs["input-2"])
227    self.assertEqual("input-2:0", input2_tensor_info_actual.name)
228    self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
229    self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
230
231    # Check outputs in signature def.
232    self.assertEqual(4, len(signature_def.outputs))
233    self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name)
234    self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype)
235
236    self.assertEqual("classes:0", signature_def.outputs["classes"].name)
237    self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim))
238
239    self.assertEqual(
240        "metrics_val:0", signature_def.outputs["metrics/value"].name)
241    self.assertEqual(
242        types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
243
244    self.assertEqual(
245        "metrics_op:0", signature_def.outputs["metrics/update_op"].name)
246    self.assertEqual(
247        types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
248
249  @test_util.run_deprecated_v1
250  def testTrainSignatureDefMissingInputs(self):
251    self._testSupervisedSignatureDefMissingInputs(
252        signature_def_utils_impl.supervised_train_signature_def,
253        signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
254
255  @test_util.run_deprecated_v1
256  def testEvalSignatureDefMissingInputs(self):
257    self._testSupervisedSignatureDefMissingInputs(
258        signature_def_utils_impl.supervised_eval_signature_def,
259        signature_constants.SUPERVISED_EVAL_METHOD_NAME)
260
261  def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name):
262    inputs = {
263        "input-1": constant_op.constant("a", name="input-1"),
264        "input-2": constant_op.constant("b", name="input-2"),
265    }
266    loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
267    predictions = {
268        "classes": constant_op.constant([100], name="classes"),
269    }
270    metrics_val = constant_op.constant(100, name="metrics_val")
271    metrics = {
272        "metrics/value": metrics_val,
273        "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
274    }
275
276    with self.assertRaises(ValueError):
277      signature_def = fn_to_test(
278          {}, loss=loss, predictions=predictions, metrics=metrics)
279
280    signature_def = fn_to_test(inputs, loss=loss)
281    self.assertEqual(method_name, signature_def.method_name)
282    self.assertEqual(1, len(signature_def.outputs))
283
284    signature_def = fn_to_test(inputs, metrics=metrics, loss=loss)
285    self.assertEqual(method_name, signature_def.method_name)
286    self.assertEqual(3, len(signature_def.outputs))
287
288  def _assertValidSignature(self, inputs, outputs, method_name):
289    signature_def = signature_def_utils_impl.build_signature_def(
290        inputs, outputs, method_name)
291    self.assertTrue(
292        signature_def_utils_impl.is_valid_signature(signature_def))
293
294  def _assertInvalidSignature(self, inputs, outputs, method_name):
295    signature_def = signature_def_utils_impl.build_signature_def(
296        inputs, outputs, method_name)
297    self.assertFalse(
298        signature_def_utils_impl.is_valid_signature(signature_def))
299
300  def testValidSignaturesAreAccepted(self):
301    self._assertValidSignature(
302        {"inputs": _STRING},
303        {"classes": _STRING, "scores": _FLOAT},
304        signature_constants.CLASSIFY_METHOD_NAME)
305
306    self._assertValidSignature(
307        {"inputs": _STRING},
308        {"classes": _STRING},
309        signature_constants.CLASSIFY_METHOD_NAME)
310
311    self._assertValidSignature(
312        {"inputs": _STRING},
313        {"scores": _FLOAT},
314        signature_constants.CLASSIFY_METHOD_NAME)
315
316    self._assertValidSignature(
317        {"inputs": _STRING},
318        {"outputs": _FLOAT},
319        signature_constants.REGRESS_METHOD_NAME)
320
321    self._assertValidSignature(
322        {"foo": _STRING, "bar": _FLOAT},
323        {"baz": _STRING, "qux": _FLOAT},
324        signature_constants.PREDICT_METHOD_NAME)
325
326  def testInvalidMethodNameSignatureIsRejected(self):
327    # WRONG METHOD
328    self._assertInvalidSignature(
329        {"inputs": _STRING},
330        {"classes": _STRING, "scores": _FLOAT},
331        "WRONG method name")
332
333  def testInvalidClassificationSignaturesAreRejected(self):
334    # CLASSIFY: wrong types
335    self._assertInvalidSignature(
336        {"inputs": _FLOAT},
337        {"classes": _STRING, "scores": _FLOAT},
338        signature_constants.CLASSIFY_METHOD_NAME)
339
340    self._assertInvalidSignature(
341        {"inputs": _STRING},
342        {"classes": _FLOAT, "scores": _FLOAT},
343        signature_constants.CLASSIFY_METHOD_NAME)
344
345    self._assertInvalidSignature(
346        {"inputs": _STRING},
347        {"classes": _STRING, "scores": _STRING},
348        signature_constants.CLASSIFY_METHOD_NAME)
349
350    # CLASSIFY: wrong keys
351    self._assertInvalidSignature(
352        {},
353        {"classes": _STRING, "scores": _FLOAT},
354        signature_constants.CLASSIFY_METHOD_NAME)
355
356    self._assertInvalidSignature(
357        {"inputs_WRONG": _STRING},
358        {"classes": _STRING, "scores": _FLOAT},
359        signature_constants.CLASSIFY_METHOD_NAME)
360
361    self._assertInvalidSignature(
362        {"inputs": _STRING},
363        {"classes_WRONG": _STRING, "scores": _FLOAT},
364        signature_constants.CLASSIFY_METHOD_NAME)
365
366    self._assertInvalidSignature(
367        {"inputs": _STRING},
368        {},
369        signature_constants.CLASSIFY_METHOD_NAME)
370
371    self._assertInvalidSignature(
372        {"inputs": _STRING},
373        {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING},
374        signature_constants.CLASSIFY_METHOD_NAME)
375
376  def testInvalidRegressionSignaturesAreRejected(self):
377    # REGRESS: wrong types
378    self._assertInvalidSignature(
379        {"inputs": _FLOAT},
380        {"outputs": _FLOAT},
381        signature_constants.REGRESS_METHOD_NAME)
382
383    self._assertInvalidSignature(
384        {"inputs": _STRING},
385        {"outputs": _STRING},
386        signature_constants.REGRESS_METHOD_NAME)
387
388    # REGRESS: wrong keys
389    self._assertInvalidSignature(
390        {},
391        {"outputs": _FLOAT},
392        signature_constants.REGRESS_METHOD_NAME)
393
394    self._assertInvalidSignature(
395        {"inputs_WRONG": _STRING},
396        {"outputs": _FLOAT},
397        signature_constants.REGRESS_METHOD_NAME)
398
399    self._assertInvalidSignature(
400        {"inputs": _STRING},
401        {"outputs_WRONG": _FLOAT},
402        signature_constants.REGRESS_METHOD_NAME)
403
404    self._assertInvalidSignature(
405        {"inputs": _STRING},
406        {},
407        signature_constants.REGRESS_METHOD_NAME)
408
409    self._assertInvalidSignature(
410        {"inputs": _STRING},
411        {"outputs": _FLOAT, "extra_WRONG": _STRING},
412        signature_constants.REGRESS_METHOD_NAME)
413
414  def testInvalidPredictSignaturesAreRejected(self):
415    # PREDICT: wrong keys
416    self._assertInvalidSignature(
417        {},
418        {"baz": _STRING, "qux": _FLOAT},
419        signature_constants.PREDICT_METHOD_NAME)
420
421    self._assertInvalidSignature(
422        {"foo": _STRING, "bar": _FLOAT},
423        {},
424        signature_constants.PREDICT_METHOD_NAME)
425
426  @test_util.run_v1_only("b/120545219")
427  def testOpSignatureDef(self):
428    key = "adding_1_and_2_key"
429    add_op = math_ops.add(1, 2, name="adding_1_and_2")
430    signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
431    self.assertIn(key, signature_def.outputs)
432    self.assertEqual(add_op.name, signature_def.outputs[key].name)
433
434  @test_util.run_v1_only("b/120545219")
435  def testLoadOpFromSignatureDef(self):
436    key = "adding_1_and_2_key"
437    add_op = math_ops.add(1, 2, name="adding_1_and_2")
438    signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
439
440    self.assertEqual(
441        add_op,
442        signature_def_utils_impl.load_op_from_signature_def(signature_def, key))
443
444
445if __name__ == "__main__":
446  test.main()
447