1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a trackable object from a SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from tensorflow.core.protobuf import graph_debug_info_pb2 25from tensorflow.python.distribute import distribute_utils 26from tensorflow.python.distribute import distribution_strategy_context as ds_context 27from tensorflow.python.distribute import values_util 28from tensorflow.python.eager import context 29from tensorflow.python.eager import function 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import custom_gradient 38from tensorflow.python.ops import lookup_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import variables 41from tensorflow.python.saved_model import function_deserialization 42from tensorflow.python.saved_model import load_options 43from tensorflow.python.saved_model import load_v1_in_v2 44from tensorflow.python.saved_model import loader_impl 45from tensorflow.python.saved_model import nested_structure_coder 46from tensorflow.python.saved_model import revived_types 47from tensorflow.python.saved_model import utils_impl as saved_model_utils 48from tensorflow.python.training.saving import checkpoint_options 49from tensorflow.python.training.saving import saveable_object_util 50from tensorflow.python.training.tracking import base 51from tensorflow.python.training.tracking import data_structures 52from tensorflow.python.training.tracking import graph_view 53from tensorflow.python.training.tracking import tracking 54from tensorflow.python.training.tracking import util 55from tensorflow.python.util import nest 56from tensorflow.python.util.tf_export import tf_export 57 58 59def _unused_handle(): 60 """Returns a placeholder as a handle that is not supposed to be accessed.""" 61 error_message = ("Trying to access a placeholder that is not supposed to be " 62 "executed. This means you are executing a graph generated " 63 "from the cross-replica context in an in-replica context.") 64 65 assert_op = control_flow_ops.Assert( 66 array_ops.placeholder_with_default(False, shape=()), 67 [error_message]) 68 69 with ops.control_dependencies([assert_op]): 70 return array_ops.placeholder(dtype=dtypes.resource) 71 72 73class _WrapperFunction(function.ConcreteFunction): 74 """A class wraps a concrete function to handle different distributed contexts. 75 76 The reason for wrapping a concrete function is because the _captured_inputs 77 fields used for in-replica context and cross-replica context are different. 78 When `load()` is called from within a tf.distribute.strategy scope, the 79 captured inputs are distributed variables. When using these distributed 80 variables during calling the function, we need different approaches when it is 81 in-replica and when it is not in-replica. When it is in replica, naturally we 82 should use the corresponding component of the distributed variable; when it is 83 not in-replica, calling the function should mean that it is constructing a 84 graph that is not actually going to be used. A typical use case is when 85 constructing a functional model. In this case, return a placeholder with a 86 control dependency to ensure that is never accessed. 87 """ 88 89 def __init__(self, concrete_function): 90 # Shallow copy the concrete_function 91 self.__dict__.update(vars(concrete_function)) 92 93 def _call_flat(self, args, captured_inputs, cancellation_manager=None): 94 95 def get_handle(x): 96 return x.handle if distribute_utils.is_distributed_variable(x) else x 97 98 def get_unused_handle(x): 99 return _unused_handle() if distribute_utils.is_distributed_variable(x) \ 100 else x 101 102 if (ds_context.get_replica_context() is not None or 103 values_util.is_saving_non_distributed()): 104 # If we're in the replica context or are saving a non-distributed version 105 # of the model, we resolve the captured variables to the corresponding 106 # resource handle. In both situation we call var.handle, but it has 107 # different behavior. In the replica context, var.handle resolves the 108 # replica local variable handle if the variable is replicated. When saving 109 # a non-distributed version of the model, var.handle resolves to the 110 # primary variable handle, since we only save one copy of a replicated 111 # variable. 112 captured_inputs = list(map(get_handle, captured_inputs)) 113 else: # cross-replica context 114 captured_inputs = list(map(get_unused_handle, captured_inputs)) 115 return super(_WrapperFunction, self)._call_flat(args, captured_inputs, 116 cancellation_manager) 117 118 119class Loader(object): 120 """Helper class to load an object-based SavedModel.""" 121 122 def __init__(self, object_graph_proto, saved_model_proto, export_dir, 123 ckpt_options, filters): 124 meta_graph = saved_model_proto.meta_graphs[0] 125 self._asset_file_def = meta_graph.asset_file_def 126 self._operation_attributes = { 127 node.name: node.attr for node in meta_graph.graph_def.node} 128 self._proto = object_graph_proto 129 self._export_dir = export_dir 130 self._concrete_functions = ( 131 function_deserialization.load_function_def_library( 132 meta_graph.graph_def.library)) 133 self._checkpoint_options = ckpt_options 134 135 # Stores user-defined node_filters argument. 136 self._node_filters = filters 137 # Stores map of string paths to integers. 138 self._node_path_to_id = self._convert_node_paths_to_ints() 139 self._loaded_nodes = {} 140 if isinstance(filters, dict): 141 # If node_filters is a dict, then the values may contain already created 142 # trackable objects. In this case, create a dictionary mapping node IDs to 143 # the already created nodes. This dict will be updated in 144 # `_retrieve_all_filtered_nodes` with tracked dependencies. 145 for node_path, node in filters.items(): 146 if isinstance(node, tuple): 147 self._loaded_nodes[self._node_path_to_id[node_path]] = node 148 else: 149 self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr) 150 151 # Get a list of all integer node ids to load, or None if all nodes should be 152 # loaded. This list includes ids of child nodes. 153 self._filtered_nodes = self._retrieve_all_filtered_nodes() 154 155 for name, concrete_function in self._concrete_functions.items(): 156 # Wrap all the concrete function so that they are capable of dealing with 157 # both in replica and cross replica cases. 158 self._concrete_functions[name] = _WrapperFunction(concrete_function) 159 160 self._load_all() 161 self._restore_checkpoint() 162 163 for node in self._nodes: 164 if isinstance(node, tracking.CapturableResource): 165 init_op = node._initialize() # pylint: disable=protected-access 166 if not context.executing_eagerly(): 167 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 168 169 def _convert_node_paths_to_ints(self): 170 """Maps all string node paths in node_filters to the int node ids.""" 171 if self._node_filters is None: 172 return None 173 path_to_int = {} 174 for node_id in self._node_filters: 175 int_node_id = None 176 if isinstance(node_id, str): 177 node_path = node_id.split(".") 178 if node_path[0] != "root": 179 raise ValueError( 180 "When passing string identifiers to node_filters, the first name" 181 " must be root.") 182 int_node_id = 0 183 for n, name in enumerate(node_path[1:]): 184 int_node_id = self._find_node_child( 185 int_node_id, name, ".".join(node_path[:n+2])) 186 path_to_int[node_id] = int_node_id 187 else: 188 raise TypeError("Elements in node_filters must be strings.") 189 return path_to_int 190 191 def _retrieve_all_filtered_nodes(self): 192 """Traverses through the object graph to get the IDs of all nodes to load. 193 194 As a side-effect, if node_filters is a dictionary that contains already- 195 created objects, then the dependencies tracked by those objects will be 196 added to node_filters. 197 198 Returns: 199 List of all nodes to load, or None if all nodes should be loaded. 200 201 """ 202 if self._node_filters is None: 203 return None # All nodes should be loaded. 204 205 all_filtered_nodes = set() 206 nodes_to_visit = list(self._node_filters) 207 208 while nodes_to_visit: 209 node_path = nodes_to_visit.pop(0) 210 node_id = self._node_path_to_id[node_path] 211 if node_id in all_filtered_nodes: 212 continue 213 all_filtered_nodes.add(node_id) 214 215 node, setter = self._loaded_nodes.get(node_id, (None, None)) 216 if node is not None: 217 if not isinstance(node, base.Trackable): 218 raise TypeError( 219 "Error when processing dictionary values passed to nodes_to_load." 220 "Object at {} is expected to be a checkpointable TensorFlow " 221 "object (e.g. tf.Variable, tf.Module or Keras layer)." 222 .format(node_path)) 223 node._maybe_initialize_trackable() # pylint: disable=protected-access 224 225 for reference in self._proto.nodes[node_id].children: 226 child_object, _ = self._loaded_nodes.get( 227 reference.node_id, (None, None)) 228 229 # See if node already tracks the child reference, in which case add the 230 # child to the loaded_nodes dict. 231 if child_object is None and node is not None: 232 child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access 233 if isinstance(child_object, data_structures.TrackableDataStructure): 234 # Make setattr a noop to avoid overwriting already existing data 235 # structures. 236 setter = lambda *args: None 237 238 self._loaded_nodes[reference.node_id] = (child_object, setter) 239 240 child_path = "{}.{}".format(node_path, reference.local_name) 241 self._node_path_to_id[child_path] = reference.node_id 242 nodes_to_visit.append(child_path) 243 244 if 0 in all_filtered_nodes: 245 return None 246 return all_filtered_nodes 247 248 def _find_node_child(self, node_id, child_name, path): 249 for reference in self._proto.nodes[node_id].children: 250 if reference.local_name == child_name: 251 return reference.node_id 252 raise ValueError("unable to find node {}".format(path)) 253 254 def _load_all(self): 255 """Loads all nodes and functions from the SavedModel and their edges.""" 256 self._load_nodes() 257 self._load_edges() 258 # TODO(b/124045874): There are limitations with functions whose captures 259 # trigger other functions to be executed. For now it is only guaranteed to 260 # work if the captures of a function only trigger functions without 261 # captures. 262 self._setup_functions_structures() 263 self._setup_functions_captures() 264 265 self._create_saveable_object_factories() 266 267 def _create_saveable_object_factories(self): 268 for node_id, proto in self._iter_all_nodes(): 269 node = self.get(node_id) 270 node._self_saveable_object_factories = {} # pylint: disable=protected-access 271 for name, saveable_object_proto in proto.saveable_objects.items(): 272 node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access 273 saveable_object_util.restored_saved_object_factory( 274 self.get(saveable_object_proto.save_function), 275 self.get(saveable_object_proto.restore_function))) 276 277 def _load_edges(self): 278 """Adds edges from objects to other objects and functions.""" 279 for node_id, object_proto in self._iter_all_nodes(): 280 self._add_object_graph_edges(object_proto, node_id) 281 282 # If root object isn't loaded, then create edges from the root for 283 # checkpoint compatibility. 284 if self._filtered_nodes is not None and 0 not in self._filtered_nodes: 285 root = self.get(0) 286 for node_path in self._node_filters: 287 loaded_node = self._nodes[self._node_path_to_id[node_path]] 288 path = node_path.split(".") 289 current_node = root 290 for name in path[1:-1]: 291 if not hasattr(current_node, name): 292 setattr(current_node, name, self._recreate_base_user_object()[0]) 293 current_node = getattr(current_node, name) 294 if not hasattr(current_node, path[-1]): 295 setattr(current_node, path[-1], loaded_node) 296 297 def _add_object_graph_edges(self, proto, node_id): 298 """Adds edges from an object to its children.""" 299 obj = self._nodes[node_id] 300 setter = self._node_setters[node_id] 301 302 for reference in proto.children: 303 setter(obj, reference.local_name, self._nodes[reference.node_id]) 304 # Note: if an object has an attribute `__call__` add a class method 305 # that allows `obj()` syntax to work. This is done per-instance to 306 # allow `callable` to be used to find out if an object is callable. 307 if reference.local_name == "__call__" and not callable(obj): 308 setattr(type(obj), "__call__", _call_attribute) 309 310 def _setup_functions_structures(self): 311 """Setup structure for inputs and outputs of restored functions.""" 312 coder = nested_structure_coder.StructureCoder() 313 for name, proto in sorted(self._proto.concrete_functions.items()): 314 concrete_function = self._concrete_functions[name] 315 # By setting the structured_outputs directly, we can rely on this 316 # function_lib.ConcreteFunction object to perform the output repacking 317 # logic. The only limitation of that logic is that it only works 318 # with output that is convertible to Tensors and the conversion 319 # always happens. For example tf.TensorShape([2, 3]) will be 320 # converted to Tensor representing [2, 3]. 321 original_outputs = coder.decode_proto(proto.output_signature) 322 # The original_outputs here had Tensors converted to TensorSpecs, so 323 # the restored function's structured_outputs field will not be 324 # exactly the same. Fortunately the repacking logic cares only about 325 # the structure; and the unpacking logic cares only about structure 326 # and types. 327 concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access 328 concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access 329 coder.decode_proto(proto.canonicalized_input_signature)) 330 concrete_function._initialize_function_spec() # pylint: disable=protected-access 331 332 def _setup_functions_captures(self): 333 """Setup captures and variables in restored functions.""" 334 concrete_functions = sorted(self._proto.concrete_functions.items()) 335 for name, proto in concrete_functions: 336 concrete_function = self._concrete_functions[name] 337 bound_inputs = [ 338 self._get_tensor_from_node(node_id, name) 339 for node_id in proto.bound_inputs] 340 bound_variables = [ 341 self._nodes[node_id] 342 for node_id in proto.bound_inputs 343 if self._proto.nodes[node_id].WhichOneof("kind") == "variable" 344 ] 345 # TODO(andresp): This is only injecting the captured inputs into the 346 # concrete function, note that we did not modify the FuncGraph 347 # itself. 348 concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access 349 concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access 350 if bound_inputs: 351 for bound_input, internal_capture in zip( 352 bound_inputs, concrete_function.inputs[-len(bound_inputs):]): 353 if distribute_utils.is_distributed_variable(bound_input): 354 concrete_function.graph.capture_distributed_variable( 355 bound_input, internal_capture) 356 else: 357 concrete_function.graph.replace_capture(bound_input, 358 internal_capture) 359 if internal_capture.dtype == dtypes.resource: 360 if resource_variable_ops.is_resource_variable(bound_input): 361 try: 362 handle = bound_input.handle 363 except ValueError: 364 # For mirrored variables we'll copy handle data for components 365 # as they get captured. 366 pass 367 else: 368 custom_gradient.copy_handle_data(handle, internal_capture) 369 else: 370 custom_gradient.copy_handle_data(bound_input, internal_capture) 371 # Setting "captures" first means "capture" won't create a new 372 # placeholder for this input. 373 concrete_function.graph.capture(bound_input) 374 375 def _get_tensor_from_node(self, node_id, fn_name): 376 """Resolves a node id into a tensor to be captured for a function.""" 377 if self._node_filters is not None and self._nodes[node_id] is None: 378 raise ValueError( 379 "Error when processing nodes_to_load. Function \"{}\" requires " 380 "inputs/variables that are not loaded when nodes_to_load={}" 381 .format(fn_name, self._node_filters)) 382 383 with ops.init_scope(): 384 obj = self._nodes[node_id] 385 if distribute_utils.is_distributed_variable(obj): 386 return obj 387 elif resource_variable_ops.is_resource_variable(obj): 388 return obj.handle 389 elif isinstance(obj, tracking.Asset): 390 return obj.asset_path 391 elif tensor_util.is_tf_type(obj): 392 return obj 393 elif isinstance(obj, tracking.CapturableResource): 394 # Note: this executes restored functions in the CapturableResource. 395 return obj.resource_handle 396 raise ValueError("Can't convert node %s to tensor" % (type(obj))) 397 398 def _initialize_loaded_nodes(self): 399 nodes = {} 400 node_setters = {} 401 for node_id, (node, setter) in self._loaded_nodes.items(): 402 nodes[node_id] = node 403 node_setters[node_id] = setter 404 return nodes, node_setters 405 406 def _iter_all_nodes(self): 407 if self._filtered_nodes is None: 408 return enumerate(self._proto.nodes) 409 else: 410 return [(node_id, self._proto.nodes[node_id]) 411 for node_id in self._filtered_nodes] 412 413 def _load_nodes(self): 414 """Load all saved objects.""" 415 # `nodes` maps from node ids to recreated objects 416 # `node_setters` maps from node ids to setter functions 417 # (same signature as setattr) for setting dependencies. 418 nodes, node_setters = self._initialize_loaded_nodes() 419 420 # Figure out which objects are slot variables. These objects are created 421 # with Optimizer.add_slot rather than _recreate_variable. 422 slot_variable_node_ids = set() 423 424 for _, proto in self._iter_all_nodes(): 425 for slot_variable_proto in proto.slot_variables: 426 slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id) 427 428 # Re-create everything except slot variables. 429 for node_id, proto in self._iter_all_nodes(): 430 if node_id in slot_variable_node_ids or nodes.get(node_id) is not None: 431 # Defer recreating slot variables so we can use the public Optimizer 432 # interface. 433 continue 434 node, setter = self._recreate(proto, node_id) 435 nodes[node_id] = node 436 node_setters[node_id] = setter 437 438 # Now that we have created the variables being optimized, we have enough 439 # information to re-create slot variables for them. 440 for node_id, proto in self._iter_all_nodes(): 441 optimizer_object = nodes[node_id] 442 for slot_variable_proto in proto.slot_variables: 443 optimized_variable = nodes[ 444 slot_variable_proto.original_variable_node_id] 445 slot_variable = optimizer_object.add_slot( 446 var=optimized_variable, 447 slot_name=slot_variable_proto.slot_name) 448 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 449 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 450 451 # If root object is not loaded, add a dummy root object for checkpoint 452 # compatibility. 453 if 0 not in nodes: 454 nodes[0] = self._recreate_base_user_object()[0] 455 456 self._nodes = [nodes.get(node_id) 457 for node_id in range(len(self._proto.nodes))] 458 self._node_setters = node_setters 459 460 @property 461 def _expect_partial_checkpoint(self): 462 """Whether to expect that some objects aren't loaded. 463 464 This should be set to True in subclasses of the Loader class which generate 465 a trackable object with an object graph that is different from the graph 466 in the SavedModel. Setting this property to True suppresses the warnings 467 that are printed out when there are unused parts of the checkpoint or 468 object. 469 470 Returns: 471 boolean 472 """ 473 return False 474 475 def _restore_checkpoint(self): 476 """Load state from checkpoint into the deserialized objects.""" 477 variables_path = saved_model_utils.get_variables_path(self._export_dir) 478 # TODO(andresp): Clean use of private methods of TrackableSaver. 479 # pylint: disable=protected-access 480 saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 481 with ops.device("CPU"): 482 saver._file_prefix_placeholder = constant_op.constant(variables_path) 483 if self._expect_partial_checkpoint: 484 load_status = saver.restore(variables_path, 485 self._checkpoint_options).expect_partial() 486 else: 487 load_status = saver.restore(variables_path, self._checkpoint_options) 488 load_status.assert_existing_objects_matched() 489 checkpoint = load_status._checkpoint 490 491 if not context.executing_eagerly(): 492 # When running in eager mode, the `restore` call above has already run and 493 # restored the state of trackables, and calling `position.restore_ops()` 494 # would re-run the restore. In graph mode, that will return a cached list 495 # of ops that must run to restore the object on that position. We have to 496 # wire them in the initializers of the objects so that they get 497 # initialized properly when using common practices (e.g. the ones used by 498 # ManagedSession) without further user action. 499 for object_id, obj in dict(checkpoint.object_by_proto_id).items(): 500 position = base.CheckpointPosition(checkpoint=checkpoint, 501 proto_id=object_id) 502 restore_ops = position.restore_ops() 503 if restore_ops: 504 if resource_variable_ops.is_resource_variable(obj): 505 if len(restore_ops) == 1: 506 obj._initializer_op = restore_ops[0] 507 else: 508 obj._initializer_op = control_flow_ops.group(*restore_ops) 509 elif isinstance(obj, lookup_ops.LookupInterface): 510 # We don't need to check for eager execution here, since this code 511 # path should only be taken if we are restoring in graph mode. 512 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) 513 else: 514 raise NotImplementedError( 515 ("Missing functionality to restore state of object " 516 "%r from the checkpoint." % obj)) 517 518 def adjust_debug_info_func_names(self, debug_info): 519 """Rewrite func names in the debug info by using the concrete func names.""" 520 output_debug_info = graph_debug_info_pb2.GraphDebugInfo() 521 output_debug_info.files[:] = debug_info.files 522 for key in debug_info.traces: 523 node, func = key.split("@") 524 new_func = "" 525 if func in self._concrete_functions: 526 new_func = self._concrete_functions[func].function_def.signature.name 527 output_debug_info.traces[node + "@" + new_func].CopyFrom( 528 debug_info.traces[key]) 529 return output_debug_info 530 531 def get(self, node_id): 532 if isinstance(node_id, str): 533 node_id = self._node_path_to_id[node_id] 534 return self._nodes[node_id] 535 536 def _recreate(self, proto, node_id): 537 """Creates a Python object from a SavedObject protocol buffer.""" 538 factory = { 539 "user_object": ( 540 lambda: self._recreate_user_object(proto.user_object, node_id)), 541 "asset": lambda: self._recreate_asset(proto.asset), 542 "function": lambda: self._recreate_function(proto.function), 543 "bare_concrete_function": functools.partial( 544 self._recreate_bare_concrete_function, 545 proto.bare_concrete_function), 546 "variable": lambda: self._recreate_variable(proto.variable), 547 "constant": lambda: self._recreate_constant(proto.constant), 548 "resource": lambda: self._recreate_resource(proto.resource), 549 } 550 kind = proto.WhichOneof("kind") 551 if kind not in factory: 552 raise ValueError("Unknown SavedObject type: %r" % kind) 553 return factory[kind]() 554 555 def _recreate_user_object(self, proto, node_id): 556 """Instantiates a SavedUserObject.""" 557 looked_up = revived_types.deserialize(proto) 558 if looked_up is None: 559 return self._recreate_base_user_object(proto, node_id) 560 return looked_up 561 562 def _recreate_base_user_object(self, proto=None, node_id=None): 563 del proto, node_id 564 # Note: each user object has its own class. This allows making each one 565 # individually callable by adding a `__call__` method to the classes of 566 # the objects instances that have a `__call__` property. 567 568 class _UserObject(tracking.AutoTrackable): 569 pass 570 571 return _UserObject(), setattr 572 573 def _recreate_asset(self, proto): 574 filename = os.path.join( 575 saved_model_utils.get_assets_dir(self._export_dir), 576 self._asset_file_def[proto.asset_file_def_index].filename) 577 asset = tracking.Asset(filename) 578 if not context.executing_eagerly(): 579 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset.asset_path) 580 return asset, setattr 581 582 def _recreate_function(self, proto): 583 return function_deserialization.recreate_function( 584 proto, self._concrete_functions), setattr 585 586 def _recreate_bare_concrete_function(self, proto): 587 return function_deserialization.setup_bare_concrete_function( 588 proto, self._concrete_functions), setattr 589 590 def _recreate_variable(self, proto): 591 name = proto.name if proto.name else None 592 if name is not None: 593 dbg_name = name 594 else: 595 dbg_name = "<variable loaded from saved model>" 596 synchronization, aggregation, trainable = ( 597 variables.validate_synchronization_aggregation_trainable( 598 proto.synchronization, proto.aggregation, proto.trainable, 599 name=dbg_name)) 600 601 def uninitialized_variable_creator(next_creator, **kwargs): 602 """A variable creator that creates uninitialized variables.""" 603 del next_creator 604 return resource_variable_ops.UninitializedVariable(**kwargs) 605 606 # Create a variable_creator_scope that creates uninitialized variables with 607 # a lower priority such that a potential distributed variable_creator_scope 608 # can take precedence. 609 with ops.get_default_graph()._variable_creator_scope( # pylint: disable=protected-access 610 uninitialized_variable_creator, 611 priority=50): 612 return variables.Variable( 613 shape=proto.shape, 614 dtype=proto.dtype, 615 name=name, 616 trainable=trainable, 617 synchronization=synchronization, 618 aggregation=aggregation), setattr 619 620 def _recreate_constant(self, proto): 621 tensor_proto = self._operation_attributes[proto.operation]["value"].tensor 622 ndarray = tensor_util.MakeNdarray(tensor_proto) 623 if dtypes.as_dtype(tensor_proto.dtype) == dtypes.string: 624 with ops.device("CPU"): 625 imported_constant = constant_op.constant(ndarray) 626 else: 627 imported_constant = constant_op.constant(ndarray) 628 return imported_constant, setattr 629 630 def _recreate_resource(self, proto): 631 return _RestoredResource(device=proto.device), setattr 632 633 634# TODO(b/124205571,b/124092991): Solve destruction of resources. 635class _RestoredResource(tracking.TrackableResource): 636 """Restored SavedResource.""" 637 638 def __init__(self, device=""): 639 super(_RestoredResource, self).__init__(device=device) 640 self._destroy_resource_fn = None 641 642 def _create_resource(self): 643 raise RuntimeError() 644 645 def _initialize(self): 646 raise RuntimeError() 647 648 @property 649 def _destroy_resource(self): 650 return self._destroy_resource_fn 651 652 @_destroy_resource.setter 653 def _destroy_resource(self, destroy_resource_fn): 654 self._resource_deleter = tracking.CapturableResourceDeleter( 655 destroy_resource_fn) 656 self._destroy_resource_fn = destroy_resource_fn 657 658 def _list_functions_for_serialization(self, unused_serialization_cache): 659 # Overwrite this method to avoid the implementation of 660 # base class to re-wrap the polymorphic functions into 661 # another layer of `tf.function`. 662 functions = { 663 "_create_resource": self._create_resource, 664 "_initialize": self._initialize, 665 } 666 if self._destroy_resource: 667 functions.update(_destroy_resource=self._destroy_resource) 668 return functions 669 670 671def _call_attribute(instance, *args, **kwargs): 672 return instance.__call__(*args, **kwargs) 673 674 675@tf_export("__internal__.saved_model.load_partial", v1=[]) 676def load_partial(export_dir, filters, tags=None, options=None): 677 """Partially load a SavedModel (saved from V2). 678 679 Similar to `tf.saved_model.load`, but with an additional argument that 680 lets you specify which nodes to load. 681 `tf.saved_model.load_partial(export_dir, ["root"])` and 682 `tf.saved_model.load(export_dir)` are equivalent. 683 684 Note: This only works for SavedModels saved with TensorFlow V2 from 685 `tf.saved_model.save` or Keras. This will not load SavedModels save from 686 the Estimator API. 687 688 In Tensorflow V2, SavedModel stores the **object graph** of the saved object. 689 The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras 690 layers, etc.) and edges that are the name of the attributes connecting the 691 objects. 692 693 *Example 1* 694 695 ``` 696 model = tf.Module() 697 model.child_layer = tf.Module() 698 model.child_layer.v = tf.Variable(5.) 699 tf.saved_model.save(model, '/tmp/model') 700 loaded = tf.__internal__.saved_model.load_partial( 701 ... '/tmp/model', 702 ... ['root.child_layer', 'root.child_layer.v']) 703 loaded['root.child_layer'].v.numpy() 704 5. 705 loaded['root.child_layer'].v is loaded['root.child_layer.v'] 706 True 707 708 *Example 2* 709 model = tf.Module() 710 model.child_layer = tf.Module() 711 model.child_layer.v = tf.Variable(5.) 712 >>> 713 tf.saved_model.save(model, '/tmp/model') 714 # Create a variable 715 new_variable = tf.Variable(0.) 716 loaded = tf.__internal__.saved_model.load_partial( 717 ... '/tmp/model', 718 ... {'root.child_layer': None, 'root.child_layer.v': new_variable}) 719 loaded['root.child_layer'].v.numpy() 720 5. 721 new_variable.numpy() 722 5. 723 ``` 724 725 **Loading under different distribution strategies** 726 You can load different parts of the model under different distribution 727 strategies. Note that this is very experimental so use with care. 728 729 ``` 730 model = tf.Module() 731 model.layer_1 = tf.Module() 732 model.layer_1.v = tf.Variable(5.) 733 model.layer_2 = tf.Module() 734 model.layer_2.v = tf.Variable(7.) 735 tf.saved_model.save(model, '/tmp/model') 736 # Load with no strategy 737 loaded = tf.__internal__.saved_model.load_partial( 738 ... '/tmp/model', 739 ... ['root.layer_1']) 740 loaded['root.layer_1'].v 741 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0> 742 strategy = tf.distribute.MirroredStrategy() 743 with strategy.scope(): 744 ... loaded2 = tf.__internal__.saved_model.load_partial( 745 ... '/tmp/model', 746 ... ['root.layer_2']) 747 loaded2['root.layer_2'].v 748 MirroredVariable:{ 749 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0> 750 } 751 ``` 752 753 Args: 754 export_dir: The SavedModel directory to load from. 755 filters: A list or dictionary where each element or key is a string 756 path to nodes that should be loaded. Node paths consist of all the child 757 attribute names to reach that node in the form: `root.{attribute_name}`. 758 The loader will load all of the specified nodes and their recursive 759 descendants. When this option is defined, the loader will return a 760 dictionary mapping the node paths to the loaded objects. 761 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 762 if the SavedModel contains a single MetaGraph, as for those exported from 763 `tf.saved_model.save`. 764 options: `tf.saved_model.LoadOptions` object that specifies options for 765 loading. 766 767 Returns: 768 A dictionary mapping node paths from the filter to loaded objects. 769 """ 770 return load_internal(export_dir, tags, options, filters=filters) 771 772 773@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 774def load(export_dir, tags=None, options=None): 775 """Load a SavedModel from `export_dir`. 776 777 Signatures associated with the SavedModel are available as functions: 778 779 ```python 780 imported = tf.saved_model.load(path) 781 f = imported.signatures["serving_default"] 782 print(f(x=tf.constant([[1.]]))) 783 ``` 784 785 Objects exported with `tf.saved_model.save` additionally have trackable 786 objects and functions assigned to attributes: 787 788 ```python 789 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 790 exported.f = tf.function( 791 lambda x: exported.v * x, 792 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 793 tf.saved_model.save(exported, path) 794 imported = tf.saved_model.load(path) 795 assert 3. == imported.v.numpy() 796 assert 6. == imported.f(x=tf.constant(2.)).numpy() 797 ``` 798 799 _Loading Keras models_ 800 801 Keras models are trackable, so they can be saved to SavedModel. The object 802 returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have 803 `.fit`, `.predict`, etc. methods). A few attributes and functions are still 804 available: `.variables`, `.trainable_variables` and `.__call__`. 805 806 ```python 807 model = tf.keras.Model(...) 808 tf.saved_model.save(model, path) 809 imported = tf.saved_model.load(path) 810 outputs = imported(inputs) 811 ``` 812 813 Use `tf.keras.models.load_model` to restore the Keras model. 814 815 _Importing SavedModels from TensorFlow 1.x_ 816 817 SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat 818 graph instead of `tf.function` objects. These SavedModels will be loaded with 819 the following attributes: 820 821 * `.signatures`: A dictionary mapping signature names to functions. 822 * `.prune(feeds, fetches) `: A method which allows you to extract 823 functions for new subgraphs. This is equivalent to importing the SavedModel 824 and naming feeds and fetches in a Session from TensorFlow 1.x. 825 826 ```python 827 imported = tf.saved_model.load(path_to_v1_saved_model) 828 pruned = imported.prune("x:0", "out:0") 829 pruned(tf.ones([])) 830 ``` 831 832 See `tf.compat.v1.wrap_function` for details. 833 * `.variables`: A list of imported variables. 834 * `.graph`: The whole imported graph. 835 * `.restore(save_path)`: A function that restores variables from a checkpoint 836 saved from `tf.compat.v1.Saver`. 837 838 _Consuming SavedModels asynchronously_ 839 840 When consuming SavedModels asynchronously (the producer is a separate 841 process), the SavedModel directory will appear before all files have been 842 written, and `tf.saved_model.load` will fail if pointed at an incomplete 843 SavedModel. Rather than checking for the directory, check for 844 "saved_model_dir/saved_model.pb". This file is written atomically as the last 845 `tf.saved_model.save` file operation. 846 847 Args: 848 export_dir: The SavedModel directory to load from. 849 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 850 if the SavedModel contains a single MetaGraph, as for those exported from 851 `tf.saved_model.save`. 852 options: `tf.saved_model.LoadOptions` object that specifies options for 853 loading. 854 855 Returns: 856 A trackable object with a `signatures` attribute mapping from signature 857 keys to functions. If the SavedModel was exported by `tf.saved_model.load`, 858 it also points to trackable objects, functions, debug info which it has been 859 saved. 860 861 Raises: 862 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 863 """ 864 return load_internal(export_dir, tags, options)["root"] 865 866 867def load_internal(export_dir, tags=None, options=None, loader_cls=Loader, 868 filters=None): 869 """Loader implementation.""" 870 options = options or load_options.LoadOptions() 871 if tags is not None and not isinstance(tags, set): 872 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 873 # sequences for nest.flatten, so we put those through as-is. 874 tags = nest.flatten(tags) 875 saved_model_proto, debug_info = ( 876 loader_impl.parse_saved_model_with_debug_info(export_dir)) 877 878 if (len(saved_model_proto.meta_graphs) == 1 and 879 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 880 meta_graph_def = saved_model_proto.meta_graphs[0] 881 if (tags is not None 882 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 883 raise ValueError( 884 ("The SavedModel at {} has one MetaGraph with tags {}, but got an " 885 "incompatible argument tags={} to tf.saved_model.load. You may omit " 886 "it, pass 'None', or pass matching tags.") 887 .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) 888 object_graph_proto = meta_graph_def.object_graph_def 889 890 ckpt_options = checkpoint_options.CheckpointOptions( 891 experimental_io_device=options.experimental_io_device) 892 with ops.init_scope(): 893 try: 894 loader = loader_cls(object_graph_proto, saved_model_proto, export_dir, 895 ckpt_options, filters) 896 except errors.NotFoundError as err: 897 raise FileNotFoundError( 898 str(err) + "\n If trying to load on a different device from the " 899 "computational device, consider using setting the " 900 "`experimental_io_device` option on tf.saved_model.LoadOptions " 901 "to the io_device such as '/job:localhost'." 902 ) 903 root = loader.get(0) 904 if isinstance(loader, Loader): 905 root.graph_debug_info = loader.adjust_debug_info_func_names(debug_info) 906 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 907 root.tensorflow_git_version = ( 908 meta_graph_def.meta_info_def.tensorflow_git_version) 909 else: 910 if filters: 911 raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any " 912 "version) cannot be loaded with node filters.") 913 with ops.init_scope(): 914 root = load_v1_in_v2.load(export_dir, tags) 915 root.graph_debug_info = debug_info 916 917 if filters: 918 return {node_id: loader.get(node_id) for node_id in filters} 919 else: 920 return {"root": root} 921