1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SignatureDef utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import types_pb2
22from tensorflow.core.protobuf import meta_graph_pb2
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29from tensorflow.python.saved_model import signature_constants
30from tensorflow.python.saved_model import signature_def_utils_impl
31from tensorflow.python.saved_model import utils
32
33
34# We'll reuse the same tensor_infos in multiple contexts just for the tests.
35# The validator doesn't check shapes so we just omit them.
36_STRING = meta_graph_pb2.TensorInfo(
37    name="foobar",
38    dtype=dtypes.string.as_datatype_enum
39)
40
41
42_FLOAT = meta_graph_pb2.TensorInfo(
43    name="foobar",
44    dtype=dtypes.float32.as_datatype_enum
45)
46
47
48def _make_signature(inputs, outputs, name=None):
49  input_info = {
50      input_name: utils.build_tensor_info(tensor)
51      for input_name, tensor in inputs.items()
52  }
53  output_info = {
54      output_name: utils.build_tensor_info(tensor)
55      for output_name, tensor in outputs.items()
56  }
57  return signature_def_utils_impl.build_signature_def(input_info, output_info,
58                                                      name)
59
60
61class SignatureDefUtilsTest(test.TestCase):
62
63  def testBuildSignatureDef(self):
64    # Force the test to run in graph mode.
65    # This tests a deprecated v1 API that uses functionality that does not work
66    # with eager tensors (namely build_tensor_info).
67    with ops.Graph().as_default():
68      x = array_ops.placeholder(dtypes.float32, 1, name="x")
69      x_tensor_info = utils.build_tensor_info(x)
70      inputs = {}
71      inputs["foo-input"] = x_tensor_info
72
73      y = array_ops.placeholder(dtypes.float32, name="y")
74      y_tensor_info = utils.build_tensor_info(y)
75      outputs = {}
76      outputs["foo-output"] = y_tensor_info
77
78    signature_def = signature_def_utils_impl.build_signature_def(
79        inputs, outputs, "foo-method-name")
80    self.assertEqual("foo-method-name", signature_def.method_name)
81
82    # Check inputs in signature def.
83    self.assertEqual(1, len(signature_def.inputs))
84    x_tensor_info_actual = signature_def.inputs["foo-input"]
85    self.assertEqual("x:0", x_tensor_info_actual.name)
86    self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
87    self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
88    self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
89
90    # Check outputs in signature def.
91    self.assertEqual(1, len(signature_def.outputs))
92    y_tensor_info_actual = signature_def.outputs["foo-output"]
93    self.assertEqual("y:0", y_tensor_info_actual.name)
94    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
95    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
96
97  def testRegressionSignatureDef(self):
98    # Force the test to run in graph mode.
99    # This tests a deprecated v1 API that uses functionality that does not work
100    # with eager tensors (namely build_tensor_info).
101    with ops.Graph().as_default():
102      input1 = constant_op.constant("a", name="input-1")
103      output1 = constant_op.constant(2.2, name="output-1")
104      signature_def = signature_def_utils_impl.regression_signature_def(
105          input1, output1)
106
107    self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
108                     signature_def.method_name)
109
110    # Check inputs in signature def.
111    self.assertEqual(1, len(signature_def.inputs))
112    x_tensor_info_actual = (
113        signature_def.inputs[signature_constants.REGRESS_INPUTS])
114    self.assertEqual("input-1:0", x_tensor_info_actual.name)
115    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
116    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
117
118    # Check outputs in signature def.
119    self.assertEqual(1, len(signature_def.outputs))
120    y_tensor_info_actual = (
121        signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
122    self.assertEqual("output-1:0", y_tensor_info_actual.name)
123    self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
124    self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
125
126  def testClassificationSignatureDef(self):
127    # Force the test to run in graph mode.
128    # This tests a deprecated v1 API that uses functionality that does not work
129    # with eager tensors (namely build_tensor_info).
130    with ops.Graph().as_default():
131      input1 = constant_op.constant("a", name="input-1")
132      output1 = constant_op.constant("b", name="output-1")
133      output2 = constant_op.constant(3.3, name="output-2")
134      signature_def = signature_def_utils_impl.classification_signature_def(
135          input1, output1, output2)
136
137    self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
138                     signature_def.method_name)
139
140    # Check inputs in signature def.
141    self.assertEqual(1, len(signature_def.inputs))
142    x_tensor_info_actual = (
143        signature_def.inputs[signature_constants.CLASSIFY_INPUTS])
144    self.assertEqual("input-1:0", x_tensor_info_actual.name)
145    self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
146    self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
147
148    # Check outputs in signature def.
149    self.assertEqual(2, len(signature_def.outputs))
150    classes_tensor_info_actual = (
151        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES])
152    self.assertEqual("output-1:0", classes_tensor_info_actual.name)
153    self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype)
154    self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim))
155    scores_tensor_info_actual = (
156        signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
157    self.assertEqual("output-2:0", scores_tensor_info_actual.name)
158    self.assertEqual(types_pb2.DT_FLOAT, scores_tensor_info_actual.dtype)
159    self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
160
161  def testPredictionSignatureDef(self):
162    # Force the test to run in graph mode.
163    # This tests a deprecated v1 API that uses functionality that does not work
164    # with eager tensors (namely build_tensor_info).
165    with ops.Graph().as_default():
166      input1 = constant_op.constant("a", name="input-1")
167      input2 = constant_op.constant("b", name="input-2")
168      output1 = constant_op.constant("c", name="output-1")
169      output2 = constant_op.constant("d", name="output-2")
170      signature_def = signature_def_utils_impl.predict_signature_def(
171          {
172              "input-1": input1,
173              "input-2": input2
174          }, {
175              "output-1": output1,
176              "output-2": output2
177          })
178
179    self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
180                     signature_def.method_name)
181
182    # Check inputs in signature def.
183    self.assertEqual(2, len(signature_def.inputs))
184    input1_tensor_info_actual = (signature_def.inputs["input-1"])
185    self.assertEqual("input-1:0", input1_tensor_info_actual.name)
186    self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
187    self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
188    input2_tensor_info_actual = (signature_def.inputs["input-2"])
189    self.assertEqual("input-2:0", input2_tensor_info_actual.name)
190    self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
191    self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
192
193    # Check outputs in signature def.
194    self.assertEqual(2, len(signature_def.outputs))
195    output1_tensor_info_actual = (signature_def.outputs["output-1"])
196    self.assertEqual("output-1:0", output1_tensor_info_actual.name)
197    self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype)
198    self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim))
199    output2_tensor_info_actual = (signature_def.outputs["output-2"])
200    self.assertEqual("output-2:0", output2_tensor_info_actual.name)
201    self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
202    self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
203
204  def testTrainSignatureDef(self):
205    self._testSupervisedSignatureDef(
206        signature_def_utils_impl.supervised_train_signature_def,
207        signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
208
209  def testEvalSignatureDef(self):
210    self._testSupervisedSignatureDef(
211        signature_def_utils_impl.supervised_eval_signature_def,
212        signature_constants.SUPERVISED_EVAL_METHOD_NAME)
213
214  def _testSupervisedSignatureDef(self, fn_to_test, method_name):
215    # Force the test to run in graph mode.
216    # This tests a deprecated v1 API that uses functionality that does not work
217    # with eager tensors (namely build_tensor_info).
218    with ops.Graph().as_default():
219      inputs = {
220          "input-1": constant_op.constant("a", name="input-1"),
221          "input-2": constant_op.constant("b", name="input-2"),
222      }
223      loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
224      predictions = {
225          "classes": constant_op.constant([100], name="classes"),
226      }
227      metrics_val = constant_op.constant(100.0, name="metrics_val")
228      metrics = {
229          "metrics/value":
230              metrics_val,
231          "metrics/update_op":
232              array_ops.identity(metrics_val, name="metrics_op"),
233      }
234
235      signature_def = fn_to_test(inputs, loss, predictions, metrics)
236
237    self.assertEqual(method_name, signature_def.method_name)
238
239    # Check inputs in signature def.
240    self.assertEqual(2, len(signature_def.inputs))
241    input1_tensor_info_actual = (signature_def.inputs["input-1"])
242    self.assertEqual("input-1:0", input1_tensor_info_actual.name)
243    self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
244    self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
245    input2_tensor_info_actual = (signature_def.inputs["input-2"])
246    self.assertEqual("input-2:0", input2_tensor_info_actual.name)
247    self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
248    self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
249
250    # Check outputs in signature def.
251    self.assertEqual(4, len(signature_def.outputs))
252    self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name)
253    self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype)
254
255    self.assertEqual("classes:0", signature_def.outputs["classes"].name)
256    self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim))
257
258    self.assertEqual(
259        "metrics_val:0", signature_def.outputs["metrics/value"].name)
260    self.assertEqual(
261        types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
262
263    self.assertEqual(
264        "metrics_op:0", signature_def.outputs["metrics/update_op"].name)
265    self.assertEqual(
266        types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
267
268  def testTrainSignatureDefMissingInputs(self):
269    self._testSupervisedSignatureDefMissingInputs(
270        signature_def_utils_impl.supervised_train_signature_def,
271        signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
272
273  def testEvalSignatureDefMissingInputs(self):
274    self._testSupervisedSignatureDefMissingInputs(
275        signature_def_utils_impl.supervised_eval_signature_def,
276        signature_constants.SUPERVISED_EVAL_METHOD_NAME)
277
278  def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name):
279    # Force the test to run in graph mode.
280    # This tests a deprecated v1 API that uses functionality that does not work
281    # with eager tensors (namely build_tensor_info).
282    with ops.Graph().as_default():
283      inputs = {
284          "input-1": constant_op.constant("a", name="input-1"),
285          "input-2": constant_op.constant("b", name="input-2"),
286      }
287      loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
288      predictions = {
289          "classes": constant_op.constant([100], name="classes"),
290      }
291      metrics_val = constant_op.constant(100, name="metrics_val")
292      metrics = {
293          "metrics/value":
294              metrics_val,
295          "metrics/update_op":
296              array_ops.identity(metrics_val, name="metrics_op"),
297      }
298
299      with self.assertRaises(ValueError):
300        signature_def = fn_to_test({},
301                                   loss=loss,
302                                   predictions=predictions,
303                                   metrics=metrics)
304
305      signature_def = fn_to_test(inputs, loss=loss)
306      self.assertEqual(method_name, signature_def.method_name)
307      self.assertEqual(1, len(signature_def.outputs))
308
309      signature_def = fn_to_test(inputs, metrics=metrics, loss=loss)
310      self.assertEqual(method_name, signature_def.method_name)
311      self.assertEqual(3, len(signature_def.outputs))
312
313  def _assertValidSignature(self, inputs, outputs, method_name):
314    signature_def = signature_def_utils_impl.build_signature_def(
315        inputs, outputs, method_name)
316    self.assertTrue(
317        signature_def_utils_impl.is_valid_signature(signature_def))
318
319  def _assertInvalidSignature(self, inputs, outputs, method_name):
320    signature_def = signature_def_utils_impl.build_signature_def(
321        inputs, outputs, method_name)
322    self.assertFalse(
323        signature_def_utils_impl.is_valid_signature(signature_def))
324
325  def testValidSignaturesAreAccepted(self):
326    self._assertValidSignature(
327        {"inputs": _STRING},
328        {"classes": _STRING, "scores": _FLOAT},
329        signature_constants.CLASSIFY_METHOD_NAME)
330
331    self._assertValidSignature(
332        {"inputs": _STRING},
333        {"classes": _STRING},
334        signature_constants.CLASSIFY_METHOD_NAME)
335
336    self._assertValidSignature(
337        {"inputs": _STRING},
338        {"scores": _FLOAT},
339        signature_constants.CLASSIFY_METHOD_NAME)
340
341    self._assertValidSignature(
342        {"inputs": _STRING},
343        {"outputs": _FLOAT},
344        signature_constants.REGRESS_METHOD_NAME)
345
346    self._assertValidSignature(
347        {"foo": _STRING, "bar": _FLOAT},
348        {"baz": _STRING, "qux": _FLOAT},
349        signature_constants.PREDICT_METHOD_NAME)
350
351  def testInvalidMethodNameSignatureIsRejected(self):
352    # WRONG METHOD
353    self._assertInvalidSignature(
354        {"inputs": _STRING},
355        {"classes": _STRING, "scores": _FLOAT},
356        "WRONG method name")
357
358  def testInvalidClassificationSignaturesAreRejected(self):
359    # CLASSIFY: wrong types
360    self._assertInvalidSignature(
361        {"inputs": _FLOAT},
362        {"classes": _STRING, "scores": _FLOAT},
363        signature_constants.CLASSIFY_METHOD_NAME)
364
365    self._assertInvalidSignature(
366        {"inputs": _STRING},
367        {"classes": _FLOAT, "scores": _FLOAT},
368        signature_constants.CLASSIFY_METHOD_NAME)
369
370    self._assertInvalidSignature(
371        {"inputs": _STRING},
372        {"classes": _STRING, "scores": _STRING},
373        signature_constants.CLASSIFY_METHOD_NAME)
374
375    # CLASSIFY: wrong keys
376    self._assertInvalidSignature(
377        {},
378        {"classes": _STRING, "scores": _FLOAT},
379        signature_constants.CLASSIFY_METHOD_NAME)
380
381    self._assertInvalidSignature(
382        {"inputs_WRONG": _STRING},
383        {"classes": _STRING, "scores": _FLOAT},
384        signature_constants.CLASSIFY_METHOD_NAME)
385
386    self._assertInvalidSignature(
387        {"inputs": _STRING},
388        {"classes_WRONG": _STRING, "scores": _FLOAT},
389        signature_constants.CLASSIFY_METHOD_NAME)
390
391    self._assertInvalidSignature(
392        {"inputs": _STRING},
393        {},
394        signature_constants.CLASSIFY_METHOD_NAME)
395
396    self._assertInvalidSignature(
397        {"inputs": _STRING},
398        {"classes": _STRING, "scores": _FLOAT, "extra_WRONG": _STRING},
399        signature_constants.CLASSIFY_METHOD_NAME)
400
401  def testInvalidRegressionSignaturesAreRejected(self):
402    # REGRESS: wrong types
403    self._assertInvalidSignature(
404        {"inputs": _FLOAT},
405        {"outputs": _FLOAT},
406        signature_constants.REGRESS_METHOD_NAME)
407
408    self._assertInvalidSignature(
409        {"inputs": _STRING},
410        {"outputs": _STRING},
411        signature_constants.REGRESS_METHOD_NAME)
412
413    # REGRESS: wrong keys
414    self._assertInvalidSignature(
415        {},
416        {"outputs": _FLOAT},
417        signature_constants.REGRESS_METHOD_NAME)
418
419    self._assertInvalidSignature(
420        {"inputs_WRONG": _STRING},
421        {"outputs": _FLOAT},
422        signature_constants.REGRESS_METHOD_NAME)
423
424    self._assertInvalidSignature(
425        {"inputs": _STRING},
426        {"outputs_WRONG": _FLOAT},
427        signature_constants.REGRESS_METHOD_NAME)
428
429    self._assertInvalidSignature(
430        {"inputs": _STRING},
431        {},
432        signature_constants.REGRESS_METHOD_NAME)
433
434    self._assertInvalidSignature(
435        {"inputs": _STRING},
436        {"outputs": _FLOAT, "extra_WRONG": _STRING},
437        signature_constants.REGRESS_METHOD_NAME)
438
439  def testInvalidPredictSignaturesAreRejected(self):
440    # PREDICT: wrong keys
441    self._assertInvalidSignature(
442        {},
443        {"baz": _STRING, "qux": _FLOAT},
444        signature_constants.PREDICT_METHOD_NAME)
445
446    self._assertInvalidSignature(
447        {"foo": _STRING, "bar": _FLOAT},
448        {},
449        signature_constants.PREDICT_METHOD_NAME)
450
451  def testOpSignatureDef(self):
452    # Force the test to run in graph mode.
453    # This tests a deprecated v1 API that uses functionality that does not work
454    # with eager tensors (namely build_tensor_info_from_op).
455    with ops.Graph().as_default():
456      key = "adding_1_and_2_key"
457      add_op = math_ops.add(1, 2, name="adding_1_and_2")
458      signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
459
460    self.assertIn(key, signature_def.outputs)
461    self.assertEqual(add_op.name, signature_def.outputs[key].name)
462
463  def testLoadOpFromSignatureDef(self):
464    # Force the test to run in graph mode.
465    # This tests a deprecated v1 API that uses functionality that does not work
466    # with eager tensors (namely build_tensor_info_from_op).
467    with ops.Graph().as_default():
468      key = "adding_1_and_2_key"
469      add_op = math_ops.add(1, 2, name="adding_1_and_2")
470      signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
471      self.assertEqual(
472          add_op,
473          signature_def_utils_impl.load_op_from_signature_def(
474              signature_def, key))
475
476
477if __name__ == "__main__":
478  test.main()
479