1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SignatureDef utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.protobuf import meta_graph_pb2
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.saved_model import signature_constants
27from tensorflow.python.saved_model import utils_impl as utils
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export(
33    v1=[
34        'saved_model.build_signature_def',
35        'saved_model.signature_def_utils.build_signature_def'
36    ])
37@deprecation.deprecated_endpoints(
38    'saved_model.signature_def_utils.build_signature_def')
39def build_signature_def(inputs=None, outputs=None, method_name=None):
40  """Utility function to build a SignatureDef protocol buffer.
41
42  Args:
43    inputs: Inputs of the SignatureDef defined as a proto map of string to
44        tensor info.
45    outputs: Outputs of the SignatureDef defined as a proto map of string to
46        tensor info.
47    method_name: Method name of the SignatureDef as a string.
48
49  Returns:
50    A SignatureDef protocol buffer constructed based on the supplied arguments.
51  """
52  signature_def = meta_graph_pb2.SignatureDef()
53  if inputs is not None:
54    for item in inputs:
55      signature_def.inputs[item].CopyFrom(inputs[item])
56  if outputs is not None:
57    for item in outputs:
58      signature_def.outputs[item].CopyFrom(outputs[item])
59  if method_name is not None:
60    signature_def.method_name = method_name
61  return signature_def
62
63
64@tf_export(
65    v1=[
66        'saved_model.regression_signature_def',
67        'saved_model.signature_def_utils.regression_signature_def'
68    ])
69@deprecation.deprecated_endpoints(
70    'saved_model.signature_def_utils.regression_signature_def')
71def regression_signature_def(examples, predictions):
72  """Creates regression signature from given examples and predictions.
73
74  This function produces signatures intended for use with the TensorFlow Serving
75  Regress API (tensorflow_serving/apis/prediction_service.proto), and so
76  constrains the input and output types to those allowed by TensorFlow Serving.
77
78  Args:
79    examples: A string `Tensor`, expected to accept serialized tf.Examples.
80    predictions: A float `Tensor`.
81
82  Returns:
83    A regression-flavored signature_def.
84
85  Raises:
86    ValueError: If examples is `None`.
87  """
88  if examples is None:
89    raise ValueError('Regression examples cannot be None.')
90  if not isinstance(examples, ops.Tensor):
91    raise ValueError('Regression examples must be a string Tensor.')
92  if predictions is None:
93    raise ValueError('Regression predictions cannot be None.')
94
95  input_tensor_info = utils.build_tensor_info(examples)
96  if input_tensor_info.dtype != types_pb2.DT_STRING:
97    raise ValueError('Regression examples must be a string Tensor.')
98  signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
99
100  output_tensor_info = utils.build_tensor_info(predictions)
101  if output_tensor_info.dtype != types_pb2.DT_FLOAT:
102    raise ValueError('Regression output must be a float Tensor.')
103  signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
104
105  signature_def = build_signature_def(
106      signature_inputs, signature_outputs,
107      signature_constants.REGRESS_METHOD_NAME)
108
109  return signature_def
110
111
112@tf_export(
113    v1=[
114        'saved_model.classification_signature_def',
115        'saved_model.signature_def_utils.classification_signature_def'
116    ])
117@deprecation.deprecated_endpoints(
118    'saved_model.signature_def_utils.classification_signature_def')
119def classification_signature_def(examples, classes, scores):
120  """Creates classification signature from given examples and predictions.
121
122  This function produces signatures intended for use with the TensorFlow Serving
123  Classify API (tensorflow_serving/apis/prediction_service.proto), and so
124  constrains the input and output types to those allowed by TensorFlow Serving.
125
126  Args:
127    examples: A string `Tensor`, expected to accept serialized tf.Examples.
128    classes: A string `Tensor`.  Note that the ClassificationResponse message
129      requires that class labels are strings, not integers or anything else.
130    scores: a float `Tensor`.
131
132  Returns:
133    A classification-flavored signature_def.
134
135  Raises:
136    ValueError: If examples is `None`.
137  """
138  if examples is None:
139    raise ValueError('Classification examples cannot be None.')
140  if not isinstance(examples, ops.Tensor):
141    raise ValueError('Classification examples must be a string Tensor.')
142  if classes is None and scores is None:
143    raise ValueError('Classification classes and scores cannot both be None.')
144
145  input_tensor_info = utils.build_tensor_info(examples)
146  if input_tensor_info.dtype != types_pb2.DT_STRING:
147    raise ValueError('Classification examples must be a string Tensor.')
148  signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
149
150  signature_outputs = {}
151  if classes is not None:
152    classes_tensor_info = utils.build_tensor_info(classes)
153    if classes_tensor_info.dtype != types_pb2.DT_STRING:
154      raise ValueError('Classification classes must be a string Tensor.')
155    signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
156        classes_tensor_info)
157  if scores is not None:
158    scores_tensor_info = utils.build_tensor_info(scores)
159    if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
160      raise ValueError('Classification scores must be a float Tensor.')
161    signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
162        scores_tensor_info)
163
164  signature_def = build_signature_def(
165      signature_inputs, signature_outputs,
166      signature_constants.CLASSIFY_METHOD_NAME)
167
168  return signature_def
169
170
171@tf_export(
172    v1=[
173        'saved_model.predict_signature_def',
174        'saved_model.signature_def_utils.predict_signature_def'
175    ])
176@deprecation.deprecated_endpoints(
177    'saved_model.signature_def_utils.predict_signature_def')
178def predict_signature_def(inputs, outputs):
179  """Creates prediction signature from given inputs and outputs.
180
181  This function produces signatures intended for use with the TensorFlow Serving
182  Predict API (tensorflow_serving/apis/prediction_service.proto). This API
183  imposes no constraints on the input and output types.
184
185  Args:
186    inputs: dict of string to `Tensor`.
187    outputs: dict of string to `Tensor`.
188
189  Returns:
190    A prediction-flavored signature_def.
191
192  Raises:
193    ValueError: If inputs or outputs is `None`.
194  """
195  if inputs is None or not inputs:
196    raise ValueError('Prediction inputs cannot be None or empty.')
197  if outputs is None or not outputs:
198    raise ValueError('Prediction outputs cannot be None or empty.')
199
200  signature_inputs = {key: utils.build_tensor_info(tensor)
201                      for key, tensor in inputs.items()}
202  signature_outputs = {key: utils.build_tensor_info(tensor)
203                       for key, tensor in outputs.items()}
204
205  signature_def = build_signature_def(
206      signature_inputs, signature_outputs,
207      signature_constants.PREDICT_METHOD_NAME)
208
209  return signature_def
210
211
212def supervised_train_signature_def(
213    inputs, loss, predictions=None, metrics=None):
214  return _supervised_signature_def(
215      signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
216      predictions=predictions, metrics=metrics)
217
218
219def supervised_eval_signature_def(
220    inputs, loss, predictions=None, metrics=None):
221  return _supervised_signature_def(
222      signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
223      predictions=predictions, metrics=metrics)
224
225
226def _supervised_signature_def(
227    method_name, inputs, loss=None, predictions=None,
228    metrics=None):
229  """Creates a signature for training and eval data.
230
231  This function produces signatures that describe the inputs and outputs
232  of a supervised process, such as training or evaluation, that
233  results in loss, metrics, and the like. Note that this function only requires
234  inputs to be not None.
235
236  Args:
237    method_name: Method name of the SignatureDef as a string.
238    inputs: dict of string to `Tensor`.
239    loss: dict of string to `Tensor` representing computed loss.
240    predictions: dict of string to `Tensor` representing the output predictions.
241    metrics: dict of string to `Tensor` representing metric ops.
242
243  Returns:
244    A train- or eval-flavored signature_def.
245
246  Raises:
247    ValueError: If inputs or outputs is `None`.
248  """
249  if inputs is None or not inputs:
250    raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
251
252  signature_inputs = {key: utils.build_tensor_info(tensor)
253                      for key, tensor in inputs.items()}
254
255  signature_outputs = {}
256  for output_set in (loss, predictions, metrics):
257    if output_set is not None:
258      sig_out = {key: utils.build_tensor_info(tensor)
259                 for key, tensor in output_set.items()}
260      signature_outputs.update(sig_out)
261
262  signature_def = build_signature_def(
263      signature_inputs, signature_outputs, method_name)
264
265  return signature_def
266
267
268@tf_export(
269    v1=[
270        'saved_model.is_valid_signature',
271        'saved_model.signature_def_utils.is_valid_signature'
272    ])
273@deprecation.deprecated_endpoints(
274    'saved_model.signature_def_utils.is_valid_signature')
275def is_valid_signature(signature_def):
276  """Determine whether a SignatureDef can be served by TensorFlow Serving."""
277  if signature_def is None:
278    return False
279  return (_is_valid_classification_signature(signature_def) or
280          _is_valid_regression_signature(signature_def) or
281          _is_valid_predict_signature(signature_def))
282
283
284def _is_valid_predict_signature(signature_def):
285  """Determine whether the argument is a servable 'predict' SignatureDef."""
286  if signature_def.method_name != signature_constants.PREDICT_METHOD_NAME:
287    return False
288  if not signature_def.inputs.keys():
289    return False
290  if not signature_def.outputs.keys():
291    return False
292  return True
293
294
295def _is_valid_regression_signature(signature_def):
296  """Determine whether the argument is a servable 'regress' SignatureDef."""
297  if signature_def.method_name != signature_constants.REGRESS_METHOD_NAME:
298    return False
299
300  if (set(signature_def.inputs.keys())
301      != set([signature_constants.REGRESS_INPUTS])):
302    return False
303  if (signature_def.inputs[signature_constants.REGRESS_INPUTS].dtype !=
304      types_pb2.DT_STRING):
305    return False
306
307  if (set(signature_def.outputs.keys())
308      != set([signature_constants.REGRESS_OUTPUTS])):
309    return False
310  if (signature_def.outputs[signature_constants.REGRESS_OUTPUTS].dtype !=
311      types_pb2.DT_FLOAT):
312    return False
313
314  return True
315
316
317def _is_valid_classification_signature(signature_def):
318  """Determine whether the argument is a servable 'classify' SignatureDef."""
319  if signature_def.method_name != signature_constants.CLASSIFY_METHOD_NAME:
320    return False
321
322  if (set(signature_def.inputs.keys())
323      != set([signature_constants.CLASSIFY_INPUTS])):
324    return False
325  if (signature_def.inputs[signature_constants.CLASSIFY_INPUTS].dtype !=
326      types_pb2.DT_STRING):
327    return False
328
329  allowed_outputs = set([signature_constants.CLASSIFY_OUTPUT_CLASSES,
330                         signature_constants.CLASSIFY_OUTPUT_SCORES])
331
332  if not signature_def.outputs.keys():
333    return False
334  if set(signature_def.outputs.keys()) - allowed_outputs:
335    return False
336  if (signature_constants.CLASSIFY_OUTPUT_CLASSES in signature_def.outputs
337      and
338      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES].dtype
339      != types_pb2.DT_STRING):
340    return False
341  if (signature_constants.CLASSIFY_OUTPUT_SCORES in signature_def.outputs
342      and
343      signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES].dtype !=
344      types_pb2.DT_FLOAT):
345    return False
346
347  return True
348
349
350def op_signature_def(op, key):
351  """Creates a signature def with the output pointing to an op.
352
353  Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
354  It is recommended to use the build_signature_def() function for Tensors.
355
356  Args:
357    op: An Op (or possibly Tensor).
358    key: Key to graph element in the SignatureDef outputs.
359
360  Returns:
361    A SignatureDef with a single output pointing to the op.
362  """
363  # Use build_tensor_info_from_op, which creates a TensorInfo from the element's
364  # name.
365  return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
366
367
368def load_op_from_signature_def(signature_def, key, import_scope=None):
369  """Load an Op from a SignatureDef created by op_signature_def().
370
371  Args:
372    signature_def: a SignatureDef proto
373    key: string key to op in the SignatureDef outputs.
374    import_scope: Scope used to import the op
375
376  Returns:
377    Op (or possibly Tensor) in the graph with the same name as saved in the
378      SignatureDef.
379
380  Raises:
381    NotFoundError: If the op could not be found in the graph.
382  """
383  tensor_info = signature_def.outputs[key]
384  try:
385    # The init and train ops are not strictly enforced to be operations, so
386    # retrieve any graph element (can be either op or tensor).
387    return utils.get_element_from_tensor_info(
388        tensor_info, import_scope=import_scope)
389  except KeyError:
390    raise errors.NotFoundError(
391        None, None,
392        'The {0} could not be found in the graph. Please make sure the '
393        'SavedModel was created by the internal _SavedModelBuilder. If you '
394        'are using the public API, please make sure the SignatureDef in the '
395        'SavedModel does not contain the key "{0}".'.format(key))
396