1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Import a trackable object from a SavedModel.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.ops import variables 30from tensorflow.python.saved_model import function_deserialization 31from tensorflow.python.saved_model import load_v1_in_v2 32from tensorflow.python.saved_model import loader_impl 33from tensorflow.python.saved_model import nested_structure_coder 34from tensorflow.python.saved_model import revived_types 35from tensorflow.python.saved_model import utils_impl as saved_model_utils 36from tensorflow.python.training.tracking import base 37from tensorflow.python.training.tracking import graph_view 38from tensorflow.python.training.tracking import tracking 39from tensorflow.python.training.tracking import util 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import tf_export 42 43 44class _Loader(object): 45 """Helper class to load an object-based SavedModel.""" 46 47 def __init__(self, object_graph_proto, saved_model_proto, export_dir): 48 meta_graph = saved_model_proto.meta_graphs[0] 49 self._asset_file_def = meta_graph.asset_file_def 50 self._operation_attributes = { 51 node.name: node.attr for node in meta_graph.graph_def.node} 52 self._proto = object_graph_proto 53 self._export_dir = export_dir 54 self._concrete_functions = ( 55 function_deserialization.load_function_def_library( 56 meta_graph.graph_def.library)) 57 self._load_all() 58 # TODO(b/124045874): There are limitations with functions whose captures 59 # trigger other functions to be executed. For now it is only guaranteed to 60 # work if the captures of a function only trigger functions without 61 # captures. 62 self._setup_functions_structures() 63 self._setup_functions_captures() 64 self._restore_checkpoint() 65 66 for node in self._nodes: 67 if isinstance(node, tracking.TrackableResource): 68 init_op = node._initialize() # pylint: disable=protected-access 69 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 70 71 def _setup_functions_structures(self): 72 """Setup structure for inputs and outputs of restored functions.""" 73 coder = nested_structure_coder.StructureCoder() 74 for name, proto in sorted(self._proto.concrete_functions.items()): 75 concrete_function = self._concrete_functions[name] 76 # By setting the structured_outputs directly, we can rely on this 77 # function_lib.ConcreteFunction object to perform the output repacking 78 # logic. The only limitation of that logic is that it only works 79 # with output that is convertible to Tensors and the conversion 80 # always happens. For example tf.TensorShape([2, 3]) will be 81 # converted to Tensor representing [2, 3]. 82 original_outputs = coder.decode_proto(proto.output_signature) 83 # The original_outputs here had Tensors converted to TensorSpecs, so 84 # the restored function's structured_outputs field will not be 85 # exactly the same. Fortunately the repacking logic cares only about 86 # the structure. 87 # TODO(vbardiovsky): Should we just replicate the structures, with 88 # Nones instead of real objects? 89 concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access 90 concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access 91 coder.decode_proto(proto.canonicalized_input_signature)) 92 93 def _setup_functions_captures(self): 94 """Setup captures and variables in restored functions.""" 95 concrete_functions = sorted(self._proto.concrete_functions.items()) 96 for name, proto in concrete_functions: 97 concrete_function = self._concrete_functions[name] 98 bound_inputs = [ 99 self._get_tensor_from_node(node_id) 100 for node_id in proto.bound_inputs] 101 bound_variables = [ 102 self._nodes[node_id] 103 for node_id in proto.bound_inputs 104 if self._proto.nodes[node_id].WhichOneof("kind") == "variable" 105 ] 106 # TODO(andresp): This is only injecting the captured inputs into the 107 # concrete function, note that we did not modify the FuncGraph 108 # itself. 109 concrete_function._captured_inputs = bound_inputs # pylint: disable=protected-access 110 concrete_function._func_graph.variables = bound_variables # pylint: disable=protected-access 111 112 def _get_tensor_from_node(self, node_id): 113 """Resolves a node id into a tensor to be captured for a function.""" 114 with ops.init_scope(): 115 obj = self._nodes[node_id] 116 if resource_variable_ops.is_resource_variable(obj): 117 return obj.handle 118 elif isinstance(obj, tracking.TrackableAsset): 119 return obj.asset_path 120 elif tensor_util.is_tensor(obj): 121 return obj 122 elif isinstance(obj, tracking.TrackableResource): 123 # Note: this executes restored functions in the TrackableResource. 124 return obj.resource_handle 125 raise ValueError("Can't convert node %s to tensor" % (type(obj))) 126 127 def _load_all(self): 128 """Load all saved objects and wire their properties.""" 129 # Maps from node ids to recreated objects 130 nodes = {} 131 # Maps from node ids to setter functions (same signature as setattr) for 132 # setting dependencies. 133 node_setters = {} 134 135 # Figure out which objects are slot variables. These objects are created 136 # with Optimizer.add_slot rather than _recreate_variable. 137 slot_variable_node_ids = set() 138 for proto in self._proto.nodes: 139 for slot_variable_proto in proto.slot_variables: 140 slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id) 141 142 # Re-create everything except slot variables. 143 for node_id, proto in enumerate(self._proto.nodes): 144 if node_id in slot_variable_node_ids: 145 # Defer recreating slot variables so we can use the public Optimizer 146 # interface. 147 continue 148 node, setter = self._recreate(proto) 149 nodes[node_id] = node 150 node_setters[node_id] = setter 151 152 # Now that we have created the variables being optimized, we have enough 153 # information to re-create slot variables for them. 154 for node_id, proto in enumerate(self._proto.nodes): 155 optimizer_object = nodes[node_id] 156 for slot_variable_proto in proto.slot_variables: 157 optimized_variable = nodes[ 158 slot_variable_proto.original_variable_node_id] 159 slot_variable = optimizer_object.add_slot( 160 var=optimized_variable, 161 slot_name=slot_variable_proto.slot_name) 162 nodes[slot_variable_proto.slot_variable_node_id] = slot_variable 163 node_setters[slot_variable_proto.slot_variable_node_id] = setattr 164 165 self._nodes = [] 166 167 # After creating the objects, construct the edges between the objects. 168 for node_id, object_proto in enumerate(self._proto.nodes): 169 obj = nodes[node_id] 170 setter = node_setters[node_id] 171 self._nodes.append(obj) 172 173 for reference in object_proto.children: 174 setter(obj, reference.local_name, nodes[reference.node_id]) 175 # Note: if an object has an attribute `__call__` add a class method 176 # that allows `obj()` syntax to work. This is done per-instance to 177 # allow `callable` to be used to find out if an object is callable. 178 if reference.local_name == "__call__": 179 setattr(type(obj), "__call__", _call_attribute) 180 181 def _restore_checkpoint(self): 182 """Load state from checkpoint into the deserialized objects.""" 183 variables_path = saved_model_utils.get_variables_path(self._export_dir) 184 # TODO(andresp): Clean use of private methods of TrackableSaver. 185 # pylint: disable=protected-access 186 saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0))) 187 saver._file_prefix_placeholder = constant_op.constant(variables_path) 188 load_status = saver.restore(variables_path) 189 load_status.assert_existing_objects_matched() 190 checkpoint = load_status._checkpoint 191 192 # When running in eager mode, the `restore` call above has already run and 193 # restored the state of trackables, call `position.restore_ops()` will 194 # return an empty list as there is nothing left to do. In graph mode, that 195 # will return the list of ops that must run to restore the object on that 196 # position. We have to wire them in the initializers of the objects so that 197 # they get initialized properly when using common practices (e.g. the ones 198 # used by ManagedSession) without further user action. 199 for object_id, obj in dict(checkpoint.object_by_proto_id).items(): 200 position = base.CheckpointPosition(checkpoint=checkpoint, 201 proto_id=object_id) 202 restore_ops = position.restore_ops() 203 if restore_ops: 204 if resource_variable_ops.is_resource_variable(obj): 205 obj._initializer_op = restore_ops 206 else: 207 raise NotImplementedError( 208 ("Missing functionality to restore state of object " 209 "%r from the checkpoint." % obj)) 210 211 def get(self, node_id): 212 return self._nodes[node_id] 213 214 def _recreate(self, proto): 215 """Creates a Python object from a SavedObject protocol buffer.""" 216 factory = { 217 "user_object": lambda: self._recreate_user_object(proto.user_object), 218 "asset": lambda: self._recreate_asset(proto.asset), 219 "function": lambda: self._recreate_function(proto.function), 220 "bare_concrete_function": functools.partial( 221 self._recreate_bare_concrete_function, 222 proto.bare_concrete_function), 223 "variable": lambda: self._recreate_variable(proto.variable), 224 "constant": lambda: self._recreate_constant(proto.constant), 225 "resource": lambda: self._recreate_resource(proto.resource), 226 } 227 kind = proto.WhichOneof("kind") 228 if kind not in factory: 229 raise ValueError("Unknown SavedObject type: %r" % kind) 230 return factory[kind]() 231 232 def _recreate_user_object(self, proto): 233 """Instantiates a SavedUserObject.""" 234 looked_up = revived_types.deserialize(proto) 235 if looked_up is None: 236 # Note: each user object has its own class. This allows to make each one 237 # individually callable by adding a `__call__` method to the classes of 238 # the objects instances that have a `__call__` property. 239 240 class _UserObject(tracking.AutoTrackable): 241 pass 242 243 return _UserObject(), setattr 244 return looked_up 245 246 def _recreate_asset(self, proto): 247 filename = os.path.join( 248 saved_model_utils.get_assets_dir(self._export_dir), 249 self._asset_file_def[proto.asset_file_def_index].filename) 250 return tracking.TrackableAsset(filename), setattr 251 252 def _recreate_function(self, proto): 253 return function_deserialization.recreate_function( 254 proto, self._concrete_functions), setattr 255 256 def _recreate_bare_concrete_function(self, proto): 257 return function_deserialization.setup_bare_concrete_function( 258 proto, self._concrete_functions), setattr 259 260 def _recreate_variable(self, proto): 261 # TODO(andresp): Can we use the checkpointed value as initializer? 262 dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape) 263 return variables.Variable(dummy_value, trainable=proto.trainable), setattr 264 265 def _recreate_constant(self, proto): 266 tensor_proto = self._operation_attributes[proto.operation]["value"].tensor 267 imported_constant = constant_op.constant( 268 tensor_util.MakeNdarray(tensor_proto)) 269 return imported_constant, setattr 270 271 def _recreate_resource(self, proto): 272 del proto 273 return _RestoredResource(), setattr 274 275 276# TODO(b/124205571,b/124092991): Solve destruction of resources. 277class _RestoredResource(tracking.TrackableResource): 278 """Restored SavedResource.""" 279 280 def _create_resource(self): 281 raise RuntimeError() 282 283 def _initialize(self): 284 raise RuntimeError() 285 286 def _list_functions_for_serialization(self): 287 # Overwrite this method to avoid the implementation of 288 # base class to re-wrap the polymorphic functions into 289 # another layer of `tf.function`. 290 return { 291 "_create_resource": self._create_resource, 292 "_initialize": self._initialize, 293 } 294 295 296def _call_attribute(instance, *args, **kwargs): 297 return instance.__call__(*args, **kwargs) 298 299 300@tf_export("saved_model.load", v1=["saved_model.load_v2"]) 301def load(export_dir, tags=None): 302 """Load a SavedModel from `export_dir`. 303 304 Signatures associated with the SavedModel are available as functions: 305 306 ```python 307 imported = tf.saved_model.load(path) 308 f = imported.signatures["serving_default"] 309 print(f(x=tf.constant([[1.]]))) 310 ``` 311 312 Objects exported with `tf.saved_model.save` additionally have trackable 313 objects and functions assigned to attributes: 314 315 ```python 316 exported = tf.train.Checkpoint(v=tf.Variable(3.)) 317 exported.f = tf.function( 318 lambda x: exported.v * x, 319 input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 320 tf.saved_model.save(exported, path) 321 imported = tf.saved_model.load(path) 322 assert 3. == imported.v.numpy() 323 assert 6. == imported.f(x=tf.constant(2.)).numpy() 324 ``` 325 326 Args: 327 export_dir: The SavedModel directory to load from. 328 tags: A tag or sequence of tags identifying the MetaGraph to load. Optional 329 if the SavedModel contains a single MetaGraph, as for those exported from 330 `tf.saved_model.load`. 331 332 Returns: 333 A trackable object with a `signatures` attribute mapping from signature 334 keys to functions. If the SavedModel was exported by `tf.saved_model.load`, 335 it also points to trackable objects and functions which were attached 336 to the exported object. 337 338 Raises: 339 ValueError: If `tags` don't match a MetaGraph in the SavedModel. 340 """ 341 if tags is not None and not isinstance(tags, set): 342 # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered 343 # sequences for nest.flatten, so we put those through as-is. 344 tags = nest.flatten(tags) 345 saved_model_proto = loader_impl.parse_saved_model(export_dir) 346 if (len(saved_model_proto.meta_graphs) == 1 347 and saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 348 meta_graph_def = saved_model_proto.meta_graphs[0] 349 if (tags is not None 350 and set(tags) != set(meta_graph_def.meta_info_def.tags)): 351 raise ValueError( 352 ("The SavedModel at {} has one MetaGraph with tags {}, but got an " 353 "incompatible argument tags={} to tf.saved_model.load. You may omit " 354 "it, pass 'None', or pass matching tags.") 355 .format(export_dir, meta_graph_def.meta_info_def.tags, tags)) 356 object_graph_proto = meta_graph_def.object_graph_def 357 with ops.init_scope(): 358 loader = _Loader(object_graph_proto, 359 saved_model_proto, 360 export_dir) 361 root = loader.get(0) 362 else: 363 with ops.init_scope(): 364 root = load_v1_in_v2.load(export_dir, tags) 365 return root 366