1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Python front-end supports for functions.
16
17NOTE: At this time, functions are experimental and subject to change!. Proceed
18with caution.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import hashlib
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.framework import function_pb2
30from tensorflow.python.client import pywrap_tf_session as c_api
31from tensorflow.python.eager import context
32from tensorflow.python.framework import c_api_util
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import graph_to_function_def
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import resource_variable_ops
38from tensorflow.python.ops import variable_scope as vs
39from tensorflow.python.util import compat
40from tensorflow.python.util import function_utils
41from tensorflow.python.util import tf_contextlib
42from tensorflow.python.util import tf_inspect
43
44
45class Defun(object):
46  """Decorator used to define TensorFlow functions.
47
48  Use this decorator to make a Python function usable directly as a TensorFlow
49  function.
50
51  The decorated function must add ops to the default graph and return zero or
52  more `Tensor` objects.  Call the decorator with named arguments, one for each
53  argument of the function to decorate, with the expected type of the argument
54  as value.
55
56  For example if the function to decorate accepts two `tf.float32` arguments
57  named `x` and `y`, call the decorator with:
58
59      @Defun(tf.float32, tf.float32)
60      def foo(x, y):
61        ...
62
63  When you call the decorated function, it adds the `call` ops to the
64  default graph. In addition, it adds the definition of the function into the
65  default graph. Because the addition of the function into the graph
66  is deferred, the decorator can be used anywhere in the program.
67
68  Any variables created inside of the function are hoisted into the outer graph.
69  Note that the variables are created in the variable scope that was active
70  during the first call to the function. Subsequent function calls will refer to
71  the same set of variables.
72
73  Definitions of functions in a graph are frozen as soon as the graph is used to
74  create a session. However, new functions and new calls to existing functions
75  may be added to the graph, with the new functions themselves becoming
76  immediately frozen.
77
78  Example, but also see the [How To on functions](link_needed).
79
80  ```python
81  # Defining the function.
82  @tf.Defun(tf.float32, tf.float32)
83  def MyFunc(x, y):
84    return x + y, x - y
85
86  # Building the graph.
87  a = tf.constant([1.0])
88  b = tf.constant([2.0])
89  c, d = MyFunc(a, b, name='mycall')
90  ```
91  """
92
93  def __init__(self, *input_types, **kwargs):
94    """Create a `Defun` decorator.
95
96    Args:
97      *input_types: A list of `tf.DType`
98      **kwargs: Optional keyword arguments, including
99         func_name - (optional).  A python string, the name to use to
100           declare this `Function` in the graph.
101
102         grad_func - (optional).  A function implementing the gradient
103           of the function-to-register.  This is must be a
104           `_DefinedFunction` object. The gradient
105           function must satisfy the criterion defined in
106           function.proto:GradientDef.
107
108         python_grad_func - (optional).  A function implementing the
109           gradient of the function python-side. This function must
110           take the current op and the gradients w.r.t. its outputs,
111           and return the gradients w.r.t. the inputs. That is it must
112           implement the interface expected by `tf.RegisterGradient`).
113           This will be called by tf.gradients to add the gradient ops
114           to the graph. At most one of grad_func and python_grad_func
115           can be specified.
116
117         out_names = (optional). A list of strings, one per output
118           tensor.
119
120         shape_func - (optional). A function taking the op and returning a list
121           of static shapes to set for the function's outputs.
122    """
123    self._input_types = input_types
124    self._func_name = kwargs.pop("func_name", None)
125    self._grad_func = kwargs.pop("grad_func", None)
126    self._python_grad_func = kwargs.pop("python_grad_func", None)
127    self._out_names = kwargs.pop("out_names", None)
128    self._extra_kwargs = kwargs
129
130  def __call__(self, func):
131    # Various sanity checks on the callable func.
132    if not callable(func):
133      raise ValueError("function %s must be callable" % func)
134
135    # Func should not use kwargs and defaults.
136    argspec = tf_inspect.getargspec(func)
137    if argspec.keywords or argspec.defaults:
138      raise ValueError(
139          "function with argument defaults or keywords arguments are not"
140          " supported. {} has defaults {} and keywords {}.".format(
141              func, argspec.defaults, argspec.keywords))
142
143    # Computes how many arguments 'func' has.
144    min_args = len(argspec.args)
145    max_args = min_args
146    if argspec.varargs:
147      max_args = 1000000
148    argnames = argspec.args
149    if tf_inspect.ismethod(func):
150      # 1st argument is the "class" type.
151      min_args -= 1
152      argnames = argnames[1:]
153
154    if self._input_types:
155      # If Defun is given a list of types for the inputs, the number
156      # of input types should be compatible with 'func'.
157      num = len(self._input_types)
158      if num < min_args or num > max_args:
159        raise ValueError(
160            "The function has fewer arguments than the number of specified "
161            "input types.")
162      return _DefinedFunction(
163          func,
164          argnames,
165          self._input_types,
166          self._func_name,
167          self._grad_func,
168          self._python_grad_func,
169          out_names=self._out_names,
170          **self._extra_kwargs)
171
172    # 'func' expects no arguments and input types is an empty list.
173    if min_args == 0 and max_args == 0:
174      return _DefinedFunction(
175          func, [], [],
176          self._func_name,
177          self._grad_func,
178          self._python_grad_func,
179          out_names=self._out_names,
180          **self._extra_kwargs)
181
182    # Input types are unknown. It's an overloaded function and hence
183    # its definition needs to be deferred until it's called.
184    return _OverloadedFunction(
185        func,
186        argnames,
187        self._func_name,
188        self._grad_func,
189        self._python_grad_func,
190        out_names=self._out_names,
191        **self._extra_kwargs)
192
193
194class _DefinedFunctionDeleter(object):
195  """Unregister function from eager context."""
196
197  __slots__ = ["name"]
198
199  def __init__(self, name):
200    self.name = name
201
202  def __del__(self):
203    try:
204      context.remove_function(self.name)
205    except TypeError:
206      # Suppress some exceptions, mainly for the case when we're running on
207      # module deletion. Things that can go wrong include the context module
208      # already being unloaded, self._handle._handle_data no longer being
209      # valid, and so on. Printing warnings in these cases is silly
210      # (exceptions raised from __del__ are printed as warnings to stderr).
211      pass  # 'NoneType' object is not callable when the handle has been
212      # partially unloaded.
213    except AttributeError:
214      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
215      # been unloaded. Will catch other module unloads as well.
216
217
218class _DefinedFunction(object):
219  """_DefinedFunction encapsulates a function definition and its properties.
220
221  Attributes:
222    name: The function name.
223    definition: The definition of this function. A FunctionDef proto.
224    grad_func_name: If not None, the name of this function's gradient function.
225    python_grad_func: A python callable implementing the gradient of
226      the function python-side.
227  """
228
229  def __init__(self,
230               func,
231               argnames,
232               input_types,
233               func_name=None,
234               grad_func=None,
235               python_grad_func=None,
236               out_names=None,
237               shape_func=None,
238               capture_by_value=False,
239               allowlisted_stateful_ops=None,
240               capture_resource_var_by_value=True,
241               **kwargs):
242    """Creates _DefinedFunction.
243
244    Args:
245      func:  A python callable which constructs a tf function body.
246      argnames: A list of strings for function argument names.
247      input_types: The function's argument types. Can be a tuple, list of
248        tf data types.
249      func_name: The function name. Defaults to None, in which derives from
250        'func'.
251      grad_func: This function's gradient function, if not None. Defaults
252        to None.
253      python_grad_func: A python callable implementing the gradient of
254        the function python-side.
255      out_names: An optional list of strings for the function return value
256        names.
257      shape_func: An optional function mapping an op to a list of static
258        output shapes.
259      capture_by_value: Boolean (defaults to False). If True, captured values
260        will be copied into the function body.
261      allowlisted_stateful_ops: A set of ops that if stateful we ignore and
262        copy into the function body, when `capture_by_value` is True.
263      capture_resource_var_by_value: Boolean (defaults to True). If False,
264        captured resource variable returns the handle instead of value.
265      **kwargs: The keyword arguments. **kwargs is passed to every call
266        site of this function.
267
268    Raises:
269      ValueError: The function definition is invalid.
270
271    """
272    self._func = func
273    self._input_types = input_types
274    self._func_name = func_name
275    self._grad_func = grad_func
276    self._python_grad_func = python_grad_func
277    self._out_names = out_names
278    self._shape_func = shape_func
279    self._capture_by_value = capture_by_value
280    self._allowlisted_stateful_ops = allowlisted_stateful_ops
281    if self._allowlisted_stateful_ops is None:
282      self._allowlisted_stateful_ops = set()
283    self._capture_resource_var_by_value = capture_resource_var_by_value
284    self._extra_kwargs = kwargs
285    # Constructed only when C API is disabled, lazily
286    self._definition = None
287    # Constructed only when C API is enabled, lazily
288    self._c_func = None
289    self._function_deleter = None
290    self._sub_functions = {}  # Constructed with _definition or _c_func
291    # pylint: disable=protected-access
292    device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
293    # pylint: enable=protected-access
294
295    # Get the innermost device if possible.
296    self._caller_device = device_funcs[-1] if device_funcs else None
297
298    # Cached OpDef for this function. When C API is enabled, this is
299    # the only part of FunctionDef that we cache in Python. When C API
300    # is disabled the whole _definition is available and this is simply
301    # another reference to _definition.signature
302    self._op_def = None
303
304    assert isinstance(input_types, (list, tuple))
305    self._arg_types = input_types
306    self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i)
307                       for i in range(len(input_types))]
308
309  @property
310  def name(self):
311    """Function name."""
312    self._create_definition_if_needed()
313    return self._func_name
314
315  @property
316  def definition(self):
317    """Function definition proto."""
318    self._create_definition_if_needed()
319    if self._c_func:
320      with c_api_util.tf_buffer() as buf:
321        c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
322        fdef = function_pb2.FunctionDef()
323        proto_data = c_api.TF_GetBuffer(buf)
324        fdef.ParseFromString(compat.as_bytes(proto_data))
325        with ops.init_scope():
326          if context.executing_eagerly():
327            context.add_function(self._c_func.func)
328            self._function_deleter = _DefinedFunctionDeleter(
329                fdef.signature.name)
330      return fdef
331    return self._definition
332
333  @property
334  def _signature(self):
335    self._create_definition_if_needed()
336    return self._op_def
337
338  def set_grad_func(self, grad_func):
339    """Specifies the gradient function of this function."""
340    assert not self._grad_func
341    assert isinstance(grad_func, _DefinedFunction)
342    self._grad_func = grad_func
343
344  @property
345  def grad_func_name(self):
346    """Returns the name of the gradient function."""
347    return self._grad_func.name if self._grad_func else None
348
349  @property
350  def python_grad_func(self):
351    """Python gradient function callable."""
352    return self._python_grad_func
353
354  @property
355  def declared_input_types(self):
356    """Returns the list of data types of explicit declared inputs."""
357    return self._input_types
358
359  @property
360  def captured_inputs(self):
361    """Returns the list of implicitly captured inputs."""
362    self._create_definition_if_needed()
363    return self._extra_inputs
364
365  @property
366  def stateful_ops(self):
367    """Returns the list of stateful ops in function definition.
368
369    Returns:
370      A list of (op.name, op.type) pairs.
371    """
372    self._create_definition_if_needed()
373    return self._stateful_ops
374
375  def _create_definition_if_needed(self):
376    """Creates the function definition if it's not created yet."""
377    with context.graph_mode():
378      self._create_definition_if_needed_impl()
379
380  def _create_definition_if_needed_impl(self):
381    """This is not what you want, see _create_definition_if_needed."""
382    if self._definition is not None or self._c_func is not None:
383      return
384
385    # Copy variable collections (by reference) from the parent graph such that
386    # name based variable sharing (e.g. via tf.make_template) works between the
387    # func graph and parent graph.
388    variable_keys = []
389    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
390    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access
391
392    parent_graph = ops.get_default_graph()
393    collections_ref = {
394        key: parent_graph.get_collection_ref(key) for key in variable_keys}
395
396    temp_graph = func_graph_from_py_func(
397        self._func,
398        self._arg_names,
399        self._arg_types,
400        self._func_name,
401        self._capture_by_value,
402        self._caller_device,
403        collections_ref=collections_ref,
404        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
405        capture_resource_var_by_value=self._capture_resource_var_by_value)
406
407    self._extra_inputs = temp_graph.extra_inputs
408    # pylint: disable=protected-access
409    self._sub_functions = temp_graph._functions
410    # pylint: enable=protected-access
411
412    # Extra kwargs are treated as attrs on the function def.
413    if self._func_name:
414      base_func_name = self._func_name
415    else:
416      base_func_name = function_utils.get_func_name(self._func)
417      if self._grad_func:
418        base_func_name += ("_%s" % self._grad_func.name)
419    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)
420
421    if not temp_graph._c_graph:  # pylint: disable=protected-access
422      # Build the FunctionDef
423      self._definition = graph_to_function_def.graph_to_function_def(
424          temp_graph,
425          temp_graph.get_operations(),
426          temp_graph.inputs,
427          temp_graph.outputs,
428          out_names=self._out_names)
429
430      for k in kwargs_attr:
431        self._definition.attr[k].CopyFrom(kwargs_attr[k])
432
433      # Hash the definition and its dependencies.
434      self._hash_str = self._create_hash_str(
435          self._definition.signature.input_arg,
436          self._definition.signature.output_arg, self._definition.node_def)
437
438      # Finally, we decide the function name to use.  If not specified,
439      # make up something which is almost certainly unique (but deterministic).
440      if not self._func_name:
441        self._func_name = "_".join([base_func_name, self._hash_str])
442      self._definition.signature.name = self._func_name
443      if self._func.__doc__:
444        self._definition.signature.description = self._func.__doc__
445
446      self._op_def = self._definition.signature
447    else:  # C API is enabled
448      output_names = ([compat.as_bytes(x) for x in self._out_names]
449                      if self._out_names else [])
450      description = self._func.__doc__ or None
451      # pylint: disable=protected-access
452      c_func = c_api.TF_GraphToFunction_wrapper(
453          temp_graph._c_graph,
454          base_func_name,
455          self._func_name is None,  # append_hash_to_fn_name
456          None,  # opers
457          [t._as_tf_output() for t in temp_graph.inputs],
458          [t._as_tf_output() for t in temp_graph.outputs],
459          output_names,
460          [], # control_outputs
461          [], # control_output_names
462          None,  # opts
463          description)
464      self._c_func = c_api_util.ScopedTFFunction(c_func)
465      # pylint: enable=protected-access
466      self._set_c_attrs(kwargs_attr)
467
468      # Set cached fields: _op_def and _func_name (if not already set)
469      self._op_def = self.definition.signature
470      if self._func_name:
471        assert self._func_name == self._op_def.name
472      else:
473        self._func_name = compat.as_str(self._op_def.name)
474
475    self._stateful_ops = [(op.name, op.type)
476                          for op in temp_graph.get_operations()
477                          if op._is_stateful]  # pylint: disable=protected-access
478
479  def _set_c_attrs(self, attrs):
480    """Sets `attrs` as attributes of self._c_func.
481
482    Requires that self._c_func is not None.
483
484    Args:
485      attrs: a dictionary from attribute name to attribute proto value
486    """
487    for name, attr_value in attrs.items():
488      serialized = attr_value.SerializeToString()
489      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
490      # It might be worth creating a convenient way to re-use the same status.
491      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
492                                         serialized)
493
494  def _create_hash_str(self, input_arg, output_arg, node_def):
495    """Creates an 8-character string unique to this input.
496
497    Args:
498      input_arg: the input_arg field of an OpDef
499                 (e.g. self._definition.signature.input_arg)
500      output_arg: the output_arg field of an OpDef
501                 (e.g. self._definition.signature.output_arg)
502      node_def: the node_def field of a FunctionDef
503                (e.g. self._definition.node_def)
504
505    Returns:
506      The unique string for this input
507    """
508    hasher = hashlib.sha1()
509
510    def update_num(n):
511      hasher.update(compat.as_bytes("%x" % n))
512
513    def update_str(s):
514      update_num(len(s))
515      hasher.update(compat.as_bytes(s))
516
517    def update_strs(slist):
518      update_num(len(slist))
519      for s in slist:
520        update_str(s)
521
522    for adef in input_arg:
523      update_str(adef.SerializeToString())
524
525    for adef in output_arg:
526      update_str(adef.SerializeToString())
527
528    for n in sorted(node_def, key=lambda n: n.name):
529      update_str(n.name)
530      update_str(n.op)
531      update_strs(n.input)
532      update_num(len(n.attr))
533      # NOTE: protobuf map serialization does not guarantee ordering.
534      for k in sorted(n.attr):
535        update_str(k)
536        update_str(n.attr[k].SerializeToString())
537
538    return hasher.hexdigest()[:8]
539
540  def add_to_graph(self, g):
541    """Adds this function into the graph g."""
542    self._create_definition_if_needed()
543
544    # Adds this function into 'g'.
545    # pylint: disable=protected-access
546    if context.executing_eagerly():
547      context.context().add_function_def(self.definition)
548    else:
549      g._add_function(self)
550    # pylint: enable=protected-access
551
552    # Ensures related sub-routines are defined in 'g', too.
553    for f in self._sub_functions.values():
554      f.add_to_graph(g)
555
556    # Adds its gradient function, too.
557    if self._grad_func:
558      self._grad_func.add_to_graph(g)
559
560  def __call__(self, *args, **kwargs):
561    self.add_to_graph(ops.get_default_graph())
562    args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs
563    ret, op = _call(self._signature, *args, **kwargs)
564
565    # Set a hidden attr in 'op' so that gradients_impl can refer back
566    # to this _DefinedFunction instance to access python_grad_func.
567    assert isinstance(op, ops.Operation)
568    setattr(op, "__defun", self)
569
570    if self._shape_func is not None:
571      shapes = self._shape_func(op)
572      if len(shapes) != len(op.outputs):
573        raise ValueError("shape_func produced %d shapes for %d outputs" %
574                         (len(shapes), len(op.outputs)))
575      for (t, shape) in zip(op.outputs, shapes):
576        t.set_shape(shape)
577    return ret
578
579
580class _OverloadedFunction(object):
581  """_OverloadedFunction encapsulates an overloaded function.
582
583  _OverloadedFunction maintains a mapping from input types to
584  instantiated _DefinedFunction in self._overload.
585
586  """
587
588  def __init__(self,
589               func,
590               argnames,
591               func_name=None,
592               grad_func=None,
593               python_grad_func=None,
594               out_names=None,
595               **kwargs):
596    """Creates _DefinedFunction.
597
598    Args:
599      func:  A python callable which constructs a tf function body.
600      argnames: A list of strings for function argument names.
601      func_name: The function name. Defaults to None, in which derives from
602        'func'.
603      grad_func: This function's gradient function, if not None. Defaults
604        to None.
605      python_grad_func: A python callable implementing the gradient of
606        the function python-side.
607      out_names: A list of strings for the function return value names.
608      **kwargs: The keyword arguments. **kwargs is passed to every call
609        site of this function.
610
611    Raises:
612      ValueError: The function definition is invalid.
613
614    """
615    self._func = func
616    self._argnames = argnames
617    self._func_name = func_name
618    assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
619    self._grad_func = grad_func
620    self._python_grad_func = python_grad_func
621    self._out_names = out_names
622    self._extra_kwargs = kwargs
623    self._overload = {}
624
625  def instantiate(self, input_types):
626    """Instantiate this function given input argument types.
627
628    Args:
629      input_types: A list of data types for the inputs.
630
631    Returns:
632      _DefinedFunction for the given input types.
633
634    """
635    # Stringify the type list.
636    key = _type_list_to_str(input_types)
637    defined = self._overload.get(key)
638    if not defined:
639      # If not defined yet, define the function given the input types.
640      name = self._func_name
641      if name is not None:
642        name = "_".join([name, key])
643      defined = _DefinedFunction(
644          self._func,
645          self._argnames,
646          input_types,
647          name,
648          None,
649          self._python_grad_func,
650          out_names=self._out_names,
651          **self._extra_kwargs)
652      _ = defined.name  # Fully instantiate the function definition.
653      if self._grad_func:
654        # If _grad_func is given, it is another
655        # _OverloadedFunction. We need to instantiate it with the
656        # right input types.
657        output_types = [
658            dtypes.DType(_.type) for _ in defined._signature.output_arg  # pylint: disable=protected-access
659        ]
660        # pylint: disable=protected-access
661        defined._grad_func = self._grad_func.instantiate(input_types +
662                                                         output_types)
663        # pylint: enable=protected-access
664      self._overload[key] = defined
665    return defined
666
667  def __call__(self, *args, **kwargs):
668    input_types = []
669    args = list(args)
670    for (i, x) in enumerate(args):
671      x = ops.convert_to_tensor(x)
672      if not isinstance(x, ops.Tensor):
673        raise ValueError("Expect a Tensor but get ", x)
674      input_types.append(x.dtype)
675      args[i] = x
676    return self.instantiate(input_types)(*args, **kwargs)
677
678
679class _FuncGraph(ops.Graph):
680  """A helper for constructing a function.
681
682  _FuncGraph overrides ops.Graph's create_op() so that we can keep
683  track of all inputs into every op created inside the function.  If
684  any input is from other graphs, we keep track of it in self.capture
685  and substitute the input with a place holder.
686
687  Each captured input's corresponding place holder is converted into a
688  function argument and the caller passes in the captured tensor.
689  """
690
691  def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
692               capture_resource_var_by_value, *args, **kwargs):
693    super(_FuncGraph, self).__init__(*args, **kwargs)
694    self._capture_by_value = capture_by_value
695    self._allowlisted_stateful_ops = allowlisted_stateful_ops
696    self._capture_resource_var_by_value = capture_resource_var_by_value
697    self._building_function = True
698    self._outer_graph = ops.get_default_graph()
699    self._vscope = vs.get_variable_scope()
700    self._old_custom_getter = self._vscope.custom_getter
701
702    # The name of the function.
703    self.name = name
704    # Placeholder tensors representing the inputs to this function. The tensors
705    # are in this _FuncGraph.
706    self.inputs = []
707    # Tensors that will be returned this function. The tensors are in this
708    # _FuncGraph.
709    self.outputs = []
710    # Maps external tensor -> internal tensor (e.g. input placeholder).
711    self._captured = {}
712    # The external tensors that have been captured as inputs and must be passed
713    # to this function (empty if capturing by value, otherwise these are the
714    # keys of _captured).
715    self.extra_inputs = []
716    # Input placeholders that been added for captured values (empty if capturing
717    # by value).
718    self.extra_args = []
719    # Captured variables.
720    # TODO(skyewm): is this needed?
721    self.extra_vars = []
722
723  # pylint: disable=g-doc-return-or-yield
724
725  @property
726  def outer_graph(self):
727    """The graph active when this _FuncGraph was created."""
728    return self._outer_graph
729
730  @tf_contextlib.contextmanager
731  def container(self, container_name):
732    """Returns a context manager that specifies the resource container to use.
733
734    Overridden from `tf.Graph` to update both the init_scope container
735    and the present inner container. This is necessary to make sure setting
736    containers applies correctly both to created variables and to stateful
737    ops.
738
739    Args:
740      container_name: container name string.
741
742    Returns:
743      A context manager for defining resource containers for stateful ops,
744        yields the container name.
745    """
746    original_container = self._container
747    # pylint: disable=protected-access
748    with ops.init_scope():
749      original_init_container = ops.get_default_graph()._container
750    try:
751      self._container = container_name
752      with ops.init_scope():
753        ops.get_default_graph()._container = container_name
754      yield self._container
755    finally:
756      self._container = original_container
757      with ops.init_scope():
758        ops.get_default_graph()._container = original_init_container
759    # pylint: enable=protected-access
760
761  # pylint: enable=g-doc-return-or-yield
762
763  def getvar(
764      self,
765      getter,
766      name,
767      shape=None,
768      dtype=None,
769      initializer=None,
770      reuse=None,
771      trainable=True,
772      collections=None,  # pylint: disable=redefined-outer-name
773      use_resource=None,
774      **kwargs):
775    """A custom variable getter."""
776    # Here, we switch the default graph to the outer graph and ask the
777    # variable scope in which the function is defined to give us the
778    # variable. The variable is stashed in extra_vars and returned to
779    # the caller.
780    #
781    # We capture these variables so that the variable definition is
782    # hoisted upward to the outer most graph.
783    with self._outer_graph.as_default():
784      # pylint: disable=protected-access
785      var = self._vscope.get_variable(
786          vs._get_default_variable_store(),
787          name,
788          shape=shape,
789          dtype=dtype,
790          initializer=initializer,
791          reuse=reuse,
792          trainable=trainable,
793          collections=collections,
794          use_resource=use_resource)
795      self.extra_vars.append(var)
796      if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
797          self._capture_resource_var_by_value):
798        # For resource-based variables read the variable outside the function
799        # and pass in the value. This ensures that the function is pure and
800        # differentiable. TODO(apassos) this may have performance problems if
801        # the function will only do embedding lookups on the variable.
802        return var.value()
803      return var
804
805  def _create_op_internal(
806      self,
807      op_type,
808      inputs,
809      dtypes=None,  # pylint: disable=redefined-outer-name
810      input_types=None,
811      name=None,
812      attrs=None,
813      op_def=None,
814      compute_device=True):
815    for i, x in enumerate(inputs):
816      if isinstance(x, ops.EagerTensor) or x.graph is not self:
817        inputs[i] = self.capture(x)
818    return super(_FuncGraph, self)._create_op_internal(
819        op_type,
820        inputs,
821        dtypes=dtypes,
822        input_types=input_types,
823        name=name,
824        attrs=attrs,
825        op_def=op_def,
826        compute_device=compute_device)
827
828  def capture(self, tensor, name=None):
829    """Adds the given tensor to this graph and returns the captured tensor."""
830    if tensor.ref() in self._captured:
831      # Captured already.
832      return self._captured[tensor.ref()]
833    elif self._capture_by_value:
834      return self._add_tensor_and_parents(tensor)
835    else:
836      return self._capture_tensor_as_extra_input(tensor, name)
837
838  @property
839  def captures(self):
840    """Pairs of tensors and captured tensor."""
841    return [(k.deref(), v) for k, v in self._captured.items()]
842
843  def _capture_tensor_as_extra_input(self, tensor, name=None):
844    # Substitute with a placeholder.
845    self.extra_inputs.append(tensor)
846    # Hoist the new input placeholder out of any control flow context
847    # we're currently in.
848    with ops.control_dependencies(None):
849      ph = array_ops.placeholder(
850          tensor.dtype, shape=tensor.get_shape(), name=name)
851    # pylint: disable=protected-access
852    if isinstance(tensor, ops.EagerTensor):
853      handle_data = tensor._handle_data
854      if handle_data:
855        handle_data = handle_data.SerializeToString()
856    else:
857      handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
858                                                tensor._as_tf_output())
859
860    if handle_data:
861      c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
862                                  compat.as_bytes(handle_data))
863    # pylint: enable=protected-access
864    self.inputs.append(ph)
865    self._captured[tensor.ref()] = ph
866    self.extra_args.append(ph)
867    if _is_guaranteed_const(tensor):
868      with ops.control_dependencies(None):
869        return array_ops.guarantee_const(ph)
870    else:
871      return ph
872
873  def _add_tensor_and_parents(self, tensor):
874    op = self._add_op_and_parents(tensor.op)
875    return op.outputs[tensor.value_index]
876
877  def _add_op_and_parents(self, op):
878    # pylint: disable=protected-access
879    op_def = graph_to_function_def._get_op_def(op)
880    if op._is_stateful and op not in self._allowlisted_stateful_ops:
881      raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
882                       "by value." % (op.name, op.type))
883    elif op.type in ("Placeholder", "PlaceholderV2"):
884      raise ValueError("Cannot capture a placeholder (name:%s, type:%s) "
885                       "by value." % (op.name, op.type))
886    # pylint: enable=protected-access
887
888    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]
889
890    captured_op = self._create_op_internal(
891        op.type,
892        captured_inputs, [o.dtype for o in op.outputs],
893        name=op.name,
894        attrs=op.node_def.attr,
895        op_def=op_def)
896
897    for t, captured_t in zip(op.outputs, captured_op.outputs):
898      self._captured[t.ref()] = captured_t
899
900    return captured_op
901
902
903def func_graph_from_py_func(func,
904                            arg_names,
905                            arg_types,
906                            name=None,
907                            capture_by_value=False,
908                            device=None,
909                            colocation_stack=None,
910                            container=None,
911                            collections_ref=None,
912                            arg_shapes=None,
913                            allowlisted_stateful_ops=None,
914                            capture_resource_var_by_value=True):
915  """Returns a _FuncGraph generated from `func`.
916
917  Args:
918    func: A Python callable which constructs a TF function body. The arguments
919      must correspond to `arg_types`. Returns a value or list/tuple of values.
920      No returned value can be None.
921    arg_names: A sequence of strings for the function argument names.
922    arg_types: A sequence of the function's argument types.
923    name: The function name. If None, the name is derived from `func`.
924    capture_by_value: boolean. If True, captured values will be copied into the
925      function body.
926    device: device name or function.
927    colocation_stack: A colocation stack (list) the _FuncGraph should use.
928    container: A container name the _FuncGraph should start with.
929    collections_ref: A reference to a collections dict the _FuncGraph should
930      use internally.
931    arg_shapes: A sequence of the function's argument shapes.
932    allowlisted_stateful_ops: A set of ops that if stateful we ignore and
933      re-create.
934    capture_resource_var_by_value: Boolean (defaults to True). If False,
935      captured resource variable returns the handle instead of value.
936
937  Returns:
938    A _FuncGraph.
939
940  Raises:
941    ValueError: if func returns None.
942  """
943  if not name:
944    name = function_utils.get_func_name(func)
945  func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
946                          capture_resource_var_by_value)
947
948  with func_graph.as_default(), ops.device(device):
949    # pylint: disable=protected-access
950    if collections_ref is not None:
951      func_graph._collections = collections_ref
952    if container is not None:
953      func_graph._container = container
954    if colocation_stack is not None:
955      func_graph._colocation_stack = colocation_stack
956    # pylint: enable=protected-access
957
958    if arg_shapes is None:
959      arg_shapes = [None] * len(arg_types)
960
961    # Create placeholders for the function arguments.
962    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
963      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
964      func_graph.inputs.append(argholder)
965    # Call func and gather the output tensors.
966    with vs.variable_scope("", custom_getter=func_graph.getvar):
967      outputs = func(*func_graph.inputs)
968
969    # There is no way of distinguishing between a function not returning
970    # anything and a function returning None in Python.
971    # We need to allow the former and ideally want to forbid the latter as
972    # it is most likely user error.
973    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
974    # allow users to explicitly mark the function as not returning anything.
975    # For now, we allow a single None return and interpret it as a function
976    # with no output.
977    if outputs is None:
978      outputs = []
979    else:
980      # If func only returned one value, make it a tuple.
981      if not isinstance(outputs, (list, tuple)):
982        outputs = (outputs,)
983      if any(_ is None for _ in outputs):
984        raise ValueError("Function %s can not return None." % name)
985    # Ensures each output is a Tensor in the function graph.
986    outputs = [ops.convert_to_tensor(t) for t in outputs]
987    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
988               for t in outputs]
989    func_graph.outputs = outputs
990  return func_graph
991
992
993def _is_guaranteed_const(tensor):
994  """Determines whether `tensor` is guaranteed to be a constant.
995
996  A tensor is guaranteed to be a constant if either it was produced by
997  a `GuaranteeConst` op or if all of its children are guaranteed to be
998  constants.
999
1000  Args:
1001    tensor: The tensor for which to determine const-ness.
1002
1003  Returns:
1004    True if `tensor` is guaranteed to be a constant, False otherwise.
1005  """
1006
1007  if isinstance(tensor, ops.EagerTensor):
1008    return False
1009
1010  class Work(object):
1011
1012    def __init__(self, op, leaving):
1013      self.op = op
1014      self.leaving = leaving
1015
1016  is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
1017  constants = set([])
1018  def all_inputs_const(op):
1019    # If all inputs of an op are guaranteed constants, then we can infer that
1020    # the op produces a constant as well.
1021    return op.inputs and all(inp.op in constants for inp in op.inputs)
1022
1023  visited = set([])
1024  stack = [Work(tensor.op, leaving=False)]
1025  while stack:
1026    work = stack.pop()
1027    if work.leaving:
1028      if all_inputs_const(work.op):
1029        constants.add(work.op)
1030      continue
1031    visited.add(work.op)
1032    if is_guaranteed_const(work.op):
1033      constants.add(work.op)
1034      continue
1035
1036    # This op will be revisited after all its inputs are checked for const-ness.
1037    stack.append(Work(work.op, leaving=True))
1038    for inp in work.op.inputs:
1039      if inp.op not in visited:
1040        stack.append(Work(inp.op, leaving=False))
1041  return tensor.op in constants
1042
1043
1044def _call(sig, *inputs, **kwargs):
1045  """Adds a node calling a function.
1046
1047  This adds a `call` op to the default graph that calls the function
1048  of signature `sig`, passing the tensors in `inputs` as arguments.
1049  It returns the outputs of the call, which are one or more tensors.
1050
1051  `sig` is OpDefArg.a `_DefinedFunction` object.
1052
1053  You can pass an optional keyword parameter `name=string` to name the
1054  added operation.
1055
1056  You can pass an optional keyword parameter `noinline=True|False` to
1057  instruct the runtime not to inline the function body into the call
1058  site.
1059
1060  Args:
1061    sig: OpDefArg. The signature of the function.
1062    *inputs: arguments to the function.
1063    **kwargs: Optional keyword arguments.  Can only contain 'name' or
1064        'noinline'.
1065
1066  Returns:
1067     A 2-element tuple. First element: a Tensor if the function returns a single
1068     value; a list of Tensors if the function returns multiple value; the
1069     Operation if the function returns no values. Second element: the Operation.
1070
1071  Raises:
1072    ValueError: if the arguments are invalid.
1073  """
1074  if len(inputs) != len(sig.input_arg):
1075    raise ValueError("Expected number of arguments: %d, received: %d" % (len(
1076        sig.input_arg), len(inputs)))
1077  name = kwargs.pop("name", None)
1078  g = ops.get_default_graph()
1079  func_name = sig.name
1080  if name is None:
1081    name = func_name
1082  attrs = _parse_kwargs_as_attrs(func_name, **kwargs)
1083  output_types = [dtypes.DType(x.type) for x in sig.output_arg]
1084  op = g._create_op_internal(  # pylint: disable=protected-access
1085      func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig)
1086  if op.outputs:
1087    if len(op.outputs) == 1:
1088      ret = op.outputs[0]
1089    else:
1090      ret = tuple(op.outputs)
1091  else:
1092    ret = op
1093  return ret, op
1094
1095
1096def _from_definition(fdef, grad_func=None):
1097  """Creates a _DefinedFunction initialized from a FunctionDef proto.
1098
1099  Args:
1100    fdef: a FunctionDef
1101    grad_func: a _DefinedFunction or None
1102
1103  Returns:
1104    A _DefinedFunction representing fdef
1105  """
1106  # TODO(iga): This method does major surgery on _DefinedFunction.
1107  # Make it a named constructor using @classmethod of _DefinedFunction.
1108
1109  # The Python callable is only needed to create a FunctionDef. Since we have
1110  # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
1111  # have access to such a callable here).
1112  func = None
1113  argnames = [arg.name for arg in fdef.signature.input_arg]
1114  input_types = tuple(
1115      dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
1116  func_name = fdef.signature.name
1117  # Note: FunctionDefs do not include python gradient functions, so if the
1118  # original _DefinedFunction included one it will not be reflected here.
1119  python_grad_func = None
1120  out_names = [arg.name for arg in fdef.signature.output_arg]
1121  result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
1122                            python_grad_func, out_names)
1123  # pylint: disable=protected-access
1124  serialized = fdef.SerializeToString()
1125  c_func = c_api.TF_FunctionImportFunctionDef(serialized)
1126  result._c_func = c_api_util.ScopedTFFunction(c_func)
1127  result._extra_inputs = []
1128  result._op_def = fdef.signature
1129  # pylint: enable=protected-access
1130
1131  return result
1132
1133
1134def from_library(lib):
1135  """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto.
1136
1137  This method handles assigning the correct gradient functions to each
1138  function.
1139
1140  Args:
1141    lib: a FunctionDefLibrary
1142
1143  Returns:
1144    A list of _DefinedFunctions
1145
1146  Raises:
1147    ValueError: `lib` is invalid
1148  """
1149  if not lib.function and not lib.gradient:
1150    return []
1151
1152  # function name -> FunctionDef proto
1153  funcs = {fdef.signature.name: fdef for fdef in lib.function}
1154
1155  # Validate that all references function names have function defs
1156  for g in lib.gradient:
1157    if g.function_name not in funcs:
1158      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
1159                       (g.function_name, str(lib)))
1160    if g.gradient_func not in funcs:
1161      raise ValueError("FunctionDefLibrary missing '%s' FunctionDef\n%s" %
1162                       (g.gradient_func, str(lib)))
1163
1164  # function name -> gradient function name
1165  func_to_grad = collections.defaultdict(lambda: None)
1166  # gradient function name -> names of functions having that grad function
1167  grad_to_funcs = collections.defaultdict(list)
1168
1169  for gdef in lib.gradient:
1170    func_to_grad[gdef.function_name] = gdef.gradient_func
1171    grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
1172
1173  # Start with functions without gradients
1174  ready = [
1175      fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
1176  ]
1177  if not ready:
1178    raise ValueError(
1179        "FunctionDefLibrary contains cyclic gradient functions!\n" + str(lib))
1180  # function name -> _DefinedFunction
1181  initialized = {}
1182
1183  while ready:
1184    fdef = ready.pop()
1185    name = fdef.signature.name
1186
1187    grad = initialized.get(func_to_grad[name])
1188    if func_to_grad[name]:
1189      assert grad
1190    defined_func = _from_definition(fdef, grad_func=grad)
1191    initialized[name] = defined_func
1192
1193    ready.extend(funcs[f] for f in grad_to_funcs[name])
1194
1195  return initialized.values()
1196
1197
1198def _get_experimental_kwarg_as_attr(attr_name, value):
1199  """Creates an AttrValue for a python object."""
1200  if isinstance(value, bool):
1201    return attr_value_pb2.AttrValue(b=value)
1202  elif isinstance(value, int):
1203    return attr_value_pb2.AttrValue(i=value)
1204  elif isinstance(value, float):
1205    return attr_value_pb2.AttrValue(f=value)
1206  elif isinstance(value, str):
1207    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1208  else:
1209    raise ValueError("Unsupported attribute type for %s with type %s" %
1210                     (attr_name, type(value)))
1211
1212
1213def _get_kwarg_as_str_attr(attr_name, value):
1214  """Creates an AttrValue for a python object."""
1215  if isinstance(value, str):
1216    return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
1217  else:
1218    raise ValueError("Unsupported attribute type for %s with type %s" %
1219                     (attr_name, type(value)))
1220
1221
1222def _parse_kwargs_as_attrs(func_name, **kwargs):
1223  """Parses **kwargs into a node's attributes."""
1224  attrs = {}
1225
1226  noinline = kwargs.pop("noinline", None)
1227  if noinline is not None:
1228    attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
1229
1230  # For compatibility with previous behavior, Defun does not perform shape
1231  # inference through its function call operations.
1232  attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True)
1233
1234  compiled = kwargs.pop("compiled", None)
1235  separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
1236  if compiled is not None:
1237    attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
1238    attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
1239        b=bool(separate_compiled_gradients))
1240    # Forward _XlaScope from enclosing context (if set), otherwise create new.
1241    # pylint: disable=protected-access
1242    if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
1243      attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
1244    else:
1245      attrs["_XlaScope"] = attr_value_pb2.AttrValue(
1246          s=("function_%s" % func_name).encode())
1247    # pylint: enable=protected-access
1248
1249  kwargs_keys = list(kwargs.keys())
1250  for key in kwargs_keys:
1251    if key.startswith("experimental_"):
1252      attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
1253      del kwargs[key]
1254    # Support for https://github.com/tensorflow/community/pull/113/files.
1255    elif key == "_implements" or key == "_reference":
1256      attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key])
1257      del kwargs[key]
1258  if kwargs:
1259    raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
1260  return attrs
1261
1262
1263def get_extra_vars():
1264  """Returns the captured variables by the function.
1265
1266  Returns:
1267    If the default graph is being used to define a function, the
1268    returned list of variables are those created inside the function
1269    body so far. Otherwise, returns an empty list.
1270  """
1271  g = ops.get_default_graph()
1272  if isinstance(g, _FuncGraph):
1273    return g.extra_vars
1274  else:
1275    return []
1276
1277
1278def get_extra_inputs():
1279  """Returns the captured input tensors by the function.
1280
1281  Returns:
1282    If the default graph is being used to define a function, the
1283    returned list of tensors are those accessed inside the function body
1284    but defined outside the function body so far. Otherwise, returns an
1285    empty list.
1286  """
1287  g = ops.get_default_graph()
1288  if isinstance(g, _FuncGraph):
1289    return g.extra_inputs
1290  else:
1291    return []
1292
1293
1294def get_extra_args():
1295  """Returns the corresponding function arguments for the captured inputs.
1296
1297  Returns:
1298    If the default graph is being used to define a function, the
1299    returned list of place holders are those used inside the function
1300    body corresponding those returned by get_extra_inputs(). Otherwise,
1301    returns an empty list.
1302  """
1303  g = ops.get_default_graph()
1304  if isinstance(g, _FuncGraph):
1305    return g.extra_args
1306  else:
1307    return []
1308
1309
1310def _type_list_to_str(types):
1311  if any(_ not in _DTYPE_TO_STR for _ in types):
1312    raise ValueError("Unsupported dtypes: %s" % types)
1313  return "".join(_DTYPE_TO_STR[_] for _ in types)
1314
1315
1316# NOTE: The list needs to be extended when more data types are added.
1317_DTYPE_TO_STR = {
1318    dtypes.float16: "f16",
1319    dtypes.float32: "f32",
1320    dtypes.float64: "f64",
1321    dtypes.int32: "i32",
1322    dtypes.uint8: "i8",
1323    dtypes.uint16: "u16",
1324    dtypes.uint32: "u32",
1325    dtypes.uint64: "u64",
1326    dtypes.int16: "i16",
1327    dtypes.int8: "i8",
1328    dtypes.string: "s",
1329    dtypes.complex64: "c64",
1330    dtypes.complex128: "c128",
1331    dtypes.int64: "i64",
1332    dtypes.bool: "b",
1333    dtypes.qint8: "qi8",
1334    dtypes.quint8: "qu8",
1335    dtypes.qint16: "qi16",
1336    dtypes.quint16: "qu16",
1337    dtypes.qint32: "qi32",
1338    dtypes.bfloat16: "b16"
1339}
1340
1341
1342def function_def_from_tf_function(c_func):
1343  """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
1344  with c_api_util.tf_buffer() as buf:
1345    c_api.TF_FunctionToFunctionDef(c_func, buf)
1346    data = c_api.TF_GetBuffer(buf)
1347  fdef = function_pb2.FunctionDef()
1348  fdef.ParseFromString(compat.as_bytes(data))
1349  return fdef
1350