1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining legacy-graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import weakref
23
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.core.protobuf import struct_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.eager import lift_to_graph
29from tensorflow.python.framework import composite_tensor
30from tensorflow.python.framework import func_graph
31from tensorflow.python.framework import importer
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import sparse_tensor
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.saved_model import nested_structure_coder
41from tensorflow.python.training.tracking import data_structures
42from tensorflow.python.util import nest
43from tensorflow.python.util.tf_export import tf_export
44
45
46class VariableHolder(object):
47  """Holds variables for a python function."""
48
49  def __init__(self, fn=None, share_variables=False):
50    self._fn = fn
51
52    self._share_variables = share_variables
53    self._variables_by_name = data_structures.Mapping()
54
55  @property
56  def variables(self):
57    return self._variables_by_name
58
59  def variable_creator_scope(self, next_creator, **kwargs):
60    """Creates variables & adds them to collections to match legacy code."""
61    collections = kwargs.pop("collections", None)
62    v = None
63
64    # Get expected variable name.
65    with ops.name_scope(
66        kwargs.get("name", None), "Variable", skip_on_eager=False) as name:
67      variable_name = ops.name_from_scope_name(name)
68      kwargs["name"] = name
69
70    if self._share_variables:
71      v = self._variables_by_name.get(variable_name, None)
72
73    if v is None:
74      v = next_creator(**kwargs)
75      self._variables_by_name[variable_name] = v
76
77    if collections is None:
78      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
79    if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
80      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
81
82    ops.add_to_collections(collections, v)
83
84    return v
85
86  def __call__(self, *args, **kwargs):
87    return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
88
89  def call_with_variable_creator_scope(self, fn):
90
91    def wrapped(*args, **kwargs):
92      with variable_scope.variable_creator_scope(self.variable_creator_scope):
93        return fn(*args, **kwargs)
94
95    return wrapped
96
97
98def _get_element_from_tensor_info(tensor_info, graph):
99  """Simplified copy of the deprecated `get_tensor_from_tensor_info`."""
100  encoding = tensor_info.WhichOneof("encoding")
101  if encoding == "name":
102    # We may get operations here in some cases. TensorInfo is a bit of a
103    # misnomer if so.
104    return graph.as_graph_element(tensor_info.name)
105  elif encoding == "coo_sparse":
106    return sparse_tensor.SparseTensor(
107        graph.get_tensor_by_name(tensor_info.coo_sparse.indices_tensor_name),
108        graph.get_tensor_by_name(tensor_info.coo_sparse.values_tensor_name),
109        graph.get_tensor_by_name(
110            tensor_info.coo_sparse.dense_shape_tensor_name))
111  elif encoding == "composite_tensor":
112    struct_coder = nested_structure_coder.StructureCoder()
113    spec_proto = struct_pb2.StructuredValue(
114        type_spec_value=tensor_info.composite_tensor.type_spec)
115    spec = struct_coder.decode_proto(spec_proto)
116    components = [graph.get_tensor_by_name(component.name) for component in
117                  tensor_info.composite_tensor.components]
118    return spec._from_components(components)  # pylint: disable=protected-access
119  else:
120    raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
121
122
123def _lift_single_variable(old_variable, graph, variable_holder):
124  """Lifts `old_variable` out of the `FuncGraph` `graph`."""
125  new_variable = resource_variable_ops.UninitializedVariable(
126      shape=old_variable.shape,
127      dtype=old_variable.dtype,
128      name=old_variable.op.name,
129      trainable=old_variable.trainable,
130      extra_handle_data=old_variable.handle)
131  new_variable._initializer_op = old_variable._initializer_op  # pylint: disable=protected-access
132  graph.add_capture(new_variable.handle, old_variable.handle)
133  # Now that we've added the new variable to graph.captures,
134  # graph.capture will use that cached value and do some post-processing
135  # on the capture like recording it on the tape.
136  graph.capture(new_variable.handle)
137  # pylint: disable=protected-access
138  variable_name = new_variable.name.split(":")[0]
139  variable_holder._variables_by_name[variable_name] = new_variable
140  graph._weak_variables.append(weakref.ref(new_variable))
141  # pylint: enable=protected-access
142  graph.watch_variable(new_variable)
143  return new_variable
144
145
146def _lift_unlifted_variables(graph, variable_holder):
147  """Finds resource variables and lifts them into the outer context.
148
149  When we import a GraphDef inside a wrap_function, no Python graph building
150  code runs. This means we get VarHandleOps which create variable resources,
151  but no corresponding Python objects. Leaving them like this works but gives
152  the user no way to interact with or modify the variables outside the graph.
153
154  This method searches for variables and lifts them out as regular variable
155  objects when possible, indicating to the FuncGraph that they are captures.
156
157  Args:
158    graph: The FuncGraph to lift variables from.
159    variable_holder: A VariableHolder to record the lifted variables in.
160  """
161  with graph.as_default():
162    global_collection_variables = ops.get_collection(
163        ops.GraphKeys.GLOBAL_VARIABLES)
164    local_collection_variables = ops.get_collection(
165        ops.GraphKeys.LOCAL_VARIABLES)
166    existing_captures = {id(c) for c in graph.internal_captures}
167    lifted_variables = {}
168
169    def _should_lift_variable(v):
170      return ((v._in_graph_mode  # pylint: disable=protected-access
171               and v.graph.building_function)
172              and isinstance(v, resource_variable_ops.BaseResourceVariable)
173              and id(v.handle) not in existing_captures)
174
175    for old_variable in global_collection_variables:
176      if _should_lift_variable(old_variable):
177        new_variable = _lift_single_variable(
178            old_variable, graph, variable_holder)
179        lifted_variables[id(old_variable)] = new_variable
180        existing_captures.add(id(old_variable.handle))
181
182    for old_variable in local_collection_variables:
183      if _should_lift_variable(old_variable):
184        new_variable = _lift_single_variable(
185            old_variable, graph, variable_holder)
186        lifted_variables[id(old_variable)] = new_variable
187        existing_captures.add(id(old_variable.handle))
188        if new_variable._in_graph_mode:  # pylint: disable=protected-access
189          outer_graph = new_variable.graph
190          # Variables are added to the global collection by default. In this
191          # case we only want the variable in the local collection, so we'll pop
192          # it out.
193          global_collection = outer_graph.get_collection_ref(
194              ops.GraphKeys.GLOBAL_VARIABLES)
195          global_collection.remove(new_variable)
196          outer_graph.add_to_collection(
197              ops.GraphKeys.LOCAL_VARIABLES, new_variable)
198
199    # Update the FuncGraph's collections, partly for the user and partly so this
200    # function is idempotent when it runs again in prune() calls.
201    for collection_name in [
202        ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
203    ]:
204      mutable_collection = ops.get_collection_ref(collection_name)
205      for index, current in enumerate(mutable_collection):
206        mutable_collection[index] = lifted_variables.get(id(current), current)
207        if not resource_variable_ops.is_resource_variable(
208            mutable_collection[index]):
209          logging.log_first_n(
210              logging.WARN,
211              "Unable to create a python object for variable {} because it is "
212              "a reference variable. It may not be visible to training APIs. "
213              "If this is a problem, consider rebuilding the SavedModel after "
214              "running tf.compat.v1.enable_resource_variables().".format(
215                  mutable_collection[index]),
216              5)
217
218
219# TODO(allenl): make this trackable
220class WrappedFunction(function.ConcreteFunction):
221  """Wraps a tf V1 piece of code in a function."""
222
223  def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
224    self._variable_holder = variable_holder
225    _lift_unlifted_variables(fn_graph, variable_holder)
226    # We call __init__ after lifting variables so that the function's signature
227    # properly reflects the new captured inputs.
228    for f in fn_graph.as_graph_def().library.function:
229      context.context().add_function_def(f)
230    self._signature = signature
231    super(WrappedFunction, self).__init__(fn_graph, attrs=attrs)
232
233  def _call_impl(self, args, kwargs, cancellation_manager=None):
234    if self._arg_keywords is None:
235      if kwargs:
236        raise NotImplementedError(
237            "Keyword arguments not supported when calling a "
238            "wrap_function-decorated function.")
239      if self._signature is not None:
240        args = list(args)
241        for i, arg in enumerate(args):
242          if isinstance(self._signature[i], tensor_spec.DenseSpec):
243            args[i] = ops.convert_to_tensor(arg, self._signature[i].dtype)
244      return self._call_flat(args, self.captured_inputs)
245    else:
246      return super(WrappedFunction, self)._call_impl(
247          args, kwargs, cancellation_manager)
248
249  def prune(self, feeds, fetches, name=None, input_signature=None):
250    """Extract a subgraph of this function's underlying graph.
251
252    Wraps the subgraph in a new `WrappedFunction` object.
253
254    Args:
255      feeds: Input tensors to the subgraph to extract, as `Tensor` objects.
256      fetches: Possibly-nested Python data structure containing information
257        about outputs of the target subgraph. Each entry can either be a
258        `Tensor` object (for data outputs), an `Operation` object (for control
259        outputs), or a `TensorInfo` proto. Any additional shape/dtype
260        information provided in a `TensorInfo` and not present in the original
261        graph will be added to the returned subgraph.
262      name: (optional) Name to give to the underlying `FuncGraph` of the
263        returned object. If no name is provided, the graph's name will be
264        `"pruned"`.
265      input_signature: (optional) possibly-nested Python data structure
266        containing `TensorSpec` objects, with which to populate the returned
267        functions's `FuncGraph`'s `structured_input_signature` field.
268
269    Returns:
270      A new `WrappedFunction` object containing a copy of the portion of this
271        object's graph that goes from `feeds` to `fetches`.
272    """
273    # TODO(b/129646028): Add support for CompositeTensors.
274    name = name or "pruned"
275    flat_feeds = nest.flatten(feeds, expand_composites=True)
276    flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds]
277    for f in flat_feeds:
278      if not isinstance(f, ops.Tensor):
279        raise ValueError("Feeds must be tensors.")
280
281    # Ignoring all feeds that are captures allows prune to be called
282    # using wrapped_func.inputs even when it uses variables
283    internal_captures = {id(c) for c in self.graph.internal_captures}
284    flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures]
285
286    operation_fetches = []
287    tensor_fetches = []
288    tensor_infos = []
289
290    def _fetch_preprocessing_callback(fetch):
291      """Extract out lists of ops, tensors, and tensor type info.
292
293      Turns TensorInfos into Tensors in the original `fetches` structure.
294      Also extracts ops from `fetches`.
295
296      Args:
297        fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or
298          string identifying a Tensor or Operation.
299
300      Returns:
301        `fetch` converted to a Tensor.
302      """
303      if isinstance(fetch, ops.Operation):
304        operation_fetches.append(fetch)
305        return fetch
306      elif isinstance(fetch, meta_graph_pb2.TensorInfo):
307        tensor_infos.append(fetch)
308        decoded = _get_element_from_tensor_info(fetch, self._func_graph)
309        if (tensor_util.is_tf_type(decoded) or
310            isinstance(decoded, composite_tensor.CompositeTensor)):
311          tensor_fetches.append(decoded)
312        else:
313          operation_fetches.append(decoded)
314        return decoded
315      elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)):
316        tensor_fetches.append(fetch)
317        return fetch
318      else:
319        graph_element = self.graph.as_graph_element(fetch)
320        return _fetch_preprocessing_callback(graph_element)
321
322    fetches = nest.map_structure(_fetch_preprocessing_callback, fetches)
323
324    # Expand composite tensors into their component dense Tensors.
325    tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
326
327    for f in (flat_feeds + tensor_fetches + operation_fetches):
328      if f.graph is not self._func_graph:
329        raise ValueError("Can only prune function whose feeds and fetches "
330                         "are from this graph (%s). Input %s is from graph %s" %
331                         (self._func_graph, f, f.graph))
332    with self._func_graph.as_default():
333      pruned_graph = func_graph.FuncGraph(name)
334    lift_map = lift_to_graph.lift_to_graph(
335        operation_fetches + tensor_fetches,
336        pruned_graph,
337        sources=flat_feeds + self.graph.internal_captures,
338        base_graph=self._func_graph)
339
340    # Note that we add the component tensors of any composite tensors to the
341    # returned function's outputs list; the list must contain these component
342    # tensors, or the function's sparse outputs won't work properly.
343    pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
344    pruned_graph.control_outputs.extend(
345        [lift_map[operation] for operation in operation_fetches])
346    pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
347    for external_capture, internal_capture in self.graph.captures:
348      pruned_graph.add_capture(external_capture, lift_map[internal_capture])
349    for ti in tensor_infos:
350      if ti.WhichOneof("encoding") == "name":  # Dense tensors only
351        t = pruned_graph.as_graph_element(ti.name)
352        if tensor_util.is_tf_type(t):
353          t.set_shape(tensor_shape.TensorShape(ti.tensor_shape))
354    # pylint: disable=protected-access
355    for f in self.graph._functions.values():
356      pruned_graph._add_function(f)
357    # pylint: enable=protected-access
358
359    pruned_graph.variables = self.graph.variables
360
361    def _structured_output_mapping(fetched):
362      """callback for `nest.map_structure()`"""
363      lifted = lift_map[fetched]
364      if isinstance(lifted, ops.Operation):
365        return None
366      return lifted
367
368    # expand_composites=True here causes composite tensors to be expanded
369    # into their component dense Tensors, mapped to the new graph, and then
370    # reconstituted into their original composite form.
371    pruned_graph.structured_outputs = nest.map_structure(
372        _structured_output_mapping, fetches, expand_composites=True)
373    pruned_graph.structured_input_signature = input_signature
374    pruned_fn = WrappedFunction(
375        pruned_graph, variable_holder=self._variable_holder)
376    pruned_fn._num_positional_args = len(flat_feeds)  # pylint: disable=protected-access
377    # TODO(kathywu): Enable keyword arguments if an input signature is specified
378    pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds]  # pylint: disable=protected-access
379    return pruned_fn
380
381
382def _filter_returned_ops(fn):
383  """Filtering out any ops returned by function.
384
385  Args:
386    fn: a function
387
388  Returns:
389    A tuple of (
390      Wrapped function that returns `None` in place of any ops,
391      dict that maps the index in the flat output structure to the returned op
392    )
393  """
394  returned_ops = {}
395
396  def wrap_and_filter_returned_ops(*args, **kwargs):
397    outputs = fn(*args, **kwargs)
398    flat_outputs = nest.flatten(outputs)
399    for n in range(len(flat_outputs)):
400      output = flat_outputs[n]
401      if isinstance(output, ops.Operation):
402        returned_ops[n] = output
403        flat_outputs[n] = None
404    return nest.pack_sequence_as(outputs, flat_outputs)
405
406  return wrap_and_filter_returned_ops, returned_ops
407
408
409class WrappedGraph(object):
410  """Class for wrapping multiple TF 1.X functions in a single graph.
411
412  Maintains a dictionary mapping names to wrapped functions. See
413  `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
414
415  Functions wrapped using this class have access to variables and collections
416  created in other wrapped functions, using the standard TF 1.X API (
417  `tf.compat.v1.get_variable` or
418  `tf.compat.v1.get_default_graph().get_collection(...)`)
419
420  Outside a function, variables and collections may be accessed using the
421  `variables` and `graph` properties.
422
423  Example:
424
425  ```
426  def add_v1(x):
427    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
428      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
429    return v + x
430
431  def increment_var_v1(x):
432    with tf.compat.v1.variable_scope('vars', reuse=tf.compat.v1.AUTO_REUSE):
433      v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
434    return v.assign_add(x)
435
436  g = WrappedGraph()
437  add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
438  increment_var = g.wrap_function(increment_var_v1,
439                                  [tf.TensorSpec([], tf.int32)])
440
441  assert len(g.variables) == 1
442  assert g.variables[0].numpy() == 0
443  increment_var(tf.constant(5))
444  assert g.variables[0].numpy() == 5
445
446  ```
447  """
448
449  def __init__(self, variable_holder=None, **kwargs):
450    self._variable_holder = (
451        variable_holder or VariableHolder(share_variables=True))
452
453    name = kwargs.pop("name", "wrapped_function_graph")
454    # Always start with empty collections, unless otherwise specified. Setting
455    # `collections=None` will copy the collections from the outer graph.
456    collections = kwargs.pop("collections", {})
457    self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
458
459    self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
460    self._functions = {}
461
462  @property
463  def functions(self):
464    return self._functions
465
466  @property
467  def variables(self):
468    return self._variable_holder.variables
469
470  def wrap_function(self, fn, signature, name=None):
471    """Wraps a TF 1.X function and returns an eager-compatible function.
472
473    All functions wrapped in the same `WrappedGraph` will have access to the
474    same graph (`tf.compat.v1.get_default_graph` to get the graph object
475    within a function, or `WrappedGraph.graph` to get the graph outside a
476    function). Variables created within the function will be added to the
477    `variables` list.
478
479    Function inputs: All inputs to the function must be tensors (nested ok),
480    with their shapes and dtypes defined in the `signature` argument.
481
482    Function outputs:
483
484      * The 1.X function may return tensors, variables, and ops. The wrapped
485        eager-compatible function will always return tensors in the same nested
486        structure.
487      * Variables are replaced with a tensor containing the latest read values.
488      * Returned ops are executed, and replaced with None.
489      * The order of op execution and variable reads in the return is
490        nondeterministic. For example:
491
492        ```
493        def update_var(x):
494          v = tf.Variable(0)
495          op = tf.compat.v1.assign(v, x).op
496          return v, op
497
498        g = WrappedGraph()
499        fn = g.wrap_function(update_var)
500        read_value, _ = fn(tf.constant(3))
501        print(read_value.numpy())  # could be 0 or 3
502        print(g.variables[0].numpy()) # always 3
503        ```
504
505    To ensure that ops in the function are executed (e.g. ops added to the
506    `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns.
507
508    Args:
509      fn: a 1.X tensorflow function.
510      signature: a possibly nested sequence of `TensorSpecs` specifying the
511        shapes and dtypes of the arguments.
512      name: an optional string name for the function. The function will be saved
513        with key `name` in the `functions` dictionary.
514
515    Returns:
516      An eager-compatible function.
517    """
518    return self._wrap_function(fn, signature=signature, name=name)
519
520  def _wrap_function(self,
521                     fn,
522                     args=None,
523                     kwargs=None,
524                     signature=None,
525                     name=None):
526    """Internal wrap function method with extended func_graph arguments."""
527    fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
528        self._variable_holder.call_with_variable_creator_scope(fn))
529
530    func_graph.func_graph_from_py_func(
531        None,  # Name is unused.
532        fn_with_filter_and_scope,
533        args=args,
534        kwargs=kwargs,
535        signature=signature,
536        add_control_dependencies=False,
537        func_graph=self.graph)
538
539    # This code relies on questional behavior from `func_graph_from_py_func`.
540    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
541    # and structured outputs are overwritten. Pretty sure this is a bug,
542    # because structured outputs doesn't match up with the outputs...
543    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
544
545    # Return filtered ops to the flattened outputs.
546    flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
547    for index, op in returned_ops.items():
548      flat_fn_outputs[index] = op
549    fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
550                                       flat_fn_outputs)
551
552    name = name or fn.__name__
553    wrapped_function = self._wrapped_function.prune(
554        fn_inputs, fn_outputs, name, self.graph.structured_input_signature)
555    self._functions[name] = wrapped_function
556    return wrapped_function
557
558
559@tf_export(v1=["wrap_function"])
560def wrap_function(fn, signature, name=None):
561  """Wraps the TF 1.x function fn into a graph function.
562
563  The python function `fn` will be called once with symbolic arguments specified
564  in the `signature`, traced, and turned into a graph function. Any variables
565  created by `fn` will be owned by the object returned by `wrap_function`. The
566  resulting graph function can be called with tensors which match the
567  signature.
568
569  ```python
570  def f(x, do_add):
571    v = tf.Variable(5.0)
572    if do_add:
573      op = v.assign_add(x)
574    else:
575      op = v.assign_sub(x)
576    with tf.control_dependencies([op]):
577      return v.read_value()
578
579  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])
580
581  assert float(f_add(1.0)) == 6.0
582  assert float(f_add(1.0)) == 7.0
583
584  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
585  # of variables, and possibly different non-template arguments.
586  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])
587
588  assert float(f_sub(1.0)) == 4.0
589  assert float(f_sub(1.0)) == 3.0
590  ```
591
592  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
593  TensorFlow graph. But while `tf.function` runs all stateful operations
594  (e.g. `tf.print`) and sequences operations to provide the same semantics as
595  eager execution, `wrap_function` is closer to the behavior of `session.run` in
596  TensorFlow 1.x. It will not run any operations unless they are required to
597  compute the function's outputs, either through a data dependency or a control
598  dependency. Nor will it sequence operations.
599
600  Unlike `tf.function`, `wrap_function` will only trace the Python function
601  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
602  `wrap_function`'s `signature` argument.
603
604  Since it is only traced once, variables and state may be created inside the
605  function and owned by the function wrapper object.
606
607  Args:
608    fn: python function to be wrapped
609    signature: the placeholder and python arguments to be passed to the wrapped
610      function
611    name: Optional. The name of the function.
612
613  Returns:
614    the wrapped graph function.
615  """
616  holder = VariableHolder(fn)
617  func_graph_name = "wrapped_function"
618  if name is not None:
619    func_graph_name = "wrapped_function_" + name
620  return WrappedFunction(
621      func_graph.func_graph_from_py_func(
622          func_graph_name,
623          holder,
624          args=None,
625          kwargs=None,
626          signature=signature,
627          add_control_dependencies=False,
628          collections={}),
629      variable_holder=holder,
630      signature=signature)
631
632
633def function_from_graph_def(graph_def, inputs, outputs):
634  """Creates a ConcreteFunction from a GraphDef.
635
636  Args:
637    graph_def: A GraphDef to make a function out of.
638    inputs: A Tensor name or nested structure of names in `graph_def` which
639      should be inputs to the function.
640    outputs: A Tensor name or nested structure of names in `graph_def` which
641      should be outputs of the function.
642
643  Returns:
644    A ConcreteFunction.
645  """
646
647  def _imports_graph_def():
648    importer.import_graph_def(graph_def, name="")
649
650  wrapped_import = wrap_function(_imports_graph_def, [])
651  import_graph = wrapped_import.graph
652  return wrapped_import.prune(
653      nest.map_structure(import_graph.as_graph_element, inputs),
654      nest.map_structure(import_graph.as_graph_element, outputs))
655