1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining legacy-graph-mode functions.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import weakref 23 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.core.protobuf import struct_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.eager import lift_to_graph 29from tensorflow.python.framework import composite_tensor 30from tensorflow.python.framework import func_graph 31from tensorflow.python.framework import importer 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import sparse_tensor 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.ops import resource_variable_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.saved_model import nested_structure_coder 41from tensorflow.python.training.tracking import data_structures 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44 45 46class VariableHolder(object): 47 """Holds variables for a python function.""" 48 49 def __init__(self, fn=None, share_variables=False): 50 self._fn = fn 51 52 self._share_variables = share_variables 53 self._variables_by_name = data_structures.Mapping() 54 55 @property 56 def variables(self): 57 return self._variables_by_name 58 59 def variable_creator_scope(self, next_creator, **kwargs): 60 """Creates variables & adds them to collections to match legacy code.""" 61 collections = kwargs.pop("collections", None) 62 v = None 63 64 # Get expected variable name. 65 with ops.name_scope( 66 kwargs.get("name", None), "Variable", skip_on_eager=False) as name: 67 variable_name = ops.name_from_scope_name(name) 68 kwargs["name"] = name 69 70 if self._share_variables: 71 v = self._variables_by_name.get(variable_name, None) 72 73 if v is None: 74 v = next_creator(**kwargs) 75 self._variables_by_name[variable_name] = v 76 77 if collections is None: 78 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 79 if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: 80 collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] 81 82 ops.add_to_collections(collections, v) 83 84 return v 85 86 def __call__(self, *args, **kwargs): 87 return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) 88 89 def call_with_variable_creator_scope(self, fn): 90 91 def wrapped(*args, **kwargs): 92 with variable_scope.variable_creator_scope(self.variable_creator_scope): 93 return fn(*args, **kwargs) 94 95 return wrapped 96 97 98def _get_element_from_tensor_info(tensor_info, graph): 99 """Simplified copy of the deprecated `get_tensor_from_tensor_info`.""" 100 encoding = tensor_info.WhichOneof("encoding") 101 if encoding == "name": 102 # We may get operations here in some cases. TensorInfo is a bit of a 103 # misnomer if so. 104 return graph.as_graph_element(tensor_info.name) 105 elif encoding == "coo_sparse": 106 return sparse_tensor.SparseTensor( 107 graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name), 108 graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name), 109 graph.get_tensor_by_name( 110 tensor_info.coo_sparse.dense_shape_tensor_name)) 111 elif encoding == "composite_tensor": 112 struct_coder = nested_structure_coder.StructureCoder() 113 spec_proto = struct_pb2.StructuredValue( 114 type_spec_value=tensor_info.composite_tensor.type_spec) 115 spec = struct_coder.decode_proto(spec_proto) 116 components = [graph.get_tensor_by_name(component.name) for component in 117 tensor_info.composite_tensor.components] 118 return spec._from_components(components) # pylint: disable=protected-access 119 else: 120 raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) 121 122 123def _lift_single_variable(old_variable, graph, variable_holder): 124 """Lifts `old_variable` out of the `FuncGraph` `graph`.""" 125 new_variable = resource_variable_ops.UninitializedVariable( 126 shape=old_variable.shape, 127 dtype=old_variable.dtype, 128 name=old_variable.op.name, 129 trainable=old_variable.trainable, 130 extra_handle_data=old_variable.handle) 131 new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access 132 graph.add_capture(new_variable.handle, old_variable.handle) 133 # Now that we've added the new variable to graph.captures, 134 # graph.capture will use that cached value and do some post-processing 135 # on the capture like recording it on the tape. 136 graph.capture(new_variable.handle) 137 # pylint: disable=protected-access 138 variable_name = new_variable.name.split(":")[0] 139 variable_holder._variables_by_name[variable_name] = new_variable 140 graph._weak_variables.append(weakref.ref(new_variable)) 141 # pylint: enable=protected-access 142 graph.watch_variable(new_variable) 143 return new_variable 144 145 146def _lift_unlifted_variables(graph, variable_holder): 147 """Finds resource variables and lifts them into the outer context. 148 149 When we import a GraphDef inside a wrap_function, no Python graph building 150 code runs. This means we get VarHandleOps which create variable resources, 151 but no corresponding Python objects. Leaving them like this works but gives 152 the user no way to interact with or modify the variables outside the graph. 153 154 This method searches for variables and lifts them out as regular variable 155 objects when possible, indicating to the FuncGraph that they are captures. 156 157 Args: 158 graph: The FuncGraph to lift variables from. 159 variable_holder: A VariableHolder to record the lifted variables in. 160 """ 161 with graph.as_default(): 162 global_collection_variables = ops.get_collection( 163 ops.GraphKeys.GLOBAL_VARIABLES) 164 local_collection_variables = ops.get_collection( 165 ops.GraphKeys.LOCAL_VARIABLES) 166 existing_captures = {id(c) for c in graph.internal_captures} 167 lifted_variables = {} 168 169 def _should_lift_variable(v): 170 return ((v._in_graph_mode # pylint: disable=protected-access 171 and v.graph.building_function) 172 and isinstance(v, resource_variable_ops.BaseResourceVariable) 173 and id(v.handle) not in existing_captures) 174 175 for old_variable in global_collection_variables: 176 if _should_lift_variable(old_variable): 177 new_variable = _lift_single_variable( 178 old_variable, graph, variable_holder) 179 lifted_variables[id(old_variable)] = new_variable 180 existing_captures.add(id(old_variable.handle)) 181 182 for old_variable in local_collection_variables: 183 if _should_lift_variable(old_variable): 184 new_variable = _lift_single_variable( 185 old_variable, graph, variable_holder) 186 lifted_variables[id(old_variable)] = new_variable 187 existing_captures.add(id(old_variable.handle)) 188 if new_variable._in_graph_mode: # pylint: disable=protected-access 189 outer_graph = new_variable.graph 190 # Variables are added to the global collection by default. In this 191 # case we only want the variable in the local collection, so we'll pop 192 # it out. 193 global_collection = outer_graph.get_collection_ref( 194 ops.GraphKeys.GLOBAL_VARIABLES) 195 global_collection.remove(new_variable) 196 outer_graph.add_to_collection( 197 ops.GraphKeys.LOCAL_VARIABLES, new_variable) 198 199 # Update the FuncGraph's collections, partly for the user and partly so this 200 # function is idempotent when it runs again in prune() calls. 201 for collection_name in [ 202 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES 203 ]: 204 mutable_collection = ops.get_collection_ref(collection_name) 205 for index, current in enumerate(mutable_collection): 206 mutable_collection[index] = lifted_variables.get(id(current), current) 207 if not resource_variable_ops.is_resource_variable( 208 mutable_collection[index]): 209 logging.log_first_n( 210 logging.WARN, 211 "Unable to create a python object for variable {} because it is " 212 "a reference variable. It may not be visible to training APIs. " 213 "If this is a problem, consider rebuilding the SavedModel after " 214 "running tf.compat.v1.enable_resource_variables().".format( 215 mutable_collection[index]), 216 5) 217 218 219# TODO(allenl): make this trackable 220class WrappedFunction(function.ConcreteFunction): 221 """Wraps a tf V1 piece of code in a function.""" 222 223 def __init__(self, fn_graph, variable_holder, attrs=None, signature=None): 224 self._variable_holder = variable_holder 225 _lift_unlifted_variables(fn_graph, variable_holder) 226 # We call __init__ after lifting variables so that the function's signature 227 # properly reflects the new captured inputs. 228 for f in fn_graph.as_graph_def().library.function: 229 context.context().add_function_def(f) 230 self._signature = signature 231 super(WrappedFunction, self).__init__(fn_graph, attrs=attrs) 232 233 def _call_impl(self, args, kwargs, cancellation_manager=None): 234 if self._arg_keywords is None: 235 if kwargs: 236 raise NotImplementedError( 237 "Keyword arguments not supported when calling a " 238 "wrap_function-decorated function.") 239 if self._signature is not None: 240 args = list(args) 241 for i, arg in enumerate(args): 242 if isinstance(self._signature[i], tensor_spec.DenseSpec): 243 args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype) 244 return self._call_flat(args, self.captured_inputs) 245 else: 246 return super(WrappedFunction, self)._call_impl( 247 args, kwargs, cancellation_manager) 248 249 def prune(self, feeds, fetches, name=None, input_signature=None): 250 """Extract a subgraph of this function's underlying graph. 251 252 Wraps the subgraph in a new `WrappedFunction` object. 253 254 Args: 255 feeds: Input tensors to the subgraph to extract, as `Tensor` objects. 256 fetches: Possibly-nested Python data structure containing information 257 about outputs of the target subgraph. Each entry can either be a 258 `Tensor` object (for data outputs), an `Operation` object (for control 259 outputs), or a `TensorInfo` proto. Any additional shape/dtype 260 information provided in a `TensorInfo` and not present in the original 261 graph will be added to the returned subgraph. 262 name: (optional) Name to give to the underlying `FuncGraph` of the 263 returned object. If no name is provided, the graph's name will be 264 `"pruned"`. 265 input_signature: (optional) possibly-nested Python data structure 266 containing `TensorSpec` objects, with which to populate the returned 267 functions's `FuncGraph`'s `structured_input_signature` field. 268 269 Returns: 270 A new `WrappedFunction` object containing a copy of the portion of this 271 object's graph that goes from `feeds` to `fetches`. 272 """ 273 # TODO(b/129646028): Add support for CompositeTensors. 274 name = name or "pruned" 275 flat_feeds = nest.flatten(feeds, expand_composites=True) 276 flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] 277 for f in flat_feeds: 278 if not isinstance(f, ops.Tensor): 279 raise ValueError("Feeds must be tensors.") 280 281 # Ignoring all feeds that are captures allows prune to be called 282 # using wrapped_func.inputs even when it uses variables 283 internal_captures = {id(c) for c in self.graph.internal_captures} 284 flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] 285 286 operation_fetches = [] 287 tensor_fetches = [] 288 tensor_infos = [] 289 290 def _fetch_preprocessing_callback(fetch): 291 """Extract out lists of ops, tensors, and tensor type info. 292 293 Turns TensorInfos into Tensors in the original `fetches` structure. 294 Also extracts ops from `fetches`. 295 296 Args: 297 fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or 298 string identifying a Tensor or Operation. 299 300 Returns: 301 `fetch` converted to a Tensor. 302 """ 303 if isinstance(fetch, ops.Operation): 304 operation_fetches.append(fetch) 305 return fetch 306 elif isinstance(fetch, meta_graph_pb2.TensorInfo): 307 tensor_infos.append(fetch) 308 decoded = _get_element_from_tensor_info(fetch, self._func_graph) 309 if (tensor_util.is_tf_type(decoded) or 310 isinstance(decoded, composite_tensor.CompositeTensor)): 311 tensor_fetches.append(decoded) 312 else: 313 operation_fetches.append(decoded) 314 return decoded 315 elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): 316 tensor_fetches.append(fetch) 317 return fetch 318 else: 319 graph_element = self.graph.as_graph_element(fetch) 320 return _fetch_preprocessing_callback(graph_element) 321 322 fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) 323 324 # Expand composite tensors into their component dense Tensors. 325 tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) 326 327 for f in (flat_feeds + tensor_fetches + operation_fetches): 328 if f.graph is not self._func_graph: 329 raise ValueError("Can only prune function whose feeds and fetches " 330 "are from this graph (%s). Input %s is from graph %s" % 331 (self._func_graph, f, f.graph)) 332 with self._func_graph.as_default(): 333 pruned_graph = func_graph.FuncGraph(name) 334 lift_map = lift_to_graph.lift_to_graph( 335 operation_fetches + tensor_fetches, 336 pruned_graph, 337 sources=flat_feeds + self.graph.internal_captures, 338 base_graph=self._func_graph) 339 340 # Note that we add the component tensors of any composite tensors to the 341 # returned function's outputs list; the list must contain these component 342 # tensors, or the function's sparse outputs won't work properly. 343 pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) 344 pruned_graph.control_outputs.extend( 345 [lift_map[operation] for operation in operation_fetches]) 346 pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) 347 for external_capture, internal_capture in self.graph.captures: 348 pruned_graph.add_capture(external_capture, lift_map[internal_capture]) 349 for ti in tensor_infos: 350 if ti.WhichOneof("encoding") == "name": # Dense tensors only 351 t = pruned_graph.as_graph_element(ti.name) 352 if tensor_util.is_tf_type(t): 353 t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) 354 # pylint: disable=protected-access 355 for f in self.graph._functions.values(): 356 pruned_graph._add_function(f) 357 # pylint: enable=protected-access 358 359 pruned_graph.variables = self.graph.variables 360 361 def _structured_output_mapping(fetched): 362 """callback for `nest.map_structure()`""" 363 lifted = lift_map[fetched] 364 if isinstance(lifted, ops.Operation): 365 return None 366 return lifted 367 368 # expand_composites=True here causes composite tensors to be expanded 369 # into their component dense Tensors, mapped to the new graph, and then 370 # reconstituted into their original composite form. 371 pruned_graph.structured_outputs = nest.map_structure( 372 _structured_output_mapping, fetches, expand_composites=True) 373 pruned_graph.structured_input_signature = input_signature 374 pruned_fn = WrappedFunction( 375 pruned_graph, variable_holder=self._variable_holder) 376 pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access 377 # TODO(kathywu): Enable keyword arguments if an input signature is specified 378 pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access 379 return pruned_fn 380 381 382def _filter_returned_ops(fn): 383 """Filtering out any ops returned by function. 384 385 Args: 386 fn: a function 387 388 Returns: 389 A tuple of ( 390 Wrapped function that returns `None` in place of any ops, 391 dict that maps the index in the flat output structure to the returned op 392 ) 393 """ 394 returned_ops = {} 395 396 def wrap_and_filter_returned_ops(*args, **kwargs): 397 outputs = fn(*args, **kwargs) 398 flat_outputs = nest.flatten(outputs) 399 for n in range(len(flat_outputs)): 400 output = flat_outputs[n] 401 if isinstance(output, ops.Operation): 402 returned_ops[n] = output 403 flat_outputs[n] = None 404 return nest.pack_sequence_as(outputs, flat_outputs) 405 406 return wrap_and_filter_returned_ops, returned_ops 407 408 409class WrappedGraph(object): 410 """Class for wrapping multiple TF 1.X functions in a single graph. 411 412 Maintains a dictionary mapping names to wrapped functions. See 413 `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. 414 415 Functions wrapped using this class have access to variables and collections 416 created in other wrapped functions, using the standard TF 1.X API ( 417 `tf.compat.v1.get_variable` or 418 `tf.compat.v1.get_default_graph().get_collection(...)`) 419 420 Outside a function, variables and collections may be accessed using the 421 `variables` and `graph` properties. 422 423 Example: 424 425 ``` 426 def add_v1(x): 427 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 428 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 429 return v + x 430 431 def increment_var_v1(x): 432 with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE): 433 v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) 434 return v.assign_add(x) 435 436 g = WrappedGraph() 437 add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) 438 increment_var = g.wrap_function(increment_var_v1, 439 [tf.TensorSpec([], tf.int32)]) 440 441 assert len(g.variables) == 1 442 assert g.variables[0].numpy() == 0 443 increment_var(tf.constant(5)) 444 assert g.variables[0].numpy() == 5 445 446 ``` 447 """ 448 449 def __init__(self, variable_holder=None, **kwargs): 450 self._variable_holder = ( 451 variable_holder or VariableHolder(share_variables=True)) 452 453 name = kwargs.pop("name", "wrapped_function_graph") 454 # Always start with empty collections, unless otherwise specified. Setting 455 # `collections=None` will copy the collections from the outer graph. 456 collections = kwargs.pop("collections", {}) 457 self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) 458 459 self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) 460 self._functions = {} 461 462 @property 463 def functions(self): 464 return self._functions 465 466 @property 467 def variables(self): 468 return self._variable_holder.variables 469 470 def wrap_function(self, fn, signature, name=None): 471 """Wraps a TF 1.X function and returns an eager-compatible function. 472 473 All functions wrapped in the same `WrappedGraph` will have access to the 474 same graph (`tf.compat.v1.get_default_graph` to get the graph object 475 within a function, or `WrappedGraph.graph` to get the graph outside a 476 function). Variables created within the function will be added to the 477 `variables` list. 478 479 Function inputs: All inputs to the function must be tensors (nested ok), 480 with their shapes and dtypes defined in the `signature` argument. 481 482 Function outputs: 483 484 * The 1.X function may return tensors, variables, and ops. The wrapped 485 eager-compatible function will always return tensors in the same nested 486 structure. 487 * Variables are replaced with a tensor containing the latest read values. 488 * Returned ops are executed, and replaced with None. 489 * The order of op execution and variable reads in the return is 490 nondeterministic. For example: 491 492 ``` 493 def update_var(x): 494 v = tf.Variable(0) 495 op = tf.compat.v1.assign(v, x).op 496 return v, op 497 498 g = WrappedGraph() 499 fn = g.wrap_function(update_var) 500 read_value, _ = fn(tf.constant(3)) 501 print(read_value.numpy()) # could be 0 or 3 502 print(g.variables[0].numpy()) # always 3 503 ``` 504 505 To ensure that ops in the function are executed (e.g. ops added to the 506 `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns. 507 508 Args: 509 fn: a 1.X tensorflow function. 510 signature: a possibly nested sequence of `TensorSpecs` specifying the 511 shapes and dtypes of the arguments. 512 name: an optional string name for the function. The function will be saved 513 with key `name` in the `functions` dictionary. 514 515 Returns: 516 An eager-compatible function. 517 """ 518 return self._wrap_function(fn, signature=signature, name=name) 519 520 def _wrap_function(self, 521 fn, 522 args=None, 523 kwargs=None, 524 signature=None, 525 name=None): 526 """Internal wrap function method with extended func_graph arguments.""" 527 fn_with_filter_and_scope, returned_ops = _filter_returned_ops( 528 self._variable_holder.call_with_variable_creator_scope(fn)) 529 530 func_graph.func_graph_from_py_func( 531 None, # Name is unused. 532 fn_with_filter_and_scope, 533 args=args, 534 kwargs=kwargs, 535 signature=signature, 536 add_control_dependencies=False, 537 func_graph=self.graph) 538 539 # This code relies on questional behavior from `func_graph_from_py_func`. 540 # If an existing FuncGraph is passed into the `func_graph` arg, the inputs 541 # and structured outputs are overwritten. Pretty sure this is a bug, 542 # because structured outputs doesn't match up with the outputs... 543 fn_inputs = self.graph.inputs[:-len(self.graph.captures)] 544 545 # Return filtered ops to the flattened outputs. 546 flat_fn_outputs = nest.flatten(self.graph.structured_outputs) 547 for index, op in returned_ops.items(): 548 flat_fn_outputs[index] = op 549 fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs, 550 flat_fn_outputs) 551 552 name = name or fn.__name__ 553 wrapped_function = self._wrapped_function.prune( 554 fn_inputs, fn_outputs, name, self.graph.structured_input_signature) 555 self._functions[name] = wrapped_function 556 return wrapped_function 557 558 559@tf_export(v1=["wrap_function"]) 560def wrap_function(fn, signature, name=None): 561 """Wraps the TF 1.x function fn into a graph function. 562 563 The python function `fn` will be called once with symbolic arguments specified 564 in the `signature`, traced, and turned into a graph function. Any variables 565 created by `fn` will be owned by the object returned by `wrap_function`. The 566 resulting graph function can be called with tensors which match the 567 signature. 568 569 ```python 570 def f(x, do_add): 571 v = tf.Variable(5.0) 572 if do_add: 573 op = v.assign_add(x) 574 else: 575 op = v.assign_sub(x) 576 with tf.control_dependencies([op]): 577 return v.read_value() 578 579 f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True]) 580 581 assert float(f_add(1.0)) == 6.0 582 assert float(f_add(1.0)) == 7.0 583 584 # Can call tf.compat.v1.wrap_function again to get a new trace, a new set 585 # of variables, and possibly different non-template arguments. 586 f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False]) 587 588 assert float(f_sub(1.0)) == 4.0 589 assert float(f_sub(1.0)) == 3.0 590 ``` 591 592 Both `tf.compat.v1.wrap_function` and `tf.function` create a callable 593 TensorFlow graph. But while `tf.function` runs all stateful operations 594 (e.g. `tf.print`) and sequences operations to provide the same semantics as 595 eager execution, `wrap_function` is closer to the behavior of `session.run` in 596 TensorFlow 1.x. It will not run any operations unless they are required to 597 compute the function's outputs, either through a data dependency or a control 598 dependency. Nor will it sequence operations. 599 600 Unlike `tf.function`, `wrap_function` will only trace the Python function 601 once. As with placeholders in TF 1.x, shapes and dtypes must be provided to 602 `wrap_function`'s `signature` argument. 603 604 Since it is only traced once, variables and state may be created inside the 605 function and owned by the function wrapper object. 606 607 Args: 608 fn: python function to be wrapped 609 signature: the placeholder and python arguments to be passed to the wrapped 610 function 611 name: Optional. The name of the function. 612 613 Returns: 614 the wrapped graph function. 615 """ 616 holder = VariableHolder(fn) 617 func_graph_name = "wrapped_function" 618 if name is not None: 619 func_graph_name = "wrapped_function_" + name 620 return WrappedFunction( 621 func_graph.func_graph_from_py_func( 622 func_graph_name, 623 holder, 624 args=None, 625 kwargs=None, 626 signature=signature, 627 add_control_dependencies=False, 628 collections={}), 629 variable_holder=holder, 630 signature=signature) 631 632 633def function_from_graph_def(graph_def, inputs, outputs): 634 """Creates a ConcreteFunction from a GraphDef. 635 636 Args: 637 graph_def: A GraphDef to make a function out of. 638 inputs: A Tensor name or nested structure of names in `graph_def` which 639 should be inputs to the function. 640 outputs: A Tensor name or nested structure of names in `graph_def` which 641 should be outputs of the function. 642 643 Returns: 644 A ConcreteFunction. 645 """ 646 647 def _imports_graph_def(): 648 importer.import_graph_def(graph_def, name="") 649 650 wrapped_import = wrap_function(_imports_graph_def, []) 651 import_graph = wrapped_import.graph 652 return wrapped_import.prune( 653 nest.map_structure(import_graph.as_graph_element, inputs), 654 nest.map_structure(import_graph.as_graph_element, outputs)) 655