1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for creating SavedModels.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import time 24 25import six 26 27from tensorflow.python.platform import gfile 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.saved_model import signature_constants 30from tensorflow.python.saved_model import signature_def_utils 31from tensorflow.python.saved_model import tag_constants 32from tensorflow.python.saved_model.model_utils import export_output as export_output_lib 33from tensorflow.python.saved_model.model_utils import mode_keys 34from tensorflow.python.saved_model.model_utils.mode_keys import KerasModeKeys as ModeKeys 35from tensorflow.python.util import compat 36 37 38# Mapping of the modes to appropriate MetaGraph tags in the SavedModel. 39EXPORT_TAG_MAP = mode_keys.ModeKeyMap(**{ 40 ModeKeys.PREDICT: [tag_constants.SERVING], 41 ModeKeys.TRAIN: [tag_constants.TRAINING], 42 ModeKeys.TEST: [tag_constants.EVAL]}) 43 44# For every exported mode, a SignatureDef map should be created using the 45# functions `export_outputs_for_mode` and `build_all_signature_defs`. By 46# default, this map will contain a single Signature that defines the input 47# tensors and output predictions, losses, and/or metrics (depending on the mode) 48# The default keys used in the SignatureDef map are defined below. 49SIGNATURE_KEY_MAP = mode_keys.ModeKeyMap(**{ 50 ModeKeys.PREDICT: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, 51 ModeKeys.TRAIN: signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY, 52 ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY}) 53 54# Default names used in the SignatureDef input map, which maps strings to 55# TensorInfo protos. 56SINGLE_FEATURE_DEFAULT_NAME = 'feature' 57SINGLE_RECEIVER_DEFAULT_NAME = 'input' 58SINGLE_LABEL_DEFAULT_NAME = 'label' 59 60### Below utilities are specific to SavedModel exports. 61 62 63def build_all_signature_defs(receiver_tensors, 64 export_outputs, 65 receiver_tensors_alternatives=None, 66 serving_only=True): 67 """Build `SignatureDef`s for all export outputs. 68 69 Args: 70 receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying 71 input nodes where this receiver expects to be fed by default. Typically, 72 this is a single placeholder expecting serialized `tf.Example` protos. 73 export_outputs: a dict of ExportOutput instances, each of which has 74 an as_signature_def instance method that will be called to retrieve 75 the signature_def for all export output tensors. 76 receiver_tensors_alternatives: a dict of string to additional 77 groups of receiver tensors, each of which may be a `Tensor` or a dict of 78 string to `Tensor`. These named receiver tensor alternatives generate 79 additional serving signatures, which may be used to feed inputs at 80 different points within the input receiver subgraph. A typical usage is 81 to allow feeding raw feature `Tensor`s *downstream* of the 82 tf.io.parse_example() op. Defaults to None. 83 serving_only: boolean; if true, resulting signature defs will only include 84 valid serving signatures. If false, all requested signatures will be 85 returned. 86 87 Returns: 88 signature_def representing all passed args. 89 90 Raises: 91 ValueError: if export_outputs is not a dict 92 """ 93 if not isinstance(receiver_tensors, dict): 94 receiver_tensors = {SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} 95 if export_outputs is None or not isinstance(export_outputs, dict): 96 raise ValueError('export_outputs must be a dict and not' 97 '{}'.format(type(export_outputs))) 98 99 signature_def_map = {} 100 excluded_signatures = {} 101 for output_key, export_output in export_outputs.items(): 102 signature_name = '{}'.format(output_key or 'None') 103 try: 104 signature = export_output.as_signature_def(receiver_tensors) 105 signature_def_map[signature_name] = signature 106 except ValueError as e: 107 excluded_signatures[signature_name] = str(e) 108 109 if receiver_tensors_alternatives: 110 for receiver_name, receiver_tensors_alt in ( 111 six.iteritems(receiver_tensors_alternatives)): 112 if not isinstance(receiver_tensors_alt, dict): 113 receiver_tensors_alt = { 114 SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors_alt 115 } 116 for output_key, export_output in export_outputs.items(): 117 signature_name = '{}:{}'.format(receiver_name or 'None', output_key or 118 'None') 119 try: 120 signature = export_output.as_signature_def(receiver_tensors_alt) 121 signature_def_map[signature_name] = signature 122 except ValueError as e: 123 excluded_signatures[signature_name] = str(e) 124 125 _log_signature_report(signature_def_map, excluded_signatures) 126 127 # The above calls to export_output_lib.as_signature_def should return only 128 # valid signatures; if there is a validity problem, they raise a ValueError, 129 # in which case we exclude that signature from signature_def_map above. 130 # The is_valid_signature check ensures that the signatures produced are 131 # valid for serving, and acts as an additional sanity check for export 132 # signatures produced for serving. We skip this check for training and eval 133 # signatures, which are not intended for serving. 134 if serving_only: 135 signature_def_map = { 136 k: v 137 for k, v in signature_def_map.items() 138 if signature_def_utils.is_valid_signature(v) 139 } 140 return signature_def_map 141 142 143_FRIENDLY_METHOD_NAMES = { 144 signature_constants.CLASSIFY_METHOD_NAME: 'Classify', 145 signature_constants.REGRESS_METHOD_NAME: 'Regress', 146 signature_constants.PREDICT_METHOD_NAME: 'Predict', 147 signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', 148 signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', 149} 150 151 152def _log_signature_report(signature_def_map, excluded_signatures): 153 """Log a report of which signatures were produced.""" 154 sig_names_by_method_name = collections.defaultdict(list) 155 156 # We'll collect whatever method_names are present, but also we want to make 157 # sure to output a line for each of the three standard methods even if they 158 # have no signatures. 159 for method_name in _FRIENDLY_METHOD_NAMES: 160 sig_names_by_method_name[method_name] = [] 161 162 for signature_name, sig in signature_def_map.items(): 163 sig_names_by_method_name[sig.method_name].append(signature_name) 164 165 # TODO(b/67733540): consider printing the full signatures, not just names 166 for method_name, sig_names in sig_names_by_method_name.items(): 167 if method_name in _FRIENDLY_METHOD_NAMES: 168 method_name = _FRIENDLY_METHOD_NAMES[method_name] 169 logging.info('Signatures INCLUDED in export for {}: {}'.format( 170 method_name, sig_names if sig_names else 'None')) 171 172 if excluded_signatures: 173 logging.info('Signatures EXCLUDED from export because they cannot be ' 174 'be served via TensorFlow Serving APIs:') 175 for signature_name, message in excluded_signatures.items(): 176 logging.info('\'{}\' : {}'.format(signature_name, message)) 177 178 if not signature_def_map: 179 logging.warn('Export includes no signatures!') 180 elif (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY not in 181 signature_def_map): 182 logging.warn('Export includes no default signature!') 183 184 185# When we create a timestamped directory, there is a small chance that the 186# directory already exists because another process is also creating these 187# directories. In this case we just wait one second to get a new timestamp and 188# try again. If this fails several times in a row, then something is seriously 189# wrong. 190MAX_DIRECTORY_CREATION_ATTEMPTS = 10 191 192 193def get_timestamped_export_dir(export_dir_base): 194 """Builds a path to a new subdirectory within the base directory. 195 196 Each export is written into a new subdirectory named using the 197 current time. This guarantees monotonically increasing version 198 numbers even across multiple runs of the pipeline. 199 The timestamp used is the number of seconds since epoch UTC. 200 201 Args: 202 export_dir_base: A string containing a directory to write the exported 203 graph and checkpoints. 204 Returns: 205 The full path of the new subdirectory (which is not actually created yet). 206 207 Raises: 208 RuntimeError: if repeated attempts fail to obtain a unique timestamped 209 directory name. 210 """ 211 attempts = 0 212 while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: 213 timestamp = int(time.time()) 214 215 result_dir = os.path.join( 216 compat.as_bytes(export_dir_base), compat.as_bytes(str(timestamp))) 217 if not gfile.Exists(result_dir): 218 # Collisions are still possible (though extremely unlikely): this 219 # directory is not actually created yet, but it will be almost 220 # instantly on return from this function. 221 return result_dir 222 time.sleep(1) 223 attempts += 1 224 logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format( 225 compat.as_str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) 226 raise RuntimeError('Failed to obtain a unique export directory name after ' 227 '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) 228 229 230def get_temp_export_dir(timestamped_export_dir): 231 """Builds a directory name based on the argument but starting with 'temp-'. 232 233 This relies on the fact that TensorFlow Serving ignores subdirectories of 234 the base directory that can't be parsed as integers. 235 236 Args: 237 timestamped_export_dir: the name of the eventual export directory, e.g. 238 /foo/bar/<timestamp> 239 240 Returns: 241 A sister directory prefixed with 'temp-', e.g. /foo/bar/temp-<timestamp>. 242 """ 243 (dirname, basename) = os.path.split(timestamped_export_dir) 244 temp_export_dir = os.path.join( 245 compat.as_bytes(dirname), 246 compat.as_bytes('temp-{}'.format(six.ensure_text(basename)))) 247 return temp_export_dir 248 249 250def export_outputs_for_mode( 251 mode, serving_export_outputs=None, predictions=None, loss=None, 252 metrics=None): 253 """Util function for constructing a `ExportOutput` dict given a mode. 254 255 The returned dict can be directly passed to `build_all_signature_defs` helper 256 function as the `export_outputs` argument, used for generating a SignatureDef 257 map. 258 259 Args: 260 mode: A `ModeKeys` specifying the mode. 261 serving_export_outputs: Describes the output signatures to be exported to 262 `SavedModel` and used during serving. Should be a dict or None. 263 predictions: A dict of Tensors or single Tensor representing model 264 predictions. This argument is only used if serving_export_outputs is not 265 set. 266 loss: A dict of Tensors or single Tensor representing calculated loss. 267 metrics: A dict of (metric_value, update_op) tuples, or a single tuple. 268 metric_value must be a Tensor, and update_op must be a Tensor or Op 269 270 Returns: 271 Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object 272 The key is the expected SignatureDef key for the mode. 273 274 Raises: 275 ValueError: if an appropriate ExportOutput cannot be found for the mode. 276 """ 277 if mode not in SIGNATURE_KEY_MAP: 278 raise ValueError( 279 'Export output type not found for mode: {}. Expected one of: {}.\n' 280 'One likely error is that V1 Estimator Modekeys were somehow passed to ' 281 'this function. Please ensure that you are using the new ModeKeys.' 282 .format(mode, SIGNATURE_KEY_MAP.keys())) 283 signature_key = SIGNATURE_KEY_MAP[mode] 284 if mode_keys.is_predict(mode): 285 return get_export_outputs(serving_export_outputs, predictions) 286 elif mode_keys.is_train(mode): 287 return {signature_key: export_output_lib.TrainOutput( 288 loss=loss, predictions=predictions, metrics=metrics)} 289 else: 290 return {signature_key: export_output_lib.EvalOutput( 291 loss=loss, predictions=predictions, metrics=metrics)} 292 293 294def get_export_outputs(export_outputs, predictions): 295 """Validate export_outputs or create default export_outputs. 296 297 Args: 298 export_outputs: Describes the output signatures to be exported to 299 `SavedModel` and used during serving. Should be a dict or None. 300 predictions: Predictions `Tensor` or dict of `Tensor`. 301 302 Returns: 303 Valid export_outputs dict 304 305 Raises: 306 TypeError: if export_outputs is not a dict or its values are not 307 ExportOutput instances. 308 """ 309 if export_outputs is None: 310 default_output = export_output_lib.PredictOutput(predictions) 311 export_outputs = { 312 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} 313 314 if not isinstance(export_outputs, dict): 315 raise TypeError('export_outputs must be dict, given: {}'.format( 316 export_outputs)) 317 for v in six.itervalues(export_outputs): 318 if not isinstance(v, export_output_lib.ExportOutput): 319 raise TypeError( 320 'Values in export_outputs must be ExportOutput objects. ' 321 'Given: {}'.format(export_outputs)) 322 323 _maybe_add_default_serving_output(export_outputs) 324 325 return export_outputs 326 327 328def _maybe_add_default_serving_output(export_outputs): 329 """Add a default serving output to the export_outputs if not present. 330 331 Args: 332 export_outputs: Describes the output signatures to be exported to 333 `SavedModel` and used during serving. Should be a dict. 334 335 Returns: 336 export_outputs dict with default serving signature added if necessary 337 338 Raises: 339 ValueError: if multiple export_outputs were provided without a default 340 serving key. 341 """ 342 if len(export_outputs) == 1: 343 (key, value), = export_outputs.items() 344 if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 345 export_outputs[ 346 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value 347 if len(export_outputs) > 1: 348 if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 349 not in export_outputs): 350 raise ValueError( 351 'Multiple export_outputs were provided, but none of them is ' 352 'specified as the default. Do this by naming one of them with ' 353 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') 354 355 return export_outputs 356