1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel builder implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from google.protobuf.any_pb2 import Any 25 26from tensorflow.core.framework import types_pb2 27from tensorflow.core.protobuf import meta_graph_pb2 28from tensorflow.core.protobuf import saved_model_pb2 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging 35from tensorflow.python.saved_model import constants 36from tensorflow.python.saved_model import signature_def_utils 37from tensorflow.python.saved_model import utils_impl as saved_model_utils 38from tensorflow.python.training import saver as tf_saver 39from tensorflow.python.util import compat 40from tensorflow.python.util.deprecation import deprecated_args 41from tensorflow.python.util.tf_export import tf_export 42 43 44# Base class for the SavedModelBuilder that is only used by Tensorflow 45# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead. 46class _SavedModelBuilder(object): 47 """Builds the `SavedModel` protocol buffer and saves variables and assets. 48 49 The `SavedModelBuilder` class provides functionality to build a `SavedModel` 50 protocol buffer. Specifically, this allows multiple meta graphs to be saved as 51 part of a single language-neutral `SavedModel`, while sharing variables and 52 assets. 53 54 To build a SavedModel, the first meta graph must be saved with variables. 55 Subsequent meta graphs will simply be saved with their graph definitions. If 56 assets need to be saved and written or copied to disk, they can be provided 57 when the meta graph def is added. If multiple meta graph defs are associated 58 an asset of the same name, only the first version is retained. 59 60 Each meta graph added to the SavedModel must be annotated with tags. The tags 61 provide a means to identify the specific meta graph to load and restore, along 62 with the shared set of variables and assets. 63 64 Typical usage for the `SavedModelBuilder`: 65 ```python 66 ... 67 builder = tf.saved_model.Builder(export_dir) 68 69 with tf.Session(graph=tf.Graph()) as sess: 70 ... 71 builder.add_meta_graph_and_variables(sess, 72 ["foo-tag"], 73 signature_def_map=foo_signatures, 74 assets_list=foo_assets) 75 ... 76 77 with tf.Session(graph=tf.Graph()) as sess: 78 ... 79 builder.add_meta_graph(["bar-tag", "baz-tag"]) 80 ... 81 82 builder.save() 83 ``` 84 85 Note: This function will only be available through the v1 compatibility 86 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or 87 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new 88 object-based method of creating SavedModels. 89 """ 90 91 def __init__(self, export_dir): 92 self._saved_model = saved_model_pb2.SavedModel() 93 self._saved_model.saved_model_schema_version = ( 94 constants.SAVED_MODEL_SCHEMA_VERSION) 95 96 self._export_dir = export_dir 97 if file_io.file_exists(export_dir): 98 if file_io.list_directory(export_dir): 99 raise AssertionError( 100 "Export directory already exists, and isn't empty. Please choose " 101 "a different export directory, or delete all the contents of the " 102 "specified directory: %s" % export_dir) 103 else: 104 file_io.recursive_create_dir(self._export_dir) 105 106 # Boolean to track whether variables and assets corresponding to the 107 # SavedModel have been saved. Specifically, the first meta graph to be added 108 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations 109 # on the SavedModel MUST use the add_meta_graph() API which does not save 110 # weights. 111 self._has_saved_variables = False 112 113 def _save_and_write_assets(self, meta_graph_def, assets_list=None): 114 """Saves asset to the meta graph and writes asset files to disk. 115 116 Args: 117 meta_graph_def: The meta graph def to which the assets will be added. 118 assets_list: The list where the asset paths are setup. 119 """ 120 # Creates a function that adds assets into the meta graph def. 121 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def) 122 asset_filename_map = _maybe_save_assets(write_fn, assets_list) 123 124 # Return if there are no assets to write. 125 if not asset_filename_map: 126 tf_logging.info("No assets to write.") 127 return 128 129 # Copy assets from source path to destination path. 130 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 131 132 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): 133 """Tags the meta graph def and adds it to the SavedModel. 134 135 Tags the meta graph def with the supplied tags, adds signature defs to it if 136 provided and appends the meta graph def to the SavedModel proto. 137 138 Args: 139 meta_graph_def: The meta graph def to add to the SavedModel. 140 tags: The set of tags to annotate the meta graph def with. 141 signature_def_map: The map of signature defs to be added to the meta graph 142 def. 143 """ 144 for tag in tags: 145 meta_graph_def.meta_info_def.tags.append(tag) 146 147 if signature_def_map is not None: 148 for key in signature_def_map: 149 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) 150 151 proto_meta_graph_def = self._saved_model.meta_graphs.add() 152 proto_meta_graph_def.CopyFrom(meta_graph_def) 153 154 def _validate_tensor_info(self, tensor_info): 155 """Validates the `TensorInfo` proto. 156 157 Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist 158 and are non-empty. 159 160 Args: 161 tensor_info: `TensorInfo` protocol buffer to validate. 162 163 Raises: 164 AssertionError: If the `name` or `dtype` fields of the supplied 165 `TensorInfo` proto are not populated. 166 """ 167 if tensor_info is None: 168 raise AssertionError( 169 "All TensorInfo protos used in the SignatureDefs must have the name " 170 "and dtype fields set.") 171 if tensor_info.WhichOneof("encoding") is None: 172 # TODO(soergel) validate each of the fields of coo_sparse 173 raise AssertionError( 174 "All TensorInfo protos used in the SignatureDefs must have one of " 175 "the 'encoding' fields (e.g., name or coo_sparse) set: %s" 176 % tensor_info) 177 if tensor_info.dtype is types_pb2.DT_INVALID: 178 raise AssertionError( 179 "All TensorInfo protos used in the SignatureDefs must have the dtype " 180 "field set: %s" % tensor_info) 181 182 def _validate_signature_def_map(self, signature_def_map): 183 """Validates the `SignatureDef` entries in the signature def map. 184 185 Validation of entries in the signature def map includes ensuring that the 186 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and 187 `outputs` of each `SignatureDef` are populated. Also ensures that reserved 188 SigantureDef keys for the initialization and train ops are not used. 189 190 Args: 191 signature_def_map: The map of signature defs to be validated. 192 193 Raises: 194 AssertionError: If a TensorInfo is not valid. 195 KeyError: If a reserved signature key is used in the map. 196 """ 197 for signature_def_key in signature_def_map: 198 signature_def = signature_def_map[signature_def_key] 199 inputs = signature_def.inputs 200 outputs = signature_def.outputs 201 for inputs_key in inputs: 202 self._validate_tensor_info(inputs[inputs_key]) 203 for outputs_key in outputs: 204 self._validate_tensor_info(outputs[outputs_key]) 205 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map: 206 raise KeyError( 207 "SignatureDef map key \"{}\" is reserved for initialization. Please " 208 "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY)) 209 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map: 210 raise KeyError( 211 "SignatureDef map key \"{}\" is reserved for the train op. Please " 212 "use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY)) 213 214 def _maybe_create_saver(self, saver=None): 215 """Creates a sharded saver if one does not already exist.""" 216 if not saver: 217 # Initialize a saver to generate a sharded output for all saveables in the 218 # current scope. 219 saver = tf_saver.Saver( 220 variables._all_saveable_objects(), # pylint: disable=protected-access 221 sharded=True, 222 write_version=saver_pb2.SaverDef.V2, 223 allow_empty=True) 224 return saver 225 226 def add_meta_graph(self, 227 tags, 228 signature_def_map=None, 229 assets_list=None, 230 clear_devices=False, 231 init_op=None, 232 train_op=None, 233 saver=None): 234 """Adds the current meta graph to the SavedModel. 235 236 Creates a Saver in the current scope and uses the Saver to export the meta 237 graph def. Invoking this API requires the `add_meta_graph_and_variables()` 238 API to have been invoked before. 239 240 Args: 241 tags: The set of tags to annotate the meta graph def with. 242 signature_def_map: The map of signature defs to be added to the meta graph 243 def. 244 assets_list: Assets to be saved with SavedModel. Note 245 that this list should be a subset of the assets saved as part of 246 the first meta graph in the SavedModel. 247 clear_devices: Set to true if the device info on the default graph should 248 be cleared. 249 init_op: Op or group of ops to execute when the graph is loaded. Note 250 that when the init_op is specified it is run after the restore op at 251 load-time. 252 train_op: Op or group of opts that trains the model when run. This will 253 not be run automatically when the graph is loaded, instead saved in 254 a SignatureDef accessible through the exported MetaGraph. 255 saver: An instance of tf.train.Saver that will be used to export the 256 metagraph. If None, a sharded Saver that restores all variables will 257 be used. 258 259 Raises: 260 AssertionError: If the variables for the SavedModel have not been saved 261 yet, or if the graph already contains one or more legacy init ops. 262 """ 263 if not self._has_saved_variables: 264 raise AssertionError( 265 "Graph state including variables and assets has not been saved yet. " 266 "Please invoke `add_meta_graph_and_variables()` first.") 267 268 # Validate the signature def map to ensure all included TensorInfos are 269 # properly populated. 270 signature_def_map = signature_def_map or {} 271 self._validate_signature_def_map(signature_def_map) 272 273 # Create a SignatureDef pointing to the graph initialization op, which will 274 # be added to the MetaGraphDef. 275 _add_op_to_signature_def_map(signature_def_map, init_op, 276 constants.INIT_OP_SIGNATURE_KEY) 277 _add_op_to_signature_def_map(signature_def_map, train_op, 278 constants.TRAIN_OP_SIGNATURE_KEY) 279 280 saver = self._maybe_create_saver(saver) 281 282 # The graph almost certainly previously contained at least one Saver, and 283 # possibly several (e.g. one for loading a pretrained embedding, and another 284 # for the model weights). Removing the preexisting ones was the 285 # motivation for the clear_extraneous_savers option, but it turns out that 286 # there are edge cases where that option breaks the graph. Until that is 287 # resolved, we just leave the option set to False for now. 288 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 289 meta_graph_def = saver.export_meta_graph( 290 clear_devices=clear_devices, strip_default_attrs=True) 291 292 # Save asset files and write them to disk, if any. 293 self._save_and_write_assets(meta_graph_def, assets_list) 294 295 # Tag the meta graph def and add it to the SavedModel. 296 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 297 298 def add_meta_graph_and_variables(self, 299 sess, 300 tags, 301 signature_def_map=None, 302 assets_list=None, 303 clear_devices=False, 304 init_op=None, 305 train_op=None, 306 strip_default_attrs=False, 307 saver=None): 308 # pylint: disable=line-too-long 309 """Adds the current meta graph to the SavedModel and saves variables. 310 311 Creates a Saver to save the variables from the provided session. Exports the 312 corresponding meta graph def. This function assumes that the variables to be 313 saved have been initialized. For a given `SavedModelBuilder`, this API must 314 be called exactly once and for the first meta graph to save. For subsequent 315 meta graph defs to be added, the `add_meta_graph()` API must be used. 316 317 Args: 318 sess: The TensorFlow session from which to save the meta graph and 319 variables. 320 tags: The set of tags with which to save the meta graph. 321 signature_def_map: The map of signature def map to add to the meta graph 322 def. 323 assets_list: Assets to be saved with SavedModel. 324 clear_devices: Set to true if the device info on the default graph should 325 be cleared. 326 init_op: Op or group of ops to execute when the graph is loaded. Note 327 that when the init_op is specified it is run after the restore op at 328 load-time. 329 train_op: Op or group of ops that trains the model when run. This will 330 not be run automatically when the graph is loaded, instead saved in 331 a SignatureDef accessible through the exported MetaGraph. 332 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 333 removed from the NodeDefs. For a detailed guide, see 334 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 335 saver: An instance of tf.train.Saver that will be used to export the 336 metagraph and save variables. If None, a sharded Saver that restores 337 all variables will be used. 338 339 """ 340 # pylint: enable=line-too-long 341 if self._has_saved_variables: 342 raise AssertionError("Graph state including variables and assets has " 343 "already been saved. Please invoke " 344 "`add_meta_graph()` instead.") 345 346 # Validate the signature def map to ensure all included TensorInfos are 347 # properly populated. 348 signature_def_map = signature_def_map or {} 349 self._validate_signature_def_map(signature_def_map) 350 351 # Create a SignatureDef pointing to the graph initialization op, which will 352 # be added to the MetaGraphDef. 353 _add_op_to_signature_def_map(signature_def_map, init_op, 354 constants.INIT_OP_SIGNATURE_KEY) 355 _add_op_to_signature_def_map(signature_def_map, train_op, 356 constants.TRAIN_OP_SIGNATURE_KEY) 357 358 saved_model_utils.get_or_create_variables_dir(self._export_dir) 359 variables_path = saved_model_utils.get_variables_path(self._export_dir) 360 361 saver = self._maybe_create_saver(saver) 362 363 # Save the variables. Also, disable writing the checkpoint state proto. The 364 # file is not used during SavedModel loading. In addition, since a 365 # SavedModel can be copied or moved, this avoids the checkpoint state to 366 # become outdated. 367 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 368 369 # Export the meta graph def. 370 371 # The graph almost certainly previously contained at least one Saver, and 372 # possibly several (e.g. one for loading a pretrained embedding, and another 373 # for the model weights). Removing the preexisting ones was the 374 # motivation for the clear_extraneous_savers option, but it turns out that 375 # there are edge cases where that option breaks the graph. Until that is 376 # resolved, we just leave the option set to False for now. 377 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 378 meta_graph_def = saver.export_meta_graph( 379 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 380 381 # Save asset files and write them to disk, if any. 382 self._save_and_write_assets(meta_graph_def, assets_list) 383 384 # Tag the meta graph def and add it to the SavedModel. 385 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 386 387 # Mark this instance of SavedModel as having saved variables, such that 388 # subsequent attempts to save variables will fail. 389 self._has_saved_variables = True 390 391 def save(self, as_text=False): 392 """Writes a `SavedModel` protocol buffer to disk. 393 394 The function writes the SavedModel protocol buffer to the export directory 395 in serialized format. 396 397 Args: 398 as_text: Writes the SavedModel protocol buffer in text format to disk. 399 400 Returns: 401 The path to which the SavedModel protocol buffer was written. 402 """ 403 if not file_io.file_exists(self._export_dir): 404 file_io.recursive_create_dir(self._export_dir) 405 406 if as_text: 407 path = os.path.join( 408 compat.as_bytes(self._export_dir), 409 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 410 file_io.write_string_to_file(path, str(self._saved_model)) 411 else: 412 path = os.path.join( 413 compat.as_bytes(self._export_dir), 414 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 415 file_io.write_string_to_file(path, self._saved_model.SerializeToString()) 416 tf_logging.info("SavedModel written to: %s", compat.as_text(path)) 417 418 return path 419 420 421@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring 422class SavedModelBuilder(_SavedModelBuilder): 423 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list", 424 "assets_collection") 425 426 def __init__(self, export_dir): 427 super(SavedModelBuilder, self).__init__(export_dir=export_dir) 428 429 def _add_collections(self, assets_collection, main_op, train_op): 430 """Add asset and op collections to be saved.""" 431 # Save asset files and write them to disk, if any. 432 self._save_and_write_assets(assets_collection) 433 434 self._maybe_add_main_op(main_op) 435 436 self._add_train_op(train_op) 437 438 def _save_and_write_assets(self, assets_collection_to_add=None): 439 """Saves asset to the meta graph and writes asset files to disk. 440 441 Args: 442 assets_collection_to_add: The collection where the asset paths are setup. 443 """ 444 # Add assets to the collection with key `constants.ASSETS_KEY`, in the 445 # graph. 446 asset_filename_map = _maybe_save_assets(_add_asset_to_collection, 447 assets_collection_to_add) 448 449 # Return if there are no assets to write. 450 if not asset_filename_map: 451 tf_logging.info("No assets to write.") 452 return 453 454 # Copy assets from source path to destination path. 455 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 456 457 def _maybe_add_main_op(self, main_op): 458 """Adds main op to the SavedModel. 459 460 Args: 461 main_op: Main op to run as part of graph initialization. If None, no main 462 op will be added to the graph. 463 464 Raises: 465 TypeError: if main op is provided but is not of type `Operation`. 466 ValueError: if the Graph already contains an init op. 467 """ 468 if main_op is None: 469 return 470 471 if not isinstance(main_op, ops.Operation): 472 raise TypeError("main_op needs to be an Operation: %r" % main_op) 473 474 # Validate that no other init ops have been added to this graph already. 475 # We check main_op and legacy_init_op for thoroughness and explicitness. 476 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): 477 if ops.get_collection(init_op_key): 478 raise ValueError( 479 "Graph already contains one or more main ops under the " 480 "collection {}.".format(init_op_key)) 481 482 ops.add_to_collection(constants.MAIN_OP_KEY, main_op) 483 484 def _add_train_op(self, train_op): 485 """Add train op to the SavedModel. 486 487 Note that this functionality is in development, and liable to be 488 moved elsewhere. 489 490 Args: 491 train_op: Op or group of ops that are used for training. These are stored 492 as a collection with key TRAIN_OP_KEY, but not executed. 493 494 Raises: 495 TypeError if Train op is not of type `Operation`. 496 """ 497 if train_op is not None: 498 if (not isinstance(train_op, ops.Tensor) and 499 not isinstance(train_op, ops.Operation)): 500 raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) 501 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) 502 503 @deprecated_args(None, 504 "Pass your op to the equivalent parameter main_op instead.", 505 "legacy_init_op") 506 def add_meta_graph(self, 507 tags, 508 signature_def_map=None, 509 assets_collection=None, 510 legacy_init_op=None, 511 clear_devices=False, 512 main_op=None, 513 strip_default_attrs=False, 514 saver=None): 515 if not self._has_saved_variables: 516 raise AssertionError( 517 "Graph state including variables and assets has not been saved yet. " 518 "Please invoke `add_meta_graph_and_variables()` first.") 519 520 # Validate the signature def map to ensure all included TensorInfos are 521 # properly populated. 522 signature_def_map = signature_def_map or {} 523 self._validate_signature_def_map(signature_def_map) 524 525 # legacy_init_op is deprecated, and going away in TF 2.0. 526 # Re-mapping to main_op, as treatment is identical regardless. 527 main_op = main_op or legacy_init_op 528 529 # Add assets and ops 530 self._add_collections(assets_collection, main_op, None) 531 532 saver = self._maybe_create_saver(saver) 533 534 # The graph almost certainly previously contained at least one Saver, and 535 # possibly several (e.g. one for loading a pretrained embedding, and another 536 # for the model weights). Removing the preexisting ones was the 537 # motivation for the clear_extraneous_savers option, but it turns out that 538 # there are edge cases where that option breaks the graph. Until that is 539 # resolved, we just leave the option set to False for now. 540 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 541 meta_graph_def = saver.export_meta_graph( 542 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 543 544 # Tag the meta graph def and add it to the SavedModel. 545 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 546 547 @deprecated_args(None, 548 "Pass your op to the equivalent parameter main_op instead.", 549 "legacy_init_op") 550 def add_meta_graph_and_variables(self, 551 sess, 552 tags, 553 signature_def_map=None, 554 assets_collection=None, 555 legacy_init_op=None, 556 clear_devices=False, 557 main_op=None, 558 strip_default_attrs=False, 559 saver=None): 560 if self._has_saved_variables: 561 raise AssertionError("Graph state including variables and assets has " 562 "already been saved. Please invoke " 563 "`add_meta_graph()` instead.") 564 565 # Validate the signature def map to ensure all included TensorInfos are 566 # properly populated. 567 signature_def_map = signature_def_map or {} 568 self._validate_signature_def_map(signature_def_map) 569 570 # legacy_init_op is deprecated, and going away in TF 2.0. 571 # Re-mapping to main_op, as treatment is identical regardless. 572 main_op = main_op or legacy_init_op 573 574 # Add assets and ops 575 self._add_collections(assets_collection, main_op, None) 576 577 saved_model_utils.get_or_create_variables_dir(self._export_dir) 578 variables_path = saved_model_utils.get_variables_path(self._export_dir) 579 580 saver = self._maybe_create_saver(saver) 581 582 # Save the variables. Also, disable writing the checkpoint state proto. The 583 # file is not used during SavedModel loading. In addition, since a 584 # SavedModel can be copied or moved, this avoids the checkpoint state to 585 # become outdated. 586 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 587 588 # Export the meta graph def. 589 590 # The graph almost certainly previously contained at least one Saver, and 591 # possibly several (e.g. one for loading a pretrained embedding, and another 592 # for the model weights). Removing the preexisting ones was the 593 # motivation for the clear_extraneous_savers option, but it turns out that 594 # there are edge cases where that option breaks the graph. Until that is 595 # resolved, we just leave the option set to False for now. 596 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 597 meta_graph_def = saver.export_meta_graph( 598 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 599 600 # Tag the meta graph def and add it to the SavedModel. 601 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 602 603 # Mark this instance of SavedModel as having saved variables, such that 604 # subsequent attempts to save variables will fail. 605 self._has_saved_variables = True 606 607 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace( 608 "assets_list", "assets_collection") 609 add_meta_graph_and_variables.__doc__ = \ 610 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace( 611 "assets_list", "assets_collection") 612 613 614def _maybe_save_assets(write_fn, assets_to_add=None): 615 """Saves assets to the meta graph. 616 617 Args: 618 write_fn: A function callback that writes asset into meta graph. 619 assets_to_add: The list where the asset paths are setup. 620 621 Returns: 622 A dict of asset basenames for saving to the original full path to the asset. 623 624 Raises: 625 ValueError: Indicating an invalid filepath tensor. 626 """ 627 # Map of target file names to original filenames 628 asset_filename_map = {} 629 630 if assets_to_add is None: 631 tf_logging.info("No assets to save.") 632 return asset_filename_map 633 634 # Iterate over the supplied assets, build the `AssetFile` proto and add them 635 # to the meta graph. 636 for asset_tensor in assets_to_add: 637 asset_source_filepath = _asset_path_from_tensor(asset_tensor) 638 if not asset_source_filepath: 639 raise ValueError("Invalid asset filepath tensor %s" % asset_tensor) 640 641 asset_filename = get_asset_filename_to_add( 642 asset_source_filepath, asset_filename_map) 643 644 # Call the passed-in function that builds AssetFileDef proto and adds it 645 # to either the collection or asset_file_def field of the meta graph. 646 # Note that this should be done even when the file is a duplicate of an 647 # already-added file, as the tensor reference should still exist. 648 write_fn(asset_filename, asset_tensor) 649 650 # In the cases where we are adding a duplicate, this will result in the 651 # last of the filepaths being the one used for copying the file to the 652 # SavedModel. Since the files in question are the same, it doesn't matter 653 # either way. 654 asset_filename_map[asset_filename] = asset_source_filepath 655 656 tf_logging.info("Assets added to graph.") 657 return asset_filename_map 658 659 660def get_asset_filename_to_add(asset_filepath, asset_filename_map): 661 """Get a unique basename to add to the SavedModel if this file is unseen. 662 663 Assets come from users as full paths, and we save them out to the 664 SavedModel as basenames. In some cases, the basenames collide. Here, 665 we dedupe asset basenames by first checking if the file is the same, 666 and, if different, generate and return an index-suffixed basename 667 that can be used to add the asset to the SavedModel. 668 669 Args: 670 asset_filepath: the full path to the asset that is being saved 671 asset_filename_map: a dict of filenames used for saving the asset in 672 the SavedModel to full paths from which the filenames were derived. 673 674 Returns: 675 Uniquified filename string if the file is not a duplicate, or the original 676 filename if the file has already been seen and saved. 677 """ 678 asset_filename = os.path.basename(asset_filepath) 679 680 if asset_filename not in asset_filename_map: 681 # This is an unseen asset. Safe to add. 682 return asset_filename 683 684 other_asset_filepath = asset_filename_map[asset_filename] 685 if other_asset_filepath == asset_filepath: 686 # This is the same file, stored twice in the list. No need 687 # to make unique. 688 return asset_filename 689 690 # Else, asset_filename is in the map, and the filepath is different. Dedupe. 691 if not file_io.filecmp(asset_filepath, other_asset_filepath): 692 # Files are different; dedupe filenames. 693 return _get_unique_asset_filename(asset_filename, asset_filename_map) 694 695 # Files are the same; don't make unique. 696 return asset_filename 697 698 699def _get_unique_asset_filename(asset_filename, asset_filename_map): 700 i = 1 701 unique_filename = asset_filename 702 while unique_filename in asset_filename_map: 703 unique_filename = compat.as_bytes("_").join( 704 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) 705 i += 1 706 return unique_filename 707 708 709def _asset_path_from_tensor(path_tensor): 710 """Returns the filepath value stored in constant `path_tensor`. 711 712 Args: 713 path_tensor: Tensor of a file-path. 714 715 Returns: 716 The string value i.e. path of the tensor, if valid. 717 718 Raises: 719 TypeError if tensor does not match expected op type, dtype or value. 720 """ 721 if not isinstance(path_tensor, ops.Tensor): 722 raise TypeError("Asset path tensor must be a Tensor.") 723 if path_tensor.op.type != "Const": 724 raise TypeError("Asset path tensor must be of type constant.") 725 if path_tensor.dtype != dtypes.string: 726 raise TypeError("Asset path tensor must be of dtype string.") 727 str_values = path_tensor.op.get_attr("value").string_val 728 if len(str_values) != 1: 729 raise TypeError("Asset path tensor must be a scalar.") 730 return str_values[0] 731 732 733def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor): 734 """Builds an asset proto and adds it to the meta graph def. 735 736 Args: 737 meta_graph_def: The meta graph def to which the asset will be added. 738 asset_filename: The filename of the asset to be added. 739 asset_tensor: The asset tensor used to populate the tensor info of the asset 740 proto. 741 """ 742 asset_proto = meta_graph_def.asset_file_def.add() 743 asset_proto.filename = asset_filename 744 asset_proto.tensor_info.name = asset_tensor.name 745 746 747def copy_assets_to_destination_dir(asset_filename_map, destination_dir): 748 """Copy all assets from source path to destination path.""" 749 assets_destination_dir = saved_model_utils.get_or_create_assets_dir( 750 destination_dir) 751 752 # Copy each asset from source path to destination path. 753 for asset_basename, asset_source_filepath in asset_filename_map.items(): 754 asset_destination_filepath = os.path.join( 755 compat.as_bytes(assets_destination_dir), 756 compat.as_bytes(asset_basename)) 757 758 # Only copy the asset file to the destination if it does not already 759 # exist. This is to ensure that an asset with the same name defined as 760 # part of multiple graphs is only copied the first time. 761 if not file_io.file_exists(asset_destination_filepath): 762 file_io.copy(asset_source_filepath, asset_destination_filepath) 763 764 tf_logging.info("Assets written to: %s", 765 compat.as_text(assets_destination_dir)) 766 767 768def _add_asset_to_collection(asset_filename, asset_tensor): 769 """Builds an asset proto and adds it to the asset collection of the graph. 770 771 Args: 772 asset_filename: The filename of the asset to be added. 773 asset_tensor: The asset tensor used to populate the tensor info of the 774 asset proto. 775 """ 776 asset_proto = meta_graph_pb2.AssetFileDef() 777 asset_proto.filename = asset_filename 778 asset_proto.tensor_info.name = asset_tensor.name 779 780 asset_any_proto = Any() 781 asset_any_proto.Pack(asset_proto) 782 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) 783 784 785def _add_op_to_signature_def_map(signature_def_map, op, key): 786 if op is not None: 787 signature_def_map[key] = signature_def_utils.op_signature_def(op, key) 788