1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Defun decorator for defining graph-mode functions."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import functools
24import itertools
25import pprint
26import threading
27import types as types_lib
28import weakref
29
30import numpy as np
31import six
32from six.moves import map
33
34from tensorflow.core.framework import attr_value_pb2
35from tensorflow.core.framework import function_pb2
36from tensorflow.python import pywrap_tfe
37from tensorflow.python.client import pywrap_tf_session
38from tensorflow.python.eager import backprop
39from tensorflow.python.eager import backprop_util
40from tensorflow.python.eager import context
41from tensorflow.python.eager import execute
42from tensorflow.python.eager import forwardprop_util
43from tensorflow.python.eager import monitoring
44from tensorflow.python.eager import tape
45from tensorflow.python.eager.graph_only_ops import graph_placeholder
46from tensorflow.python.framework import c_api_util
47from tensorflow.python.framework import composite_tensor
48from tensorflow.python.framework import constant_op
49from tensorflow.python.framework import device as pydev
50from tensorflow.python.framework import dtypes
51from tensorflow.python.framework import error_interpolation
52from tensorflow.python.framework import errors
53from tensorflow.python.framework import func_graph as func_graph_module
54from tensorflow.python.framework import ops
55from tensorflow.python.framework import tensor_shape
56from tensorflow.python.framework import tensor_spec
57from tensorflow.python.framework import type_spec
58from tensorflow.python.ops import array_ops
59from tensorflow.python.ops import control_flow_ops
60from tensorflow.python.ops import custom_gradient
61from tensorflow.python.ops import default_gradient
62from tensorflow.python.ops import functional_ops
63from tensorflow.python.ops import gradients_util
64from tensorflow.python.ops import resource_variable_ops
65
66from tensorflow.python.platform import tf_logging as logging
67from tensorflow.python.profiler import trace
68from tensorflow.python.saved_model import save_context
69from tensorflow.python.util import _pywrap_utils
70from tensorflow.python.util import compat
71from tensorflow.python.util import function_utils
72from tensorflow.python.util import lazy_loader
73from tensorflow.python.util import memory
74from tensorflow.python.util import nest
75from tensorflow.python.util import object_identity
76from tensorflow.python.util import tf_decorator
77from tensorflow.python.util import tf_inspect
78from tensorflow.python.util.tf_export import tf_export
79
80# Loaded lazily due to a circular dependency (roughly
81# tf.function->autograph->->dataset->tf.function).
82# TODO(b/133251390): Use a regular import.
83ag_ctx = lazy_loader.LazyLoader(
84    "ag_ctx", globals(),
85    "tensorflow.python.autograph.core.ag_ctx")
86np_arrays = lazy_loader.LazyLoader(
87    "np_arrays", globals(),
88    "tensorflow.python.ops.numpy_ops.np_arrays")
89
90
91FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
92BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
93IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
94SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
95
96_graph_building_time_counter = monitoring.Counter(
97    "/tensorflow/core/tf_function/graph_building_time_usecs",
98    "Time for tf.function to build a graph (us).")
99
100
101def _make_input_signature_hashable(elem):
102  """Rewrite input signature to be hashable.
103
104  We replace nested variables in the input signature with TensorSpec in order to
105  be hashable.
106
107  Args:
108    elem: Input signature element
109
110  Returns:
111    A hashable object for the requested input signature
112  """
113  try:
114    hash(elem)
115  except TypeError:
116    # TODO(slebedev): consider using nest.
117    if isinstance(elem, tuple):
118      return tuple(map(_make_input_signature_hashable, elem))
119
120    # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
121    # all recognized types to be hashable.
122    assert isinstance(elem, weakref.ReferenceType)
123    v = elem()
124
125    if resource_variable_ops.is_resource_variable(v):
126      # We special case variables here to use unique_id as the cache key. This
127      # ensures we have to retrace whenever a different variable is passed in.
128      # This is needed to support cases where the user may use the id of a
129      # variable in the function perhaps as a lookup in a dictionary.
130      #
131      # This choice leads to more retracing when we could have possibly used the
132      # shape and dtype instead. However, we expect the number of variables in a
133      # program to be bounded, and correspondingly the number of retraces.
134      #
135      # Note we also include the class name to avoid collisions with strings.
136      return v.__class__, v._unique_id  # pylint: disable=protected-access
137
138    if _is_ndarray(v):
139      # Numpy arrays are not hashable, but when calling functions we treat them
140      # in the same way as tf.Tensors.
141      if not hasattr(v, "shape") or not hasattr(v, "dtype"):
142        # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
143        v = _as_ndarray(v)
144      return tensor_spec.TensorSpec(v.shape, v.dtype)
145
146    raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
147                     "or hashable Python objects (or nested structures of "
148                     "these types).\nGot type: %s" % type(v).__name__)
149
150  return elem
151
152
153CacheKey = collections.namedtuple("CacheKey", [
154    "input_signature",
155    "parent_graph",
156    "device_functions",
157    "colocation_stack",
158    "in_cross_replica_context",
159    "variable_policy",
160    "xla_context_id",
161])
162
163
164def _type_spec_for(x):
165  """Returns a TypeSpec for `x`, or `None` if `x` doesn't have a TensorSpec."""
166  if isinstance(x, ops.Tensor):
167    return tensor_spec.TensorSpec.from_tensor(x)
168  elif isinstance(x, type_spec.TypeSpec):
169    return x
170  elif isinstance(x, composite_tensor.CompositeTensor):
171    return x._type_spec  # pylint: disable=protected-access
172  else:
173    return None
174
175
176def _is_type_subset(a, b):
177  """Returns true if TypeSpec `b` is a subset of type `a` (or if a is None.)"""
178  if a is None:
179    return True
180  else:
181    return a.most_specific_compatible_type(b) == a
182
183
184def _shape_relaxed_type_for_composite_tensor(x):
185  """Returns a shape-relaxed TypeSpec for x (if composite) or x (if not)."""
186  if isinstance(x, composite_tensor.CompositeTensor):
187    # pylint: disable=protected-access
188    return x._type_spec._with_tensor_ranks_only()
189  else:
190    return x
191
192
193def common_shape(x, y):
194  """Find a `TensorShape` that is compatible with both `x` and `y`."""
195  if x is None != y is None:
196    raise RuntimeError(
197        "Cannot find a common shape when LHS shape is None but RHS shape "
198        "is not (or vice versa): %s vs. %s" % (x, y))
199  if x is None:
200    return None  # The associated input was not a Tensor, no shape generated.
201  if not isinstance(x, tensor_shape.TensorShape):
202    raise TypeError("Expected x to be a TensorShape but saw %s" % (x,))
203  if not isinstance(y, tensor_shape.TensorShape):
204    raise TypeError("Expected y to be a TensorShape but saw %s" % (y,))
205  if x.rank != y.rank or x.rank is None:
206    return tensor_shape.TensorShape(None)
207  dims = []
208  for dim_x, dim_y in zip(x.dims, y.dims):
209    if (dim_x != dim_y
210        or tensor_shape.dimension_value(dim_x) is None
211        or tensor_shape.dimension_value(dim_y) is None):
212      dims.append(None)
213    else:
214      dims.append(tensor_shape.dimension_value(dim_x))
215  return tensor_shape.TensorShape(dims)
216
217
218def is_same_structure(structure1,
219                      structure2,
220                      check_values=False):
221  """Check two structures for equality, optionally of types and of values."""
222  try:
223    nest.assert_same_structure(structure1, structure2, expand_composites=True)
224  except (ValueError, TypeError):
225    return False
226  if check_values:
227    flattened1 = nest.flatten(structure1, expand_composites=True)
228    flattened2 = nest.flatten(structure2, expand_composites=True)
229    # First check the types to avoid AttributeErrors.
230    if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)):
231      return False
232    return flattened1 == flattened2
233  return True
234
235
236def _parse_func_attrs(attributes):
237  """Convert the keyword arguments into function_def attributes.
238
239  Currently only support primitive types: bool, int, float and string.
240
241  Args:
242    attributes: the dictionary of attributes.
243  Returns:
244    A dict of attributes where the key is the name of attribute and the value
245      is the AttrValue proto.
246  Raises:
247    ValueError: If the kwargs contains unallowlisted name or unsupported value
248      types.
249  """
250  attrs = {}
251  for key, value in attributes.items():
252    if isinstance(value, attr_value_pb2.AttrValue):
253      attrs[key] = value
254    # bool type check has to happen before int since bool is a subclass of int.
255    elif isinstance(value, bool):
256      attrs[key] = attr_value_pb2.AttrValue(b=value)
257    elif isinstance(value, int):
258      attrs[key] = attr_value_pb2.AttrValue(i=value)
259    elif isinstance(value, float):
260      attrs[key] = attr_value_pb2.AttrValue(f=value)
261    elif isinstance(value, (str, bytes, six.text_type)):
262      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
263    else:
264      raise ValueError("Unsupported attribute type for %s with type %s" %
265                       (key, type(value)))
266  return attrs
267
268
269class _InterpolateFunctionError(object):
270  """Context Manager that interpolates the exception from 'top_level_func'."""
271
272  __slots__ = ["_func"]
273
274  def __init__(self, top_level_func):
275    self._func = top_level_func
276
277  def __enter__(self):
278    pass
279
280  def __exit__(self, typ, exc, tb):
281    if not exc or not isinstance(exc, errors.OpError):
282      return False
283    message = compat.as_text(exc.message)
284    _, tags = error_interpolation.parse_message(message)
285    g = None
286    func_stack = []
287    for t in tags:
288      if t.type == "function_node":
289        # TODO(mdan): Tests should cover this.
290        if t.name == compat.as_str(self._func.name):
291          g = self._func.graph
292        elif g:
293          next_func = g._get_function(t.name)  # pylint: disable=protected-access
294          if next_func is not None and isinstance(next_func,
295                                                  _EagerDefinedFunction):
296            g = next_func.graph
297        if g:
298          func_stack.append(g.name)
299        else:
300          func_stack.append("<unknown>")
301    if g:
302      message = error_interpolation.interpolate(message, g)
303      message += "\n\nFunction call stack:\n"
304      message += " -> ".join(func_stack)
305      message += "\n"
306      exc._message = message  # pylint: disable=protected-access
307    return False
308
309
310_function_callbacks = set()
311
312
313def add_function_callback(function_callback):
314  """Add a callback function for Function creation.
315
316  The callback function has the signature:
317
318    `def function_callback(function, name, graph, inputs, outputs):`
319
320  where:
321  - `function`: _EagerDefinedFunction being created before finalizing the graph.
322      Do not modify the function directly but instead modify the graph.
323  - `name`: name of the function.
324  - `graph`: Graph of the function.
325  - `inputs`: `tuple` of tensors used as inputs to the function.
326  - `outputs`: `tuple` of tensors used as outputs from the function.
327
328  The callback is at the top of the `_EagerDefinedFunction` construction, giving
329  callback an opportunity to make the last edits to the graph. Do not make
330  changes to `graph, inputs`, and `outputs` manually, but, instead, set the
331  `graph` as the default then define ops.
332
333  Repeated registration of the same callback function is idempotent.
334  After a callback is added, it can be removed with the
335  `remove_function_callback()` method.
336
337  Args:
338    function_callback: The callback to add.
339  """
340  _function_callbacks.add(function_callback)
341
342
343def remove_function_callback(function_callback):
344  """Remove an already-added function callback.
345
346  See the doc string of `add_function_callback()` for more information.
347
348  Args:
349    function_callback: The callback to remove.
350  """
351  _function_callbacks.remove(function_callback)
352
353
354def clear_function_callbacks():
355  """Clear all function callbacks, if any have been regisered."""
356  _function_callbacks.clear()
357
358
359_FORWARD_PREFIX = "__forward_"
360_BACKWARD_PREFIX = "__backward_"
361_INFERENCE_PREFIX = "__inference_"
362
363
364def _forward_name(n):
365  """The name of a generated forward defun named n."""
366  return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
367
368
369def _backward_name(n):
370  """The name of a generated backward defun named n."""
371  return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
372
373
374def _inference_name(n):
375  """The name of a forward-but-no-gradient defun named n."""
376  return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
377
378
379def _enclosing_xla_context():
380  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
381  graph = ops.get_default_graph()
382  while graph is not None:
383    # pylint: disable=protected-access
384    context_ = graph._get_control_flow_context()
385    # pylint: enable=protected-access
386    while context_ is not None:
387      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
388        return context_
389      context_ = context_.outer_context
390    # This may be a FuncGraph due to defuns or v2 control flow. We need to
391    # find the original graph with the XLAControlFlowContext.
392    graph = getattr(graph, "outer_graph", None)
393  return None
394
395
396class _EagerDefinedFunctionDeleter(object):
397  """Unregister function from eager context."""
398
399  __slots__ = ["name"]
400
401  def __init__(self, name):
402    self.name = name
403
404  def __del__(self):
405    try:
406      context.remove_function(self.name)
407    except TypeError:
408      # Suppress some exceptions, mainly for the case when we're running on
409      # module deletion. Things that can go wrong include the context module
410      # already being unloaded, self._handle._handle_data no longer being
411      # valid, and so on. Printing warnings in these cases is silly
412      # (exceptions raised from __del__ are printed as warnings to stderr).
413      pass  # 'NoneType' object is not callable when the handle has been
414      # partially unloaded.
415    except AttributeError:
416      pass  # 'NoneType' object has no attribute 'eager_mode' when context has
417      # been unloaded. Will catch other module unloads as well.
418
419
420# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
421# so it doesn't have the definition-generating logic and is just a container for
422# an already-defined function.
423class _EagerDefinedFunction(object):
424  """Callable with the interface of `framework.function._DefinedFunction`.
425
426  `_EagerDefinedFunction` encapsulates a function definition and its properties,
427  and it provides a method for calling the encapsulated function. Some Ops
428  take functions as attributes, which have type `func`; an instance of this
429  class may be provided as the value of these `func` attributes.
430  """
431
432  def __init__(self, name, graph, inputs, outputs, attrs):
433    """Initializes an eager defined function.
434
435    Args:
436      name: str, the name for the created function.
437      graph: Graph, the graph containing the operations in the function
438      inputs: the tensors in the graph to be used as inputs to the function
439      outputs: the tensors in the graph which will be outputs from the function
440      attrs: dict mapping names of attributes to their AttrValue values
441    """
442    for function_callback in _function_callbacks:
443      function_callback(self, name, graph, tuple(inputs), tuple(outputs))
444
445    input_ops = set(arg.op for arg in inputs)
446    operations = [op for op in graph.get_operations() if op not in input_ops]
447
448    graph_output_names = graph._output_names  # pylint: disable=protected-access
449    if (graph_output_names is not None and
450        all(ops.tensor_id(t) in graph_output_names for t in outputs)):
451      output_names = [
452          compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
453      ]
454      if len(set(output_names)) != len(output_names):
455        # There are duplicate names for some reason, probably an invalid
456        # signature. Revert to auto-naming.
457        output_names = []
458    else:
459      output_names = []
460    fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
461        graph._c_graph,  # pylint: disable=protected-access
462        compat.as_str(name),
463        False,
464        [o._c_op for o in operations],  # pylint: disable=protected-access
465        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
466        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
467        output_names,
468        [o._c_op for o in graph.control_outputs],  # pylint: disable=protected-access
469        [],  # control_output_names
470        None,
471        compat.as_str(""))
472
473    for name, attr_value in attrs.items():
474      serialized = attr_value.SerializeToString()
475      # TODO(iga): this creates and deletes a new TF_Status for every attr.
476      # It might be worth creating a convenient way to re-use status.
477      pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
478                                                     serialized)
479
480    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
481    # signature, but also in general it's nice not to depend on it.
482    with c_api_util.tf_buffer() as buffer_:
483      pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_)
484      proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
485    function_def = function_pb2.FunctionDef()
486    function_def.ParseFromString(compat.as_bytes(proto_data))
487    self._name = compat.as_bytes(function_def.signature.name)
488    with ops.init_scope():
489      if context.executing_eagerly():
490        context.ensure_initialized()
491        context.add_function(fn)
492        self._function_deleter = _EagerDefinedFunctionDeleter(self.name)
493        self._registered_on_context = True
494    self.definition = function_def
495    self.signature = function_def.signature
496    self._num_outputs = len(self.signature.output_arg)
497    self._output_types = [o.type for o in self.signature.output_arg]
498    self._output_shapes = [o.shape for o in outputs]
499    self._control_captures = graph.control_captures
500    # Shallow copy outputs since ConcreteFunction may mutate it.
501    self._func_graph_outputs = list(outputs)
502    self.grad_func_name = None
503    self.python_grad_func = None
504    self._c_func = c_api_util.ScopedTFFunction(fn)
505    self._grad_func = None
506    self.graph = graph
507    self._stateful_ops = tuple(op for op in operations if op._is_stateful)  # pylint: disable=protected-access
508
509  def add_to_graph(self, g=None):
510    """Add the function to the current context or a graph, if supplied.
511
512    Args:
513      g: the graph to add the function to. If not supplied, the function will
514        be added to the current context.
515    """
516    # pylint: disable=protected-access
517    if not g and context.executing_eagerly():
518      ctx = context.context()
519      if not ctx.has_function(self.name):
520        ctx.add_function_def(self.definition)
521    else:
522      if not g._is_function(self.name):
523        g._add_function(self)
524      for f in self.graph._functions.values():
525        if not g._is_function(f.name):
526          g._add_function(f)
527    # pylint: enable=protected-access
528
529  @property
530  def name(self):
531    return self._name
532
533  @property
534  def stateful_ops(self):
535    return self._stateful_ops
536
537  def call(self, ctx, args, cancellation_manager=None):
538    """Calls this function with `args` as inputs.
539
540    `ConcreteFunction` execution respects device annotations only if the
541    function won't be compiled with xla.
542
543    Args:
544      ctx: a Context object
545      args: a list of arguments to supply this function with.
546      cancellation_manager: a `CancellationManager` object that can be used to
547        cancel function execution.
548
549    Returns:
550      The outputs of the function call.
551
552    Raises:
553      ValueError: if the number of arguments is incorrect.
554    """
555    if len(args) != len(self.signature.input_arg):
556      raise ValueError(
557          "Arguments and signature arguments do not match. "
558          "got: %s, expected: %s " %
559          (len(args), len(list(self.signature.input_arg))))
560
561    function_call_options = ctx.function_call_options
562    if function_call_options.config_proto_serialized is None:
563      config = function_utils.get_disabled_rewriter_config()
564    else:
565      config = function_call_options.config_proto_serialized
566    executor_type = function_call_options.executor_type or ""
567
568    executing_eagerly = ctx.executing_eagerly()
569    attrs = ("executor_type", executor_type, "config_proto", config)
570    if executing_eagerly:
571      with _InterpolateFunctionError(self):
572        if cancellation_manager is None:
573          outputs = execute.execute(
574              str(self.signature.name),
575              num_outputs=self._num_outputs,
576              inputs=args,
577              attrs=attrs,
578              ctx=ctx)
579        else:
580          outputs = execute.execute_with_cancellation(
581              str(self.signature.name),
582              num_outputs=self._num_outputs,
583              inputs=args,
584              attrs=attrs,
585              ctx=ctx,
586              cancellation_manager=cancellation_manager)
587      # Replace empty list with None
588      outputs = outputs or None
589    else:
590      # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
591      # creates `PartitionedCallOp` kernels by default, or remove the previous
592      # branch if a TPU kernel is registered for `PartitionedCall`.
593      with _InterpolateFunctionError(self):
594        with ops.control_dependencies(self._control_captures):
595          # The caller must use record_operation to record this operation in the
596          # eager case, so we enforce the same requirement for the non-eager
597          # case by explicitly pausing recording. We don't have a gradient
598          # registered for PartitionedCall, so recording this operation confuses
599          # forwardprop code (GradientTape manages to ignore it).
600          with tape.stop_recording():
601            outputs = functional_ops.partitioned_call(
602                args=args,
603                f=self,
604                tout=self._output_types,
605                executing_eagerly=executing_eagerly,
606                config=config,
607                executor_type=executor_type)
608
609    for i, func_graph_output in enumerate(self._func_graph_outputs):
610      custom_gradient.copy_handle_data(func_graph_output, outputs[i])
611    if executing_eagerly:
612      return outputs
613    else:
614      # TODO(b/128924522): This additional set_shape should not be
615      # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this
616      # once that's done.
617      for i, shape in enumerate(self._output_shapes):
618        outputs[i].set_shape(shape)
619      return outputs
620
621
622def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph):
623  """Creates forward and backward functions from the function graphs."""
624  forward_function_name = _forward_name(forward_graph.name)
625  common_attributes = dict(attrs)
626  # NB: forward and backward function need to drop "_implements".
627  # attribute, because their signature contains all the intermediate tensors
628  # that they compute. Thus they don't have a stable signature which can
629  # be directly optimized downstream.
630  # See for more details:
631  # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
632  common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
633  backward_function_attr = _parse_func_attrs(
634      {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
635  backward_function_attr.update(common_attributes)
636  backward_function = ConcreteFunction(
637      backwards_graph, attrs=backward_function_attr)
638  forward_function_attr = _parse_func_attrs({
639      BACKWARD_FUNCTION_ATTRIBUTE_NAME:
640      backward_function.name})
641  forward_function_attr.update(common_attributes)
642  forward_function = _EagerDefinedFunction(
643      forward_function_name, forward_graph, forward_graph.inputs,
644      forward_graph.outputs, forward_function_attr)
645  return forward_function, backward_function
646
647
648class _DelayedRewriteGradientFunctions(object):
649  """Caches forward/backward functions with a delayed forward rewrite."""
650
651  def __init__(self, func_graph, attrs, func_graph_deleter):
652    """Construct an inference function and initialize caches."""
653    # A map from the number of forward function outputs with accepted gradients
654    # to forward and backward functions, used to cache non-tape backward
655    # function generation.
656    self._cached_function_pairs = {}
657    self._func_graph = func_graph
658    self._inference_function = _EagerDefinedFunction(
659        _inference_name(self._func_graph.name), self._func_graph,
660        self._func_graph.inputs, self._func_graph.outputs, attrs)
661    self._attrs = attrs
662    self._gradient_name = None
663    # Note that the FuncGraph is mutated later, so we need to inspect it now to
664    # figure out the user-specified outputs of the inference function.
665    self._num_inference_outputs = len(self._func_graph.outputs)
666    self._func_graph_deleter = func_graph_deleter
667
668  def forward_backward(self, num_doutputs=None):
669    """A possibly-cached pair of forward and backward functions."""
670    if num_doutputs is None:
671      num_doutputs = self._num_inference_outputs
672    forward_backward = self._cached_function_pairs.get(num_doutputs)
673    if forward_backward is not None:
674      return forward_backward
675    forward, backward = self._construct_forward_backward(num_doutputs)
676    self._cached_function_pairs[num_doutputs] = (forward, backward)
677    return forward, backward
678
679  def _construct_forward_backward(self, num_doutputs):
680    """Constructs a pair of forward and backward functions.
681
682    Args:
683      num_doutputs: The constructed backprop function will take output gradients
684        for the first `num_doutputs` outputs of the forward function. Defaults
685        to the number of outputs for the inference function, but when
686        higher-order gradients are computed this will increase to include side
687        outputs.
688
689    Returns:
690      A pair of (forward_function, backward_function):
691        forward_function: A re-generated inference function (an
692          _EagerDefinedFunction) to account for new side outputs, if any extra
693          were required when building the backward pass.
694        backward_function: A ConcreteFunction that Takes `num_doutputs`
695          arguments and returns gradients with respect to inputs of the forward
696          function.
697    """
698    trainable_outputs = [
699        output for output in self._func_graph.outputs[:num_doutputs]
700        if backprop_util.IsTrainable(output)]
701
702    signature = []
703    for t in trainable_outputs:
704      signature.append(
705          tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
706
707    def _backprop_function(*grad_ys):
708      with ops.device(None):
709        return gradients_util._GradientsHelper(  # pylint: disable=protected-access
710            trainable_outputs,
711            self._func_graph.inputs,
712            grad_ys=grad_ys,
713            src_graph=self._func_graph)
714
715    with self._func_graph.as_default():
716      backwards_graph = func_graph_module.FuncGraph(
717          _backward_name(self._func_graph.name))
718      func_graph_module.func_graph_from_py_func(
719          name=backwards_graph.name,
720          python_func=_backprop_function,
721          args=[], kwargs={},
722          signature=signature,
723          func_graph=backwards_graph)
724      backwards_graph_captures = backwards_graph.external_captures
725      captures_from_forward = [
726          c for c in backwards_graph_captures if
727          not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
728
729      existing_outputs = object_identity.ObjectIdentitySet(
730          self._func_graph.outputs)
731      for capture in captures_from_forward:
732        if capture not in existing_outputs:
733          existing_outputs.add(capture)
734          self._func_graph.outputs.append(capture)
735
736      forward_function, backward_function = _create_forward_backward_with_graph(
737          self._attrs, self._func_graph, backwards_graph)
738      return forward_function, backward_function
739
740  def _rewrite_forward_and_call_backward(self, op, *doutputs):
741    """Add outputs to the forward call and feed them to the grad function."""
742    forward_function, backwards_function = self.forward_backward(len(doutputs))
743    if not backwards_function.outputs:
744      return backwards_function.structured_outputs
745    forward_function.add_to_graph(op.graph)
746
747    # pylint: disable=protected-access
748    # Rewrite an inference call op to be a forward call op
749    op._set_func_attr("f", forward_function.name)
750    op._set_type_list_attr("Tout", forward_function._output_types)
751    op._add_outputs(
752        forward_function._output_types[len(op.outputs):],
753        forward_function._output_shapes[len(op.outputs):])
754    for i in range(len(op.outputs)):
755      func_graph_output = forward_function._func_graph_outputs[i]
756      custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
757    # pylint: enable=protected-access
758
759    capture_mapping = dict(
760        zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
761    remapped_captures = [
762        capture_mapping.get(ops.tensor_id(capture), capture)
763        for capture in backwards_function.captured_inputs
764    ]
765
766    # Replace Nones with zeros since we're calling a graph function which
767    # expects numeric inputs.
768    cleaned_doutputs = []
769    for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
770      if backprop_util.IsTrainable(placeholder):
771        if isinstance(doutput, ops.IndexedSlices):
772          # Gradient passed to a backward ConcreteFunction must be tf.Tensor,
773          # so we convert tf.IndexedSlices to tf.Tensor.
774          cleaned_doutputs.append(ops.convert_to_tensor(doutput))
775        elif doutput is not None:
776          cleaned_doutputs.append(doutput)
777        else:
778          cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
779
780    # Compute the gradients using the side outputs
781    return backwards_function._call_flat(  # pylint: disable=protected-access
782        cleaned_doutputs, remapped_captures)
783
784  def get_gradient_function(self):
785    """Returns gradient function.
786
787    The gradient rewrites an inference call op to a forward call op, but does
788    not modify a pre-existing forward call op. It then computes the gradient
789    from the output's gradients and the side outputs of the forward op.
790    """
791    return self._rewrite_forward_and_call_backward
792
793  def forward(self, inference_args=None, input_tangents=None):
794    """A forward function with only user-specified outputs.
795
796    The call operation for the returned inference function can be rewritten into
797    a forward function. This only happens if the backward function (from the
798    `backward` method) ends up being used to compute gradients.
799
800    This approach avoids constructing unnecessary graphs, but it only works if
801    we are calling this function when not executing eagerly.
802
803    Args:
804      inference_args: A flat list of Tensors, arguments to the inference
805        function. Unused, but taken for compatibility with
806        _TapeGradientFunctions.
807      input_tangents: A flat list of Tensors, jvps associated with
808        `inference_args`. Unused; if required, tape functions must be used
809        instead.
810
811    Returns:
812      An _EagerDefinedFunction.
813    """
814    del inference_args  # unused
815    if input_tangents:
816      # This class does not support special-cased forwardprop. The arguments are
817      # here for compatibility with _TapeGradientFunctions.
818      raise AssertionError(
819          "Internal error: unexpectedly got forwardprop information in a class "
820          "that does not support forwardprop.")
821    return self._inference_function
822
823  def _backward(self, outputs):
824    """Fetch a backward function for `outputs` from the forward function."""
825    def _backward_function(*args):
826      call_op = outputs[0].op
827      return self._rewrite_forward_and_call_backward(call_op, *args)
828    return _backward_function, outputs
829
830  def record(self, flat_outputs, inference_args, input_tangents):
831    """Record the function call operation.
832
833    _DelayedRewriteGradientFunctions supports only first-order backprop tape
834    gradients (and then only when graph building). It does not work with
835    higher-order tape gradients or forward autodiff, but does work with
836    higher-order symbolic gradients (tf.gradients).
837
838    Args:
839      flat_outputs: The result of running `forward`.
840      inference_args: A flat list of Tensors with inference inputs to the
841        operation.
842      input_tangents: A flat list of Tensors with input tangents consumed by the
843        operation.
844    """
845    backward_function, to_record = self._backward(flat_outputs)
846    tape.record_operation(self._inference_function.signature.name,
847                          to_record, inference_args + input_tangents,
848                          backward_function)
849
850
851# Contains information about a forward function wrapped to compute jvps.
852_ForwardWrapper = collections.namedtuple(
853    "_ForwardWrapper", (
854        # The wrapper Graph.
855        "graph",
856        # A flat list of non-tangent Tensor outputs from the wrapped forward
857        # function.
858        "outputs",
859        # Indices for output tangents, same format as
860        # forwardprop_util.pack_tangents.
861        "output_indices",
862        # A flat list of tangents for `outputs`.
863        "output_tangents"))
864
865
866class _TapeGradientFunctions(object):
867  """Caches forward and backward functions compatible with eager gradients.
868
869  In contrast to the delayed-rewrite approach in
870  `_DelayedRewriteGradientFunctions` which only works with delayed execution,
871  the forward function generated by this class has a fixed set of outputs which
872  may be preserved by a tape in order to compute gradients later.
873
874  This class is abstract; its child classes differ in how many side outputs of
875  the forward function their backward function accepts gradients for, which
876  determines whether higher-order tape gradients are possible.
877  """
878
879  def __init__(self, func_graph, attrs, func_graph_deleter,
880               forwardprop_input_indices, delayed_rewrite_functions,
881               need_gradients_for_jvps):
882    self._func_graph = func_graph
883    self._forward_graph = None
884    self._attrs = attrs
885    self._forward = None
886    self._backward = None
887    self._num_outputs = len(func_graph.outputs)
888    self._func_graph_deleter = func_graph_deleter
889    self._forwardprop_input_indices = forwardprop_input_indices
890    self._forwardprop_output_indices = None
891    self._num_forwardprop_outputs = 0
892    self._num_inference_outputs = len(func_graph.outputs)
893    self._num_trainable_inference_outputs = len(
894        [t for t in func_graph.outputs if backprop_util.IsTrainable(t)])
895    self._delayed_rewrite_functions = delayed_rewrite_functions
896    self._need_gradients_for_jvps = need_gradients_for_jvps
897
898  def _build_functions_for_outputs(
899      self, outputs, inference_args, input_tangents):
900    """Forward+backward functions where the backward function sees `outputs`."""
901    # First figure out which of `outputs` are trainable. We'll accept gradients
902    # for each of these in the backward function.
903    handles_to_variables = self._func_graph.variable_captures
904    trainable_outputs = []
905    trainable_indices = []
906    for index, output in enumerate(outputs):
907
908      if backprop_util.IsTrainable(output):
909        # Swap in the Variable object for resource handles if we can so
910        # sparse gradients work.
911        output = handles_to_variables.get(id(output), output)
912        trainable_outputs.append(output)
913        trainable_indices.append(index)
914
915    backwards_graph = func_graph_module.FuncGraph(
916        _backward_name(self._func_graph.name))
917    with backwards_graph.as_default():
918      gradients_wrt_outputs = []
919      for output in trainable_outputs:
920        gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
921            output)
922        gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
923        custom_gradient.copy_handle_data(output, gradient_placeholder)
924        gradients_wrt_outputs.append(gradient_placeholder)
925      with ops.device(None):
926        gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
927            trainable_outputs,
928            self._func_graph.inputs,
929            grad_ys=gradients_wrt_outputs,
930            src_graph=self._func_graph)
931
932      if input_tangents:
933        # Convert IndexedSlices to dense tensors (as we do elsewhere for
934        # function gradients). Our C++ bindings don't know how to handle them
935        # currently.
936        gradients_wrt_inputs = nest.map_structure(
937            lambda x: ops.convert_to_tensor(x) if x is not None else None,
938            gradients_wrt_inputs)
939      captures_from_forward = [
940          c for c in backwards_graph.external_captures
941          if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
942      ]
943      existing_outputs = object_identity.ObjectIdentitySet(
944          self._func_graph.outputs)
945      for capture in captures_from_forward:
946        if capture not in existing_outputs:
947          existing_outputs.add(capture)
948          self._func_graph.outputs.append(capture)
949
950    # The ordering of `backwards_graph.inputs` is important: inputs of
951    # `backward_function` correspond to outputs (including
952    # side outputs) of `self._tape_forward_function`.
953    backwards_graph.inputs = (
954        gradients_wrt_outputs + backwards_graph.internal_captures)
955    backwards_graph.outputs.extend(
956        grad
957        for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
958        if grad is not None)
959    backwards_graph.structured_outputs = gradients_wrt_inputs
960
961    forward_function, backward_function = _create_forward_backward_with_graph(
962        self._attrs, self._func_graph, backwards_graph)
963
964    if not input_tangents:
965      # There is no need to special-case forwardprop, so we can return the
966      # forward+backward pair we've created without further wrapping.
967      return (forward_function, self._func_graph, backward_function,
968              # No forwardprop outputs.
969              None, 0)
970    forward_wrapper = self._wrap_forward_function_with_jvps(
971        forward_function, backward_function, inference_args, input_tangents)
972    (wrapped_backwards_graph,
973     forward_wrapper) = self._wrap_backward_function_with_jvp_backprop(
974         backward_function, gradients_wrt_outputs, forward_wrapper)
975    # Now that we've added new captures, we need to make sure forward outputs
976    # are in the same order the backward function expects them to be in:
977    # [inference outputs] + [jvps] + [side outputs] + [captures].
978    forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
979    (wrapped_forward_function,
980     wrapped_backward_function) = _create_forward_backward_with_graph(
981         self._attrs, forward_wrapper.graph, wrapped_backwards_graph)
982    if (len(inference_args) + len(input_tangents)
983        != len(forward_wrapper.graph.inputs)):
984      raise AssertionError(
985          ("Internal error: the forward graph had {} inputs, but we expected"
986           " {} ({} inference inputs and {} input tangents)")
987          .format(len(len(forward_wrapper.graph.inputs)),
988                  len(inference_args) + len(input_tangents),
989                  len(inference_args), len(input_tangents)))
990    return (wrapped_forward_function, forward_wrapper.graph,
991            wrapped_backward_function, forward_wrapper.output_indices,
992            len(forward_wrapper.output_tangents))
993
994  def _wrap_forward_function_with_jvps(
995      self, forward_function, backward_function,
996      inference_args, input_tangents):
997    """Adds inline JVP computation to a forward function."""
998    forward_wrapper_graph = func_graph_module.FuncGraph(
999        _forward_name(self._func_graph.name))
1000    with forward_wrapper_graph.as_default():
1001      # Tell forward accumulators to free up space for new JVP computations,
1002      # since one may be in the process of computing a JVP (if that computation
1003      # triggered this function building).
1004      #
1005      # We'll make symbolic versions of input JVPs, run the forward function
1006      # under forward accumulators to get symbolic output JVPs, then set those
1007      # as outputs of the new wrapped forward function.
1008      with forwardprop_util.push_forwardprop_state():
1009        forward_captures = {
1010            ops.tensor_id(internal): external
1011            for external, internal in self._func_graph.captures}
1012        for input_index, real_input in enumerate(self._func_graph.inputs):
1013          # This loop is more or less equivalent to running tf.identity on each
1014          # of self._func_graph.inputs. However, doing that also captures jvps
1015          # for resource handles, which confuses the jvp capturing code below
1016          # (since primal inputs are interwoven with jvp inputs).
1017          input_placeholder = array_ops.placeholder(
1018              dtype=real_input.dtype,
1019              shape=real_input.shape)
1020          capture = forward_captures.get(ops.tensor_id(real_input))
1021          if capture is not None:
1022            forward_wrapper_graph.add_capture(capture, input_placeholder)
1023            if capture.dtype == dtypes.resource:
1024              custom_gradient.copy_handle_data(capture, input_placeholder)
1025          else:
1026            forward_wrapper_graph.inputs.append(input_placeholder)
1027        for inp, arg in zip(forward_wrapper_graph.inputs, inference_args):
1028          tape.record_operation(
1029              "captured_value", [inp], [arg],
1030              backward_function=lambda x: [x],
1031              forward_function=lambda x: [x])
1032        num_inference_inputs = len(inference_args)
1033        for tape_indices in self._forwardprop_input_indices:
1034          for input_index, jvp_index in tape_indices:
1035            input_placeholder = forward_wrapper_graph.inputs[input_index]
1036            if len(forward_wrapper_graph.inputs) != jvp_index:
1037              raise AssertionError(
1038                  ("Internal error: expected {} forward graph inputs, but "
1039                   "found {}.")
1040                  .format(jvp_index, len(forward_wrapper_graph.inputs)))
1041            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
1042                input_placeholder)
1043            jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
1044            external_jvp = input_tangents[jvp_index - num_inference_inputs]
1045            forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder)
1046            tensor_shape.TensorShape(
1047                external_jvp.shape).assert_is_compatible_with(
1048                    jvp_placeholder.shape)
1049            tape.record_operation(
1050                "captured_value",
1051                [jvp_placeholder],
1052                [external_jvp],
1053                backward_function=lambda x: [x],
1054                forward_function=lambda x: [x])
1055        forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs]
1056        gradient_function = (
1057            self._delayed_rewrite_functions._rewrite_forward_and_call_backward)  # pylint: disable=protected-access
1058        with ops.get_default_graph()._override_gradient_function(  # pylint: disable=protected-access
1059            {"PartitionedCall": gradient_function,
1060             "StatefulPartitionedCall": gradient_function}):
1061          forward_outputs = forward_function.call(context.context(),
1062                                                  forward_inputs)
1063          if isinstance(forward_outputs, ops.Operation):
1064            # _wrapped_backward_function expects a list, but if the function has
1065            # no outputs its call() returns an Operation. We need to undo that
1066            # so we don't cause problems later.
1067            forward_outputs = []
1068        py_backward, _ = self._wrap_backward_function(
1069            self._func_graph, backward_function, forward_outputs)
1070      # We will never request backward tape gradients for this operation
1071      # directly since we're wrapping the call; forwardprop will call the
1072      # backward function (and nested forward accumulators may build
1073      # higher-order gradients), but any watching GradientTapes should ignore
1074      # it.
1075      #
1076      # TODO(allenl): It might be better to explicitly stop backward recording
1077      # so we don't use the second-order tape cases unnecessarily.
1078      tape.record_operation_forwardprop_only(
1079          forward_function.signature.name,
1080          forward_outputs, forward_inputs, py_backward, None)
1081      output_indices, output_tangents = (
1082          pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
1083      output_tangents = [forward_wrapper_graph.capture(t)
1084                         for t in output_tangents]
1085    return _ForwardWrapper(
1086        graph=forward_wrapper_graph, outputs=forward_outputs,
1087        output_indices=output_indices, output_tangents=output_tangents)
1088
1089  def _wrap_backward_function_with_jvp_backprop(
1090      self, backward_function, gradients_wrt_outputs, forward_wrapper):
1091    """Wraps `backward_function` to include gradients for JVPs."""
1092    wrapped_backwards_graph = func_graph_module.FuncGraph(
1093        _backward_name(self._func_graph.name))
1094    with wrapped_backwards_graph.as_default():
1095      py_backward, recorded_outputs = self._wrap_backward_function(
1096          self._func_graph, backward_function, forward_wrapper.outputs)
1097      trainable_index = 0
1098      forward_doutputs = []
1099      doutput_args = []
1100      for output in recorded_outputs:
1101        if backprop_util.IsTrainable(output):
1102          doutput = gradients_wrt_outputs[trainable_index]
1103          doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape)
1104          doutput_args.append(doutput_placeholder)
1105          forward_doutputs.append(doutput_placeholder)
1106          trainable_index += 1
1107        else:
1108          doutput_args.append(None)
1109
1110      dinputs = py_backward(*doutput_args)
1111      existing_outputs = object_identity.ObjectIdentitySet(
1112          forward_wrapper.outputs + forward_wrapper.output_tangents)
1113      num_processed_output_tangents = 0
1114      gradients_wrt_output_tangents = []
1115      tangent_doutputs = []
1116      output_tangents = forward_wrapper.output_tangents
1117      output_indices = forward_wrapper.output_indices
1118      if self._need_gradients_for_jvps:
1119        # TODO(allenl): Consider using a throwaway graph to avoid extra gradient
1120        # evaluations; gradients for jvps may have common subgraphs.
1121        while num_processed_output_tangents != len(output_tangents):
1122          for output in output_tangents[num_processed_output_tangents:]:
1123            gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
1124                output)
1125            placeholder = graph_placeholder(gradient_dtype, gradient_shape)
1126            gradients_wrt_output_tangents.append(placeholder)
1127            tangent_doutputs.append(placeholder)
1128          num_processed_output_tangents = len(output_tangents)
1129          with ops.device(None):
1130            gradients_wrt_inputs = gradients_util._GradientsHelper(  # pylint: disable=protected-access
1131                output_tangents,
1132                forward_wrapper.graph.inputs,
1133                grad_ys=gradients_wrt_output_tangents,
1134                src_graph=forward_wrapper.graph)
1135          dinputs = [
1136              backprop.aggregate_indexed_slices_gradients((existing, new))
1137              for existing, new in zip(dinputs, gradients_wrt_inputs)
1138              if existing is not None or new is not None]
1139          dinputs.extend(gradients_wrt_inputs[len(dinputs):])
1140          captures_from_forward = [
1141              c for c in wrapped_backwards_graph.external_captures
1142              if (not isinstance(c, ops.EagerTensor)
1143                  and c.graph is forward_wrapper.graph)]
1144          for capture in captures_from_forward:
1145            if capture not in existing_outputs:
1146              existing_outputs.add(capture)
1147              forward_wrapper.outputs.append(capture)
1148          output_indices, output_tangents = (
1149              forwardprop_util.pack_tangents(forward_wrapper.outputs))
1150          output_tangents = [forward_wrapper.graph.capture(t)
1151                             for t in output_tangents]
1152          for t in output_tangents:
1153            existing_outputs.add(t)
1154    wrapped_backwards_graph.inputs = (
1155        forward_doutputs[:self._num_trainable_inference_outputs]
1156        + tangent_doutputs
1157        + forward_doutputs[self._num_trainable_inference_outputs:]
1158        + wrapped_backwards_graph.internal_captures)
1159    wrapped_backwards_graph.structured_outputs = dinputs
1160    wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None]
1161    return (wrapped_backwards_graph,
1162            forward_wrapper._replace(output_indices=output_indices,
1163                                     output_tangents=output_tangents))
1164
1165  def _shuffle_forward_outputs(self, forward_wrapper):
1166    """Reorders function outputs so captures are last."""
1167    def _index_map(original):
1168      if original < self._num_inference_outputs:
1169        return original
1170      if original >= len(forward_wrapper.outputs):
1171        return (original - len(forward_wrapper.outputs)
1172                + self._num_inference_outputs)
1173      return original + len(forward_wrapper.output_tangents)
1174    output_indices = nest.map_structure(
1175        _index_map, forward_wrapper.output_indices)
1176    forward_wrapper.graph.outputs = (
1177        forward_wrapper.outputs[:self._num_inference_outputs]
1178        + forward_wrapper.output_tangents
1179        + forward_wrapper.outputs[self._num_inference_outputs:])
1180    return forward_wrapper._replace(output_indices=output_indices)
1181
1182  def forward(self, inference_args, input_tangents):
1183    """Construct or fetch a forward function with side-outputs.
1184
1185    When graph building without a tape active, symbolic gradients rely on
1186    regenerating the backward function for higher-order gradients (to account
1187    for new side outputs of the rewritten forward function call). Thus there is
1188    no fixed backward function for this case. However, when a tape is active
1189    (eager or graph building), we generate fixed backward and forward functions
1190    at forward function call time.
1191
1192    This difference between the tape and non-tape cases is to avoid building
1193    unneeded backward functions while graph building (where we may or may not
1194    eventually need gradients).
1195
1196    Args:
1197      inference_args: A flat list of Tensors, arguments to the inference
1198        function.
1199      input_tangents: A flat list of Tensors, jvps associated with
1200        `inference_args`.
1201
1202    Returns:
1203      A forward _EagerDefinedFunction.
1204    """
1205    if self._forward is None:
1206      (self._forward, self._forward_graph, self._backward,
1207       self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
1208           self._forward_and_backward_functions(inference_args, input_tangents))
1209    return self._forward
1210
1211  def _wrap_backward_function(self, forward_graph, backward, outputs):
1212    """Create a backward function given `outputs` from the forward function."""
1213    capture_mapping = dict(
1214        zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
1215    captured_inputs = backward.captured_inputs
1216    remapped_captures = [
1217        capture_mapping.get(ops.tensor_id(capture), capture)
1218        for capture in captured_inputs
1219    ]
1220    if any(t.graph is forward_graph for t in remapped_captures
1221           if not isinstance(t, ops.EagerTensor)):
1222      raise AssertionError(
1223          "Internal error: failed to map all backward graph captures to the "
1224          "forward graph. Incorrectly mapped: {}".format(
1225              [t for t in remapped_captures
1226               if (not isinstance(t, ops.EagerTensor)
1227                   and t.graph is not forward_graph)]))
1228    # We may need to use zeros_like to get a zero for variant Tensors with
1229    # unconnected gradients. We do that in advance so we don't have to hold on
1230    # to the outputs themselves, which may not be needed otherwise.
1231    variant_zeros_like = {}
1232    backward_function_inputs = (len(backward.inputs) - len(captured_inputs))
1233    recorded_outputs = []
1234    trainable_recorded_outputs = 0
1235    skip_positions = []
1236    if self._num_forwardprop_outputs and not self._need_gradients_for_jvps:
1237      relevant_outputs = (
1238          outputs[:self._num_inference_outputs]
1239          + outputs[self._num_inference_outputs
1240                    + self._num_forwardprop_outputs:])
1241    else:
1242      relevant_outputs = outputs
1243    for output_index, output in enumerate(relevant_outputs):
1244      if trainable_recorded_outputs < backward_function_inputs:
1245        recorded_outputs.append(output)
1246      if backprop_util.IsTrainable(output):
1247        trainable_recorded_outputs += 1
1248      else:
1249        skip_positions.append(output_index)
1250      if output.dtype == dtypes.variant:
1251        variant_zeros_like[output_index] = default_gradient.zeros_like(output)
1252
1253    def _backward_function_wrapper(*args):
1254      """Process output gradients and call the backward function."""
1255      if not backward.outputs:
1256        return backward.structured_outputs
1257
1258      processed_args = []
1259      input_index = 0
1260      for output_index, arg in enumerate(args):
1261        # Convert IndexedSlices to dense tensors. The IndexedSlices optimization
1262        # is only really effective when doing tf.gather(variable) as the
1263        # adjoint functions for most operations are unlikely to preserve the
1264        # sparsity in IndexedSlices.
1265        if isinstance(arg, ops.IndexedSlices):
1266          arg = ops.convert_to_tensor(arg)
1267        if output_index in skip_positions:
1268          continue
1269        if arg is None:
1270          # We're calling a (non-polymorphic) ConcreteFunction, so we need to
1271          # have a Tensor value for each Tensor we thought would be trainable
1272          # based on its dtype, even if it ended up being unconnected.
1273          input_placeholder = backward.inputs[
1274              input_index]
1275          if input_placeholder.dtype == dtypes.variant:
1276            arg = variant_zeros_like[output_index]
1277          else:
1278            arg = array_ops.zeros(
1279                *default_gradient.shape_and_dtype(input_placeholder))
1280        processed_args.append(arg)
1281        input_index += 1
1282        if input_index >= backward_function_inputs:
1283          break
1284      return backward._call_flat(  # pylint: disable=protected-access
1285          processed_args, remapped_captures)
1286
1287    return _backward_function_wrapper, recorded_outputs
1288
1289  def record(self, flat_outputs, inference_args, input_tangents):
1290    """Record the function call operation.
1291
1292    For backprop, indicates the backward function to use and which new Tensors
1293    must be watched. For forwardprop from eager, the function call itself will
1294    have produced tangents which need to be recorded.
1295
1296    Args:
1297      flat_outputs: The result of running `forward`.
1298      inference_args: A flat list of Tensors with inference inputs to the
1299        operation.
1300      input_tangents: A flat list of Tensors with input tangents consumed by the
1301        operation.
1302    """
1303    backward_function, to_record = self._wrap_backward_function(
1304        self._forward_graph, self._backward, flat_outputs)
1305    if self._forwardprop_output_indices:
1306      tape.record_operation_backprop_only(
1307          self._forward.signature.name,
1308          to_record, inference_args,
1309          backward_function)
1310      tape.record_operation_forwardprop_only(
1311          self._forward.signature.name,
1312          flat_outputs, inference_args + input_tangents,
1313          backward_function,
1314          self._forwardprop_output_indices)
1315    else:
1316      tape.record_operation(self._forward.signature.name,
1317                            to_record, inference_args + input_tangents,
1318                            backward_function)
1319
1320
1321class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
1322  """Caches tape-friendly functions for first-order gradients."""
1323
1324  def __init__(self, func_graph, attrs, func_graph_deleter,
1325               forwardprop_input_indices, delayed_rewrite_functions,
1326               need_gradients_for_jvps):
1327    super(_FirstOrderTapeGradientFunctions, self).__init__(
1328        func_graph, attrs, func_graph_deleter, forwardprop_input_indices,
1329        delayed_rewrite_functions, need_gradients_for_jvps)
1330    self._func_graph_deleter = func_graph_deleter
1331    self._forwardprop_input_indices = forwardprop_input_indices
1332
1333  def _forward_and_backward_functions(self, inference_args, input_tangents):
1334    """Shortcut for when only first-order gradients are required.
1335
1336    The returned backward function does not accept gradients with respect to
1337    side output of forward_function. This is fine as long as the user can't
1338    possibly request second order tape gradients, as when they've used a single
1339    non-persistent GradientTape. Since we don't need the backward function to
1340    take gradients with respect to side outputs, we can skip some potentially
1341    slow graph building.
1342
1343    Args:
1344      inference_args: A flat list of Tensors, arguments to the inference
1345        function.
1346      input_tangents: A flat list of Tensors, jvps associated with
1347        `inference_args`.
1348
1349    Returns:
1350      A tuple of (forward_function, backward_function):
1351        forward_function: Takes the same inputs as the inference function, but
1352          returns side outputs used by backward_function in addition to the
1353          inference function's outputs.
1354        backward_function: Takes side outputs from forward_function and
1355          gradients with respect to the "real" outputs of forward_function and
1356          returns gradients with respect to the inputs.
1357    """
1358    outputs = self._func_graph.outputs[:self._num_inference_outputs]
1359    return self._build_functions_for_outputs(
1360        outputs, inference_args, input_tangents)
1361
1362
1363class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
1364  """Caches tape-friendly functions for higher-order gradients."""
1365
1366  # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
1367  # generalizing if so.
1368  def _forward_and_backward_functions(self, inference_args, input_tangents):
1369    """Forward and backward functions suitable for higher-order gradients.
1370
1371    Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
1372    this method accepts gradients for all of the outputs of the returned forward
1373    function, including side outputs.
1374
1375    Args:
1376      inference_args: A flat list of Tensors, arguments to the inference
1377        function.
1378      input_tangents: A flat list of Tensors, jvps associated with
1379        `inference_args`.
1380
1381    Returns:
1382      A tuple of (forward_function, backward_function):
1383        forward_function: Takes the same inputs as the inference function, but
1384          returns side outputs used by backward_function in addition to the
1385          inference function's outputs.
1386        backward_function: Takes side outputs from forward_function and
1387          gradients with respect to all of its outputs, real and side. Returns
1388          gradients with respect to the inputs.
1389    """
1390    outputs = []
1391    iteration_count = 0
1392    # First we need to figure out how many side outputs from the forward pass
1393    # will be required. We do this in a temporary graph to avoid actually
1394    # running multiple copies of the backward pass (one per _GradientsHelper
1395    # call).
1396    #
1397    # While computing gradients, the backward function captures Tensors from
1398    # the forward function. We add these as side outputs of the original
1399    # function. However, we then need to accept output gradients with respect
1400    # to these side outputs for higher order gradients to work. Thus we loop
1401    # until the number of outputs of the function stabilizes. Note that this
1402    # is only required for tape gradients, where we need to declare in advance
1403    # all of the forward op's outputs: symbolic gradients with tf.gradients
1404    # instead rely on regenerating backward functions when higher-order
1405    # gradients are requested.
1406    while (len(outputs) < len(self._func_graph.outputs)
1407           # It's possible for gradient generation to add new ops to the forward
1408           # pass. If all of the new outputs are non-trainable, there's no
1409           # reason to continue.
1410           and any(backprop_util.IsTrainable(output)
1411                   for output in self._func_graph.outputs[len(outputs):])):
1412      iteration_count += 1
1413      if iteration_count >= 20 and iteration_count % 5 == 0:
1414        new_op_with_trainable_output = None
1415        num_new_trainable_outputs = 0
1416        for output in self._func_graph.outputs[len(outputs):]:
1417          if backprop_util.IsTrainable(output):
1418            num_new_trainable_outputs += 1
1419            new_op_with_trainable_output = output.op
1420        logging.warning(
1421            ("Determining side outputs for the function '{}' is taking longer "
1422             "than expected ({} iterations, typically this converges in 5 or "
1423             "so). This could indicate that a gradient registration is adding "
1424             "new ops to the forward pass every time gradients are generated. "
1425             "{} new trainable output(s) were added this iteration, one from "
1426             "the following op:\n {}\nThis may indicate a TensorFlow bug, or "
1427             "an issue in a tf.custom_gradient.")
1428            .format(
1429                self._func_graph.name, iteration_count,
1430                num_new_trainable_outputs, new_op_with_trainable_output))
1431      outputs = list(self._func_graph.outputs)
1432      self._build_functions_for_outputs(
1433          outputs, inference_args, input_tangents)
1434
1435    (forward_function, forward_graph,
1436     backward_function, output_indices, num_output_tangents) = (
1437         self._build_functions_for_outputs(
1438             outputs, inference_args, input_tangents))
1439    if (len(self._func_graph.outputs) > len(outputs)
1440        and any(backprop_util.IsTrainable(output)
1441                for output in self._func_graph.outputs[len(outputs):])):
1442      raise AssertionError(
1443          ("Unexpectedly added new outputs to the forward function when "
1444           "building the backward function: {}").format(
1445               self._func_graph.outputs[len(outputs):]))
1446    return (forward_function, forward_graph, backward_function, output_indices,
1447            num_output_tangents)
1448
1449
1450class _ForwardBackwardCall(object):
1451  """Holds the state of a function call between execution and recording."""
1452
1453  __slots__ = [
1454      "_functions", "_inference_args", "_input_tangents", "_tape_watching"
1455  ]
1456
1457  def __init__(self, functions, inference_args, input_tangents, tape_watching):
1458    """Collects information about the function call.
1459
1460    Args:
1461      functions: An object which produces forward and backward functions, either
1462        a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object.
1463      inference_args: A flat list of Tensors, arguments to the inference
1464        function.
1465      input_tangents: A flat list of Tensors, jvps associated with
1466        `inference_args`.
1467      tape_watching: Boolean, with True indicating that recording is necessary.
1468    """
1469    self._functions = functions
1470    self._inference_args = inference_args
1471    self._input_tangents = input_tangents
1472    self._tape_watching = tape_watching
1473
1474  def forward(self):
1475    """Builds or retrieves a forward function for this call."""
1476    forward_function = self._functions.forward(
1477        self._inference_args, self._input_tangents)
1478    return forward_function, self._inference_args + self._input_tangents
1479
1480  def record(self, flat_outputs):
1481    """Given outputs from the execution of `forward`, records the operation."""
1482    if (self._tape_watching
1483        and not isinstance(flat_outputs, ops.Operation)
1484        and flat_outputs is not None):
1485      # We only record function calls which have outputs, and then only when a
1486      # tape is watching.
1487      self._functions.record(
1488          flat_outputs, self._inference_args, self._input_tangents)
1489
1490
1491# Sentinel value used by with ConcreteFunction's structured signature to
1492# indicate that a non-tensor parameter should use the value that was
1493# specified when the concrete function was created.
1494_BOUND_VALUE = object()
1495
1496
1497class ConcreteFunction(object):
1498  """Callable object encapsulating a function definition and its gradient.
1499
1500  `ConcreteFunction` is a callable that encapsulates a function definition and
1501  is differentiable under `tf.GradientTape` objects.
1502  """
1503
1504  def __init__(self,
1505               func_graph,
1506               attrs=None,
1507               shared_func_graph=True,
1508               function_spec=None):
1509    """Initialize a `ConcreteFunction`.
1510
1511    Args:
1512      func_graph: An instance of FuncGraph: the function body to wrap.
1513      attrs: (optional) dict mapping names of attributes to their AttrValue
1514        values. Attributes in `attrs` will be included in this function's
1515        definition.
1516     shared_func_graph: If False, the ConcreteFunction takes ownership of
1517       `func_graph` and will break reference cycles when it is deleted. This
1518       makes the FuncGraph inoperable.
1519     function_spec: FunctionSpec for the original function.  If not specified,
1520       then this ConcreteFunction may only be called using the flat signature.
1521
1522    Raises:
1523      ValueError: If number of input_placeholders is not equal to the number
1524        of function inputs.
1525    """
1526    # _arg_keywords and _num_positional_args define the flat signature.  They
1527    # are assigned after construction.
1528    self._arg_keywords = None
1529    self._num_positional_args = None
1530
1531    self._func_graph = func_graph
1532    self._captured_inputs = self._func_graph.external_captures
1533    self._captured_closures = self._func_graph.deferred_external_captures
1534
1535    # function_spec defines the structured signature.
1536    self._set_function_spec(function_spec)
1537
1538    if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
1539      # The alternative is to silently drop "implements" tag
1540      # but it seems likely it would lead to hard to catch bugs.
1541      # Another alternative is to make func_body to preserve the order
1542      # of arguments if variables are present. Yet another option
1543      # is to automatically replace variables as arguments to functions
1544      # to v.read_value() whenever "implements" tag is present
1545      # Anytime we annotate existing function we probably want to wrap
1546      # it with safe read_value for backward compatibility.
1547      has_resource_vars = any(inp.dtype == dtypes.resource
1548                              for inp in self.inputs)
1549
1550      assert not any(
1551          (has_resource_vars, self._captured_inputs, self._captured_closures)
1552      ), ('Function {name} has "{attr}={value}" attribute and thus can not '
1553          "depend on any tensors outside of its signature or modify variables. "
1554          "\n\nNote: variables are always captured and cause function "
1555          "re-tracing for every variable called.\n"
1556          "  inputs: {inputs}\n  captures: {captured}\n"
1557          "  closures: {closures}.\n\n"
1558          "To pass a variable to such function use  "
1559          "use variable.read_value().".format(
1560              name=func_graph.name,
1561              attr=IMPLEMENTS_ATTRIBUTE_NAME,
1562              value=attrs[IMPLEMENTS_ATTRIBUTE_NAME],
1563              inputs=self.inputs,
1564              captured=self._captured_inputs,
1565              closures=self._captured_closures))
1566    self._output_shapes = tuple(
1567        output.shape for output in self._func_graph.outputs)
1568    self._attrs = _parse_func_attrs(attrs or {})
1569
1570    if shared_func_graph:
1571      self._garbage_collector = None
1572    else:
1573      self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph)
1574
1575    # Pairs of forward and backward functions used for computing gradients.
1576    #
1577    # These each get a reference to the FuncGraph deleter since they use the
1578    # FuncGraph directly.
1579    self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
1580        func_graph, self._attrs, self._garbage_collector)
1581    self._first_order_tape_functions = {}
1582    self._higher_order_tape_functions = {}
1583    # Cache the inference function to avoid a (Python) function call when not
1584    # building gradients.
1585    self._inference_function = self._delayed_rewrite_functions.forward()
1586
1587  def _set_function_spec(self, function_spec):
1588    """Enables the structured signature by supplying a function_spec."""
1589    self._function_spec = None
1590    self._pre_initialized_function_spec = function_spec
1591
1592    # Note: when ConcreteFunctions are built by recreate_function() in
1593    # function_deserialization.py, they don't have a structured_input_signature
1594    # yet.  In that case, _initialize_function_spec() gets called by
1595    # _setup_functions_structures() in load.py.
1596    if (function_spec is not None and
1597        self.structured_input_signature is not None):
1598      self._initialize_function_spec()
1599
1600  def _initialize_function_spec(self):
1601    """Updates `self._function_spec` to include varargs and bound variables.
1602
1603    Adds new positional arguments for any varargs (i.e., for args that are
1604    in `structured_input_signature`, but not in the original fullargspec.args).
1605
1606    Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
1607    all args and kwargs in `structured_input_signature`.
1608
1609    Sets `varkw` and `varargs` to None.
1610    """
1611    if self._pre_initialized_function_spec is None:
1612      return  # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
1613    assert not self._function_spec, "already initialized"
1614    function_spec = self._pre_initialized_function_spec
1615    args = function_spec.fullargspec.args
1616    arg_specs, kwarg_specs = self.structured_input_signature
1617    vararg_indices = range(len(function_spec.arg_names), len(arg_specs))
1618    fullargspec = tf_inspect.FullArgSpec(
1619        args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices],
1620        varargs=None,
1621        varkw=None,
1622        defaults=[_BOUND_VALUE] * len(arg_specs),
1623        kwonlyargs=list(sorted(kwarg_specs)),
1624        kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
1625        annotations=function_spec.fullargspec.annotations)
1626    self._function_spec = FunctionSpec(
1627        fullargspec,
1628        function_spec.is_method,
1629        function_spec.input_signature,
1630        function_spec.is_pure,
1631        name=self._func_graph.name)
1632
1633  @property
1634  def variables(self):
1635    """Sequence of variables for this function."""
1636    return tuple(self._func_graph.variables)
1637
1638  @property
1639  def trainable_variables(self):
1640    """Sequence of trainable variables for this function."""
1641    return tuple(self._func_graph.trainable_variables)
1642
1643  def __call__(self, *args, **kwargs):
1644    """Executes the wrapped function.
1645
1646    ConcreteFunctions have two signatures:
1647
1648    * The signature of the original function wrapped by this ConcreteFunction.
1649    * A flat signature, where each argument accepts a single Tensor.
1650
1651    The original function signature is generally preferred, but the flat input
1652    signature is supported for backward compatibility.
1653
1654    ### Original Function Signature
1655
1656    When calling a ConcreteFunction with the signature of the original function,
1657    each argument must match the type or value that was used when the
1658    ConcreteFunction's graph was traced.  In particular:
1659
1660    * Tensor arguments (including CompositeTensors, such as RaggedTensor) must
1661      have matching `TypeSpec`s.
1662    * Non-Tensor arguments (such as booleans or ints) must have equal values.
1663    * Nested arguments (such as lists, tuples, or dictionaries) must have the
1664      same nesting structure; and each nested value must have a matching type
1665      or value.
1666
1667    The default value for any arguments that were traced with non-Tensor values
1668    is the value that was used in the trace.  Arguments that were traced with
1669    tensor arguments do not have a default value (even if the original function
1670    had a default value for that argument).
1671
1672    ### Flat Signature
1673
1674    When calling a ConcreteFunction with the flat signature, the arguments
1675    correspond to the flattened component tensors of the arguments that were
1676    used to construct the ConcreteFunction.  Parameter names are assigned based
1677    on `TensorSpec.name` (when specified) or the original argument names (with
1678    suffixes automatically added for nested arguments or composite tensors with
1679    multiple components).
1680
1681    Args:
1682      *args: Positional arguments to the concrete function.
1683      **kwargs: Keyword arguments to the concrete function.
1684
1685    Returns:
1686      The result of applying the TF function on the given Tensors.
1687
1688    Raises:
1689      AssertionError: If this `ConcreteFunction` was not created through
1690        `get_concrete_function`.
1691      TypeError: If the arguments do not match the function's signature.
1692    """
1693    return self._call_impl(args, kwargs)
1694
1695  def _call_impl(self, args, kwargs, cancellation_manager=None):
1696    """See `__call__` for details."""
1697    with trace.Trace(self._func_graph.name, tf_function_call="concrete"):
1698      # Construct the list of input tensors: check if the structured signature
1699      # applies first; and if not, then use the flat signature.
1700      if self._function_spec is not None:
1701        try:
1702          return self._call_with_structured_signature(args, kwargs,
1703                                                      cancellation_manager)
1704        except TypeError as structured_err:
1705          try:
1706            return self._call_with_flat_signature(args, kwargs,
1707                                                  cancellation_manager)
1708          except TypeError:
1709            raise structured_err
1710
1711      return self._call_with_flat_signature(args, kwargs, cancellation_manager)
1712
1713  def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
1714    """Executes the wrapped function with the flat signature.
1715
1716    Args:
1717      args: Positional arguments to the concrete function.
1718      kwargs: Keyword arguments to the concrete function.
1719      cancellation_manager: A `CancellationManager` that can be used to cancel
1720        function invocation.
1721
1722    Returns:
1723      The result of applying the function on the Tensors/Variables contained in
1724      `args` and `kwargs`.
1725    Raises:
1726      TypeError: if `args` and `kwargs` do not match the flat signature of this
1727        `ConcreteFunction`.
1728    """
1729    if len(args) > self._num_positional_args:
1730      raise TypeError(
1731          "{} takes {} positional arguments but {} were given".format(
1732              self._flat_signature_summary(), self._num_positional_args,
1733              len(args)))
1734    args = list(args)
1735    kwargs = dict(kwargs)
1736    for keyword in self._arg_keywords[len(args):]:
1737      try:
1738        args.append(kwargs.pop(compat.as_str(keyword)))
1739      except KeyError:
1740        specified_keywords = (
1741            list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
1742        raise TypeError("{} missing required arguments: {}".format(
1743            self._flat_signature_summary(), ", ".join(
1744                sorted(set(self._arg_keywords) - set(specified_keywords)))))
1745    if kwargs:
1746      positional_arg_keywords = set(self._arg_keywords[:len(args)])
1747      for unused_key in kwargs:
1748        if unused_key in positional_arg_keywords:
1749          raise TypeError("{} got two values for argument '{}'".format(
1750              self._flat_signature_summary(), unused_key))
1751      raise TypeError("{} got unexpected keyword arguments: {}.".format(
1752          self._flat_signature_summary(), ", ".join(sorted(kwargs))))
1753
1754    for i, arg in enumerate(args):
1755      if not isinstance(
1756          arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
1757        raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
1758                        "got {} ({})".format(self._flat_signature_summary(), i,
1759                                             type(arg).__name__, str(arg)))
1760    return self._call_flat(args, self.captured_inputs, cancellation_manager)
1761
1762  def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
1763    """Executes the wrapped function with the structured signature.
1764
1765    Args:
1766      args: Positional arguments to the concrete function.
1767      kwargs: Keyword arguments to the concrete function.
1768      cancellation_manager: A `CancellationManager` that can be used to cancel
1769        function invocation.
1770
1771    Returns:
1772      The result of applying the function on the Tensors/Variables contained in
1773      `args` and `kwargs`.
1774    Raises:
1775      TypeError: if `args` and `kwargs` do not match the structured signature
1776        of this `ConcreteFunction`.
1777    """
1778    args, kwargs, _, filtered_flat_args = \
1779        self._function_spec.canonicalize_function_inputs(*args, **kwargs)
1780    self._structured_signature_check_missing_args(args, kwargs)
1781    self._structured_signature_check_unexpected_args(args, kwargs)
1782    self._structured_signature_check_arg_types(args, kwargs)
1783    return self._call_flat(
1784        filtered_flat_args,
1785        captured_inputs=self.captured_inputs,
1786        cancellation_manager=cancellation_manager)
1787
1788  def _structured_signature_check_missing_args(self, args, kwargs):
1789    """Raises a TypeError if any args are missing."""
1790    arg_specs, kwarg_specs = self.structured_input_signature
1791    missing_arguments = []
1792    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1793      if arg is _BOUND_VALUE and _contains_type_spec(spec):
1794        missing_arguments.append(self._function_spec.arg_names[i])
1795    for (name, arg) in kwargs.items():
1796      if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
1797        missing_arguments.append(name)
1798    if missing_arguments:
1799      raise TypeError("{} missing required arguments: {}".format(
1800          self._structured_signature_summary(),
1801          ", ".join(sorted(missing_arguments))))
1802
1803  def _structured_signature_check_unexpected_args(self, args, kwargs):
1804    """Raises a TypeError if there are any extra args."""
1805    arg_specs, kwarg_specs = self.structured_input_signature
1806    if len(args) > len(arg_specs):
1807      raise TypeError(
1808          "{} takes {} positional arguments but {} were given".format(
1809              self._structured_signature_summary(),
1810              len(self._function_spec.arg_names), len(args)))
1811    if len(kwargs) > len(kwarg_specs):
1812      extra_args = set(kwargs) - set(kwarg_specs)
1813      raise TypeError("{} got unexpected keyword arguments: {}".format(
1814          self._structured_signature_summary(), ", ".join(extra_args)))
1815
1816  def _structured_signature_check_arg_types(self, args, kwargs):
1817    """Raises a TypeError if any args have the wrong type."""
1818    # Check argument types
1819    arg_specs, kwarg_specs = self.structured_input_signature
1820    for i, (arg, spec) in enumerate(zip(args, arg_specs)):
1821      name = self._function_spec.arg_names[i]
1822      self._structured_signature_check_arg_type(arg, spec, name)
1823    for (name, arg) in kwargs.items():
1824      self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
1825
1826  def _structured_signature_check_arg_type(self, arg, spec, name):
1827    """Raise TypeError if `arg`'s type doesn't match `spec`."""
1828    if arg is _BOUND_VALUE:
1829      return
1830
1831    # Check the overall nested structure of the argument.
1832    try:
1833      nest.assert_same_structure(arg, spec, expand_composites=True)
1834    except (ValueError, TypeError):
1835      try:
1836        nest.assert_same_structure(arg, spec, expand_composites=False)
1837        expected, got = spec, arg
1838      except (ValueError, TypeError):
1839        expected, got = _structure_summary(spec), _structure_summary(arg)
1840      raise TypeError("{}: argument {} had incorrect type\n"
1841                      "  expected: {}\n       got: {}".format(
1842                          self._structured_signature_summary(), name, expected,
1843                          got))
1844
1845    # Check the type for each leaf in the nested structure.
1846    arg_pieces = nest.flatten(arg, expand_composites=True)
1847    spec_pieces = nest.flatten(spec, expand_composites=True)
1848    for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
1849      if isinstance(spec_piece, tensor_spec.DenseSpec):
1850        # TODO(edloper): Consider calling convert_to_tensor on non-tensor
1851        # values here.  That would match the behavior of
1852        # _call_concrete_function() in function_deserialization.py.  If
1853        # we do, then we need to change the nest assert_same_structure and
1854        # flatten calls above to use shallow variants.
1855        tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
1856        if not isinstance(arg_piece, tensor_types):
1857          raise TypeError(
1858              "{} expected a Tensor in {}, but got {} value {}".format(
1859                  self._structured_signature_summary(), name,
1860                  type(arg_piece).__name__, arg_piece))
1861      elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
1862        raise TypeError("ConcreteFunction {} was constructed with {} value "
1863                        "{} in {}, but was called with {} value {}".format(
1864                            self._structured_signature_summary(),
1865                            type(spec_piece).__name__, spec_piece, name,
1866                            type(arg_piece).__name__, arg_piece))
1867
1868  def _call_flat(self, args, captured_inputs, cancellation_manager=None):
1869    """Executes the wrapped function.
1870
1871    Args:
1872      args: a list of Tensors or Variables. Arguments from the Python function
1873        should be filtered before calling this method: objects aside from
1874        Tensors, CompositeTensors, and Variables are ignored. Any
1875        CompositeTensors should be expanded before calling this method.
1876      captured_inputs: the captured inputs that are also part of the input args
1877        to the actual execution. By default, it should be self._captured_inputs.
1878      cancellation_manager: (Optional.) A `CancellationManager` that can be
1879        used to cancel function invocation.
1880
1881    Returns:
1882      The result of applying the TF function to `args`.
1883
1884    Raises:
1885      ValueError: If `args` contains anything other than Tensors or Variables.
1886    """
1887    ctx = context.context()
1888    executing_eagerly = ctx.executing_eagerly()
1889
1890    # Copy saveable status of function's graph to current FuncGraph.
1891    default_graph = ops.get_default_graph()
1892    if default_graph.building_function and not self._func_graph.saveable:
1893      default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
1894
1895    if (tape.could_possibly_record() or
1896        hasattr(default_graph, "watch_variable")):
1897      for v in self._func_graph.variables:
1898        resource_variable_ops.variable_accessed(v)
1899
1900    tensor_inputs = []
1901    variables_used = set([])
1902    for i, arg in enumerate(args):
1903      if isinstance(arg, resource_variable_ops.BaseResourceVariable):
1904        # We can pass a variable more than once, and in this case we need to
1905        # pass its handle only once.
1906        if id(arg.handle) in variables_used:
1907          continue
1908        resource_variable_ops.variable_accessed(arg)
1909        tensor_inputs.append(arg.handle)
1910        variables_used.add(id(arg.handle))
1911      elif isinstance(arg, ops.Tensor):
1912        tensor_inputs.append(arg)
1913        if not executing_eagerly:
1914          # If we're graph building, shape inference is on. We check for input
1915          # compatibility up front to avoid hard to debug incompatibilities
1916          # later.
1917          graph_input_shape = tensor_shape.TensorShape(
1918              self._func_graph.inputs[i].shape)
1919          if not graph_input_shape.is_compatible_with(arg.shape):
1920            if self._arg_keywords:
1921              arg_name = "'{}'".format(self._arg_keywords[i])
1922            else:
1923              arg_name = "with index {}".format(i)
1924            raise ValueError(
1925                ("The argument {} (value {}) is not compatible with the shape "
1926                 "this function was traced with. Expected shape {}, but got "
1927                 "shape {}.\n\nIf you called get_concrete_function, you may "
1928                 "need to pass a tf.TensorSpec(..., shape=...) with a less "
1929                 "specific shape, having None on axes which can vary.").format(
1930                     arg_name, arg,
1931                     self._func_graph.inputs[i].shape,
1932                     arg.shape))
1933      else:
1934        raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; "
1935                         "on invocation of %s, the %d-th input (%s) was not a "
1936                         "Tensor." % (self._func_graph.name, i, str(arg)))
1937    args = tensor_inputs + captured_inputs
1938    possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
1939    if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
1940        and executing_eagerly):
1941      # No tape is watching; skip to running the function.
1942      return self._build_call_outputs(self._inference_function.call(
1943          ctx, args, cancellation_manager=cancellation_manager))
1944    forward_backward = self._select_forward_and_backward_functions(
1945        args,
1946        possible_gradient_type,
1947        executing_eagerly)
1948    forward_function, args_with_tangents = forward_backward.forward()
1949    if executing_eagerly:
1950      flat_outputs = forward_function.call(
1951          ctx, args_with_tangents, cancellation_manager=cancellation_manager)
1952    else:
1953      with default_graph._override_gradient_function(  # pylint: disable=protected-access
1954          {"PartitionedCall": self._get_gradient_function(),
1955           "StatefulPartitionedCall": self._get_gradient_function()}):
1956        flat_outputs = forward_function.call(ctx, args_with_tangents)
1957    forward_backward.record(flat_outputs)
1958    return self._build_call_outputs(flat_outputs)
1959
1960  def _experimental_with_cancellation_manager(self, cancellation_manager):
1961    """Returns a callable that invokes a cancellable version of this function.
1962
1963    Args:
1964      cancellation_manager: A `CancellationManager` object that can be used to
1965        cancel function invocation.
1966
1967    Returns:
1968      A callable with the same signature as this concrete function.
1969    """
1970
1971    def cancellable_call(*args, **kwargs):
1972      return self._call_impl(
1973          args, kwargs, cancellation_manager=cancellation_manager)
1974
1975    return cancellable_call
1976
1977  @property
1978  def name(self):
1979    """`ConcreteFunction` name."""
1980    return self._delayed_rewrite_functions.forward().name
1981
1982  @property
1983  def graph(self):
1984    """Returns the graph from which this function was constructed."""
1985    return self._func_graph
1986
1987  @property
1988  def inputs(self):
1989    """Returns tensors in `self.graph` corresponding to arguments."""
1990    return self._func_graph.inputs
1991
1992  @property
1993  def structured_input_signature(self):
1994    """Returns structured signature for this concrete function.
1995
1996    Returns:
1997      A tuple `(args, kwargs)`, where:
1998
1999        * `args` is a tuple that specifies the expected type or value each for
2000          positional argument.
2001        * `kwargs` is a dictionary that specifies the expected type or value
2002          for each keyword-only argument.
2003
2004      The type or value for each argument is specified using one of the
2005      following:
2006
2007        * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
2008          value is expected.
2009        * A Python value, such as an integer, indicating that an equal value
2010          is expected.
2011        * A nested structure of `tf.TypeSpec`s and Python values, indicating
2012          that a corresponding nested structure is expected.
2013    """
2014    return self._func_graph.structured_input_signature
2015
2016  @property
2017  def outputs(self):
2018    """Returns tensors in `self.graph` corresponding to returned tensors."""
2019    return self._func_graph.outputs
2020
2021  @property
2022  def structured_outputs(self):
2023    """Returns outputs in `self.graph` as returned by the original function."""
2024    return self._func_graph.structured_outputs
2025
2026  @property
2027  def captured_inputs(self):
2028    """Returns external Tensors captured by this function.
2029
2030    self.__call__(*args) passes `args + self.captured_inputs` to the function.
2031    """
2032    from_closures = nest.flatten([x() for x in self._captured_closures],
2033                                 expand_composites=True)
2034    return self._captured_inputs + from_closures
2035
2036  @property
2037  def function_def(self):
2038    """Returns a `FunctionDef` object representing this function."""
2039    return self._delayed_rewrite_functions.forward().definition
2040
2041  @property
2042  def output_shapes(self):
2043    """The function's output shapes."""
2044    return nest.map_structure(
2045        lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)),
2046        composite_tensor.replace_composites_with_components(
2047            self._func_graph.structured_outputs),
2048        expand_composites=False)
2049
2050  @property
2051  def output_dtypes(self):
2052    # TODO(akshayka): Consider removing this.
2053    return nest.map_structure(
2054        lambda x: x.dtype if x is not None else None,
2055        composite_tensor.replace_composites_with_components(
2056            self._func_graph.structured_outputs),
2057        expand_composites=False)
2058
2059  def add_to_graph(self, g=None):
2060    """Registers the function, adds it to the graph g or default graph.
2061
2062    Args:
2063      g: If specified, registers the function with this graph. Defaults to the
2064        current context (either the default graph or the eager context).
2065    """
2066    # If we are not executing eagerly, adds the function to default graph if no
2067    # graph is specified.
2068    # In case of eager execution, function definition gets added to context
2069    # during construction itself.
2070
2071    if not context.executing_eagerly() and not g:
2072      g = ops.get_default_graph()
2073    self._delayed_rewrite_functions.forward().add_to_graph(g)
2074
2075  def add_gradient_functions_to_graph(self, g=None):
2076    """Add forward/backward functions to graph `g` or the current context."""
2077    if not context.executing_eagerly() and not g:
2078      g = ops.get_default_graph()
2079    self._delayed_rewrite_functions.forward().add_to_graph(g)
2080    forward_function, backward_function = (
2081        self._delayed_rewrite_functions.forward_backward())
2082    forward_function.add_to_graph(g)
2083    backward_function.add_to_graph(g)
2084
2085  def _get_gradient_function(self):
2086    """Returns gradient function. It will be lazily created at first call."""
2087    return self._delayed_rewrite_functions._rewrite_forward_and_call_backward  # pylint: disable=protected-access
2088
2089  def _select_forward_and_backward_functions(
2090      self, args, possible_gradient_type, executing_eagerly):
2091    """Selects forward and backward functions based on the calling context.
2092
2093    The forward function computes the "real" function outputs, `self._outputs`,
2094    and any extra values needed by the corresponding backward function.
2095
2096    Args:
2097      args: A flat list of Tensors with all of the inputs to the forward
2098        function (including user-specified and captured inputs).
2099      possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
2100      executing_eagerly: Boolean, the value of context.executing_eagerly().
2101
2102    Returns:
2103      An object with a `forward` method returning a tuple of (forward_function :
2104      _EagerDefinedFunction, augmented_arguments : List), and a corresponding
2105      `record` method which takes outputs from the forward function and records
2106      the operation. forward_function should be called with augmented_arguments.
2107    """
2108    if executing_eagerly:
2109      input_tangents = forwardprop_util.pack_tangents(args)
2110    else:
2111      input_tangents = forwardprop_util.TangentInfo()
2112    need_gradients_for_jvps = tape.should_record_backprop(
2113        input_tangents.tangents)
2114    # Allows re-use of forward and backward function pairs depending on the
2115    # tapes and forward accumulators watching its inputs.
2116    cache_key = (need_gradients_for_jvps, input_tangents.indices)
2117    if (possible_gradient_type
2118        == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
2119      if input_tangents.indices or executing_eagerly:
2120        # There is a single non-persistent tape active, so the user can only
2121        # request first-order gradients from a tape. We can spend less time
2122        # graph building since we know this.
2123        #
2124        # We may still end up computing higher-order gradients, but that'd be
2125        # through `tf.gradients`, which can re-write the forward pass and so
2126        # needs no preparation here.
2127        functions = self._first_order_tape_functions.get(cache_key, None)
2128        if functions is None:
2129          functions = _FirstOrderTapeGradientFunctions(
2130              self._func_graph, self._attrs, self._garbage_collector,
2131              forwardprop_input_indices=input_tangents.indices,
2132              delayed_rewrite_functions=self._delayed_rewrite_functions,
2133              need_gradients_for_jvps=need_gradients_for_jvps)
2134          self._first_order_tape_functions[cache_key] = functions
2135        return _ForwardBackwardCall(
2136            functions, args, input_tangents.tangents, tape_watching=True)
2137      else:
2138        # We can avoid computing second-order gradients in some cases by doing a
2139        # delayed rewrite when graph building. Since we know we'll only compute
2140        # first-order tape gradients, the delayed rewrite is safe: we won't need
2141        # to tell the tape about side outputs.
2142        #
2143        # TODO(allenl): This case is really dirty. It would be better if we
2144        # could temporarily pop all of the current tapes to avoid
2145        # accidentally taking second-order gradients.
2146        return _ForwardBackwardCall(
2147            self._delayed_rewrite_functions, args, input_tangents.tangents,
2148            tape_watching=True)
2149    elif (possible_gradient_type
2150          == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
2151      # Either there's a persistent tape watching, or there are multiple nested
2152      # tapes. Either way, the user may request higher-order gradients. We'll
2153      # spend a bit more time and make sure higher-order gradients are correct.
2154      functions = self._higher_order_tape_functions.get(
2155          cache_key, None)
2156      if functions is None:
2157        functions = _HigherOrderTapeGradientFunctions(
2158            self._func_graph, self._attrs, self._garbage_collector,
2159            forwardprop_input_indices=input_tangents.indices,
2160            delayed_rewrite_functions=self._delayed_rewrite_functions,
2161            need_gradients_for_jvps=need_gradients_for_jvps)
2162        self._higher_order_tape_functions[cache_key] = functions
2163      return _ForwardBackwardCall(functions, args, input_tangents.tangents,
2164                                  tape_watching=True)
2165    # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
2166    # tape is recording.
2167    return _ForwardBackwardCall(
2168        self._delayed_rewrite_functions, args, input_tangents.tangents,
2169        tape_watching=False)
2170
2171  def _build_call_outputs(self, result):
2172    """Maps the fdef output list to actual output structure.
2173
2174    Args:
2175      result: Output lists defined by FunctionDef.
2176    Returns:
2177      The actual call output.
2178    """
2179    # TODO(jlchu): call C++ version in function.cc when speed is improved
2180    if self._func_graph.structured_outputs is None:
2181      return result
2182
2183    # Replace outputs with results, skipping over any 'None' values.
2184    outputs_list = nest.flatten(
2185        self._func_graph.structured_outputs, expand_composites=True)
2186    j = 0
2187    for i, o in enumerate(outputs_list):
2188      if o is not None:
2189        custom_gradient.copy_handle_data(self.outputs[j], result[j])
2190        outputs_list[i] = result[j]
2191        j += 1
2192    ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
2193                                outputs_list, expand_composites=True)
2194    return ret
2195
2196  @property
2197  def _as_name_attr_list(self):
2198    """Returns a `NameAttrList` representing this function."""
2199    ret = attr_value_pb2.NameAttrList(name=self.name)
2200    for name, value in self._attrs.items():
2201      ret.attr[name].CopyFrom(value)
2202    return ret
2203
2204  def _structured_signature_summary(self, default_values=False):
2205    """Returns a string summarizing this function's structured signature.
2206
2207    Args:
2208      default_values: If true, then include default values in the signature.
2209
2210    Returns:
2211      A `string`.
2212    """
2213    # Note: we can't just use self._funcion_spec.signature_summary(), because
2214    # that would show "_BOUND_VALUE" as the default value for all arguments.
2215    assert self._function_spec is not None
2216    arg_specs, kwarg_specs = self.structured_input_signature
2217    arg_names = list(self._function_spec.arg_names)
2218
2219    # If an explicit input_signature is provided to @tf.function, then any
2220    # arguments with defaults that are not covered by that explicit signature
2221    # are simply dropped from the signature.
2222    # TODO(b/159639913) Look into whether dropping arguments with default values
2223    # from the signature is the right thing to do.
2224    arg_names = arg_names[:len(arg_specs)]
2225
2226    if default_values:
2227      for i in range(len(arg_names)):
2228        if not _contains_type_spec(arg_specs[i]):
2229          arg_names[i] += "={}".format(arg_specs[i])
2230    if kwarg_specs:
2231      arg_names.append("*")
2232      for name, spec in kwarg_specs.items():
2233        arg_names.append(name)
2234        if default_values and not _contains_type_spec(spec):
2235          arg_names[-1] += "={}".format(spec)
2236    signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
2237
2238    return signature
2239
2240  def _flat_signature_summary(self):
2241    """Returns a string summarizing this function's flat signature."""
2242    assert self._arg_keywords is not None
2243    assert self._num_positional_args is not None
2244    arg_names = self._arg_keywords
2245    if self._num_positional_args > len(arg_names):
2246      arg_names.extend(
2247          "<arg{}>".format(i + 1)
2248          for i in range(len(arg_names), self._num_positional_args))
2249    return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
2250
2251  def pretty_printed_signature(self, verbose=True):
2252    """Returns a string summarizing the signature of this concrete function."""
2253    if not verbose:
2254      return self._structured_signature_summary(default_values=True)
2255
2256    def pretty_print_spec(spec):
2257      """Returns a string describing the spec for a single argument."""
2258      if isinstance(spec, tensor_spec.TensorSpec):
2259        return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
2260      elif nest.is_sequence(spec):
2261        pieces = nest.flatten(spec, expand_composites=False)
2262        markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
2263        structure = nest.pack_sequence_as(spec, markers)
2264        # Ensure dictionaries are sorted by key (for determinism)
2265        result = pprint.pformat(structure, width=10000)
2266        for (marker, piece) in zip(markers, pieces):
2267          result += "\n      {}: {}".format(marker, pretty_print_spec(piece))
2268        return result
2269      else:
2270        return repr(spec)
2271
2272    lines = [self._structured_signature_summary(default_values=True)]
2273    arg_specs, kwarg_specs = self.structured_input_signature
2274    names = list(self._function_spec.arg_names)
2275
2276    # If an explicit input_signature is provided to @tf.function, then any
2277    # arguments with defaults that are not covered by that explicit signature
2278    # are simply dropped from the signature.
2279    # TODO(b/159639913) Look into whether dropping arguments with default values
2280    # from the signature is the right thing to do.
2281    names = names[:len(arg_specs)]
2282
2283    names.extend(sorted(kwarg_specs))
2284    specs = list(arg_specs) + list(kwarg_specs.values())
2285    # note: we can skip bound args, since we already displayed thier bound
2286    # value in the signature summary.
2287    arg_details = []
2288    for (name, spec) in zip(names, specs):
2289      if _contains_type_spec(spec):
2290        arg_details.append("    {}: {}".format(name, pretty_print_spec(spec)))
2291    if arg_details:
2292      lines.append("  Args:")
2293      lines.extend(arg_details)
2294    lines.append("  Returns:")
2295
2296    def spec_from_value(value):
2297      # For loaded function, structured_outputs are already specs.
2298      if isinstance(value, type_spec.TypeSpec):
2299        return value
2300      return type_spec.type_spec_from_value(value)
2301
2302    lines.append("    {}".format(
2303        pretty_print_spec(
2304            nest.map_structure(spec_from_value, self.structured_outputs))))
2305
2306    return "\n".join(lines)
2307
2308  def __repr__(self):
2309    if self._function_spec is not None:
2310      return "<ConcreteFunction {} at 0x{:X}>".format(
2311          self.pretty_printed_signature(verbose=False), id(self))
2312    elif not (self._num_positional_args is None or self._arg_keywords is None):
2313      return "<ConcreteFunction {} at 0x{:X}>".format(
2314          self._flat_signature_summary(), id(self))
2315    else:
2316      return object.__repr__(self)
2317
2318  def __str__(self):
2319    if self._function_spec is not None:
2320      return "ConcreteFunction {}".format(self.pretty_printed_signature())
2321    else:
2322      return self.__repr__()
2323
2324
2325_pywrap_utils.RegisterType("Tensor", ops.Tensor)
2326_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
2327_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
2328
2329
2330def _deterministic_dict_values(dictionary):
2331  return tuple(dictionary[key] for key in sorted(dictionary))
2332
2333
2334class FunctionSpec(object):
2335  """Specification of how to bind arguments to a function."""
2336
2337  @staticmethod
2338  def from_function_and_signature(python_function,
2339                                  input_signature,
2340                                  is_pure=False,
2341                                  experimental_follow_type_hints=False,
2342                                  jit_compile=None):
2343    """Create a FunctionSpec instance given a python function and signature.
2344
2345    Args:
2346      python_function: a function to inspect
2347      input_signature: a signature of the function (None, if variable)
2348      is_pure: if True all input arguments (including variables and constants)
2349      will be converted to tensors and no variable changes allowed.
2350      experimental_follow_type_hints: see `tf.function`
2351      jit_compile: see `tf.function`
2352
2353    Returns:
2354      instance of FunctionSpec
2355    """
2356    fullargspec = tf_inspect.getfullargspec(python_function)
2357    # Treat a wrapped partial function as a special case. For all arguments that
2358    # were overridden with keywords in the partial:
2359    #   - remove the corresponding arguments,
2360    #   - remove the corresponding keywords.
2361    _, unwrapped = tf_decorator.unwrap(python_function)
2362    # TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and
2363    # fullargspec.kwonlydefaults.
2364    if isinstance(unwrapped, functools.partial):
2365      # Also consider the Python3 case with kwonlydefaults.
2366      if fullargspec.defaults or fullargspec.kwonlydefaults:
2367        new_defaults = fullargspec.defaults
2368        new_args = fullargspec.args
2369        if fullargspec.defaults:
2370          # To be able to canonicalize the function properly, we want to ignore
2371          # default values that are overridden via a partial kwarg. For example:
2372          #
2373          #   def func(a, b, c, d=5, e=7):
2374          #     return a, b, c, d, e
2375          #   p_func = functools.partial(tf.function(func, 10, e=9))
2376          #
2377          # Here we want to drop from the defaults the parameter `e`. If we
2378          # forwarded the call to the partial function with a default for `e`
2379          # we would get an error for passing two values for one parameter.
2380          #
2381          # Note that this has a limitation: we can only override parameters at
2382          # the end of the parameter list.
2383          #
2384          # In this case we want to end up with 3 arguments (b, c, d) and 1
2385          # default value (5). We do this by constructing a mask where 0 stands
2386          # for a value that was overridden by a partial kwarg. The seemingly
2387          # complicated logic below does just that - for arguments (b, c, d, e)
2388          # we would get a mask (1, 1, 1, 0).
2389          old_args = fullargspec.args
2390          old_defaults = fullargspec.defaults
2391
2392          no_default = object()
2393          num_args_without_defaults = len(old_args) - len(old_defaults)
2394          left_padding = tuple([no_default] * num_args_without_defaults)
2395
2396          args_with_defaults = zip(old_args, left_padding + old_defaults)
2397
2398          # Create a mask where 0 stands for args that had a partial kwarg
2399          # defined.
2400          non_keyword_defaults_mask = [
2401              0 if key in unwrapped.keywords else 1 for key in old_args
2402          ]
2403          # Keep only arguments and defaults that were not kwargs of partial.
2404          new_args_with_defaults = list(
2405              itertools.compress(args_with_defaults, non_keyword_defaults_mask))
2406          # Keep all args.
2407          new_args = [arg for arg, _ in new_args_with_defaults]
2408          # Keep only real default values.
2409          new_defaults = [
2410              default for _, default in new_args_with_defaults
2411              if default is not no_default
2412          ]
2413        fullargspec = tf_inspect.FullArgSpec(
2414            args=new_args,
2415            varargs=fullargspec.varargs,
2416            varkw=fullargspec.varkw,
2417            defaults=new_defaults,
2418            kwonlyargs=[],
2419            kwonlydefaults={},
2420            annotations=fullargspec.annotations)
2421
2422      # inspect.ismethod() and inspect.isfunction() both return False on a
2423      # functools.partial-wrapped function. We set it to False to
2424      # maintain consistency with prior versions.
2425      is_method = False
2426
2427    else:
2428      # Instead of using tf_inspect.ismethod() which only checks the
2429      # final unwrapped target, we check if any decorated target along the chain
2430      # is a method.
2431      is_method = tf_inspect.isanytargetmethod(python_function)
2432
2433      # In the following scenario, 'python_function' is a callable object.
2434      # python_function(...) is equal to python_function.__call__(self, ...)
2435      if not is_method and not tf_inspect.isfunction(
2436          python_function) and hasattr(
2437              python_function, "__class__") and hasattr(
2438                  python_function.__class__, "__call__"):
2439        is_method = True
2440
2441    # Get the function's name.  Remove functools.partial wrappers if necessary.
2442    while isinstance(python_function, functools.partial):
2443      python_function = python_function.func
2444    name = getattr(python_function, "__name__", "f")
2445
2446    return FunctionSpec(
2447        fullargspec,
2448        is_method,
2449        input_signature,
2450        is_pure=is_pure,
2451        jit_compile=jit_compile,
2452        experimental_follow_type_hints=experimental_follow_type_hints,
2453        name=name)
2454
2455  def __init__(self,
2456               fullargspec,
2457               is_method,
2458               input_signature,
2459               is_pure=False,
2460               experimental_follow_type_hints=False,
2461               name=None,
2462               jit_compile=None):
2463    """Constructs a FunctionSpec describing a python function.
2464
2465    Args:
2466      fullargspec: `tf_inspect.FullArgSpec` object describing the function.
2467      is_method: True if the function is a method.
2468      input_signature: a signature of the function (None, if variable)
2469      is_pure: if True all input arguments (including variables and constants)
2470        will be converted to tensors and no variable changes allowed.
2471      experimental_follow_type_hints: see `tf.function`.
2472      name: Name of the function
2473      jit_compile: see `tf.function`.
2474    """
2475    self._fullargspec = fullargspec
2476    self._is_method = is_method
2477    self._is_pure = is_pure
2478    self._jit_compile = jit_compile
2479    self._experimental_follow_type_hints = experimental_follow_type_hints
2480
2481    # TODO(edloper): Include name when serializing for SavedModel?
2482    self._name = name or "f"
2483
2484    if self._is_method:
2485      # Remove `self`: default arguments shouldn't be matched to it.
2486      # TODO(b/127938157): Should this error out if there is no arg to
2487      # be removed?
2488      args = fullargspec.args[1:]
2489    else:
2490      args = fullargspec.args
2491
2492    # A cache mapping from argument name to index, for canonicalizing
2493    # arguments that are called in a keyword-like fashion.
2494    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
2495    self._arg_names = args
2496
2497    # A cache mapping from arg index to default value, for canonicalization.
2498    default_values = fullargspec.defaults
2499    offset = len(args) - len(default_values or [])
2500    self._arg_indices_to_default_values = {
2501        offset + index: default
2502        for index, default in enumerate(default_values or [])
2503    }
2504    self._arg_indices_no_default_values = set(range(len(args))) - set(
2505        self._arg_indices_to_default_values)
2506    if input_signature is None:
2507      self._input_signature = None
2508    else:
2509      if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
2510        raise ValueError("Cannot define a TensorFlow function from a Python "
2511                         "function with keyword-only arguments when "
2512                         "input_signature is provided.")
2513
2514      if not isinstance(input_signature, (tuple, list)):
2515        raise TypeError("input_signature must be either a tuple or a "
2516                        "list, received " + str(type(input_signature)))
2517
2518      self._input_signature = tuple(input_signature)
2519      self._flat_input_signature = tuple(nest.flatten(input_signature,
2520                                                      expand_composites=True))
2521
2522  @property
2523  def fullargspec(self):
2524    return self._fullargspec
2525
2526  @property
2527  def is_method(self):
2528    return self._is_method
2529
2530  @property
2531  def args_to_indices(self):
2532    return self._args_to_indices
2533
2534  @property
2535  def kwargs_to_include(self):
2536    return self._kwargs_to_include
2537
2538  @property
2539  def input_signature(self):
2540    return self._input_signature
2541
2542  @property
2543  def flat_input_signature(self):
2544    return self._flat_input_signature
2545
2546  @property
2547  def is_pure(self):
2548    return self._is_pure
2549
2550  @property
2551  def jit_compile(self):
2552    return self._jit_compile
2553
2554  @property
2555  def arg_names(self):
2556    return self._arg_names
2557
2558  @property
2559  def vararg_name(self):
2560    return self._fullargspec.varargs
2561
2562  @property
2563  def varkw_name(self):
2564    return self._fullargspec.varkw
2565
2566  def signature_summary(self, default_values=False):
2567    """Returns a string summarizing this function's signature.
2568
2569    Args:
2570      default_values: If true, then include default values in the signature.
2571
2572    Returns:
2573      A `string`.
2574    """
2575    args = list(self._arg_names)
2576    if default_values:
2577      for (i, default) in self._arg_indices_to_default_values.items():
2578        args[i] += "={}".format(default)
2579    if self._fullargspec.kwonlyargs:
2580      args.append("*")
2581      for arg_name in self._fullargspec.kwonlyargs:
2582        args.append(arg_name)
2583        if default_values and arg_name in self._fullargspec.kwonlydefaults:
2584          args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
2585    return "{}({})".format(self._name, ", ".join(args))
2586
2587  def _to_tensor_or_tensor_spec(self, x):
2588    return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec))
2589            else ops.convert_to_tensor(x))
2590
2591  def _convert_variables_to_tensors(self, args, kwargs):
2592    args = [self._to_tensor_or_tensor_spec(x) for x in args]
2593    kwargs = {kw: self._to_tensor_or_tensor_spec(x)
2594              for kw, x in kwargs.items()}
2595    return tuple(args), kwargs
2596
2597  def _convert_annotated_args_to_tensors(self, args, kwargs):
2598    """Attempts to autobox arguments annotated as tf.Tensor."""
2599    if self.input_signature is not None:
2600      return
2601
2602    args = list(args)
2603    for i, arg in enumerate(args):
2604      # See
2605      # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
2606      if i < len(self._fullargspec.args):
2607        annotation_key = self._fullargspec.args[i]
2608      else:
2609        annotation_key = self._fullargspec.varargs
2610      arg_annotation = self._fullargspec.annotations.get(annotation_key, None)
2611
2612      # TODO(rahulkamat): Change to TensorLike (here ans below)
2613      if arg_annotation == ops.Tensor:
2614        args[i] = self._to_tensor_or_tensor_spec(arg)
2615
2616    for kw, v in kwargs.items():
2617      if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
2618        annotation_key = kw
2619      else:
2620        annotation_key = self._fullargspec.varkw
2621      kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
2622      if kwarg_annotation == ops.Tensor:
2623        kwargs[kw] = self._to_tensor_or_tensor_spec(v)
2624    return tuple(args), kwargs
2625
2626  def canonicalize_function_inputs(self, *args, **kwargs):
2627    """Canonicalizes `args` and `kwargs`.
2628
2629    Canonicalize the inputs to the Python function using a `FunctionSpec`
2630    instance. In particular, we parse the varargs and kwargs that the
2631    original function was called with into a tuple corresponding to the
2632    Python function's positional (named) arguments and a dictionary
2633    corresponding to its kwargs.  Missing default arguments are added.
2634
2635    If this `FunctionSpec` has an input signature, then it is used to convert
2636    arguments to tensors; otherwise, any inputs containing numpy arrays are
2637    converted to tensors.
2638
2639    Additionally, any inputs containing numpy arrays are converted to Tensors.
2640
2641    Args:
2642      *args: The varargs this object was called with.
2643      **kwargs: The keyword args this function was called with.
2644
2645    Returns:
2646      A canonicalized ordering of the inputs, as well as full and filtered
2647      (Tensors and Variables only) versions of their concatenated flattened
2648      representations, represented by a tuple in the form (args, kwargs,
2649      flat_args, filtered_flat_args). Here: `args` is a full list of bound
2650      arguments, and `kwargs` contains only true keyword arguments, as opposed
2651      to named arguments called in a keyword-like fashion.
2652
2653    Raises:
2654      ValueError: If a keyword in `kwargs` cannot be matched with a positional
2655        argument when an input signature is specified, or when the inputs
2656        do not conform to the input signature.
2657    """
2658    if self._is_pure:
2659      args, kwargs = self._convert_variables_to_tensors(args, kwargs)
2660    if self._experimental_follow_type_hints:
2661      args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
2662    # Pre-calculate to reduce overhead
2663    arglen = len(args)
2664    if self._input_signature is not None:
2665      if arglen > len(self._input_signature):
2666        raise TypeError("{} takes {} positional arguments (as specified by the "
2667                        "input_signature) but {} were given".format(
2668                            self.signature_summary(),
2669                            len(self._input_signature), arglen))
2670      for arg in six.iterkeys(kwargs):
2671        index = self._args_to_indices.get(arg, None)
2672        if index is None:
2673          raise TypeError("{} got unexpected keyword argument `{}`".format(
2674              self.signature_summary(), arg))
2675        if index >= len(self._input_signature):
2676          raise TypeError(
2677              "{} got keyword argument `{}` that was not included in "
2678              "input_signature".format(self.signature_summary(), arg))
2679
2680    if not kwargs:
2681      inputs = args
2682      if self._arg_indices_to_default_values:
2683        try:
2684          inputs += tuple(self._arg_indices_to_default_values[i]
2685                          for i in range(arglen, len(self._arg_names)))
2686        except KeyError:
2687          missing_args = [
2688              self._arg_names[i]
2689              for i in range(arglen, len(self._arg_names))
2690              if i not in self._arg_indices_to_default_values
2691          ]
2692          raise TypeError("{} missing required arguments: {}".format(
2693              self.signature_summary(), ", ".join(missing_args)))
2694
2695      if self._fullargspec.kwonlydefaults:
2696        kwargs.update(self._fullargspec.kwonlydefaults)
2697    else:
2698      # Maps from index of arg to its corresponding value, according to `args`
2699      # and `kwargs`; seeded with the default values for the named args that
2700      # aren't in `args`.
2701      arg_indices_to_values = {
2702          index: default for index, default in six.iteritems(
2703              self._arg_indices_to_default_values) if index >= arglen
2704      }
2705      consumed_args = []
2706      missing_arg_indices = self._arg_indices_no_default_values - set(
2707          range(arglen))
2708      for arg, value in six.iteritems(kwargs):
2709        index = self._args_to_indices.get(arg, None)
2710        if index is not None:
2711          if index < arglen:
2712            raise TypeError("{} got two values for argument '{}'".format(
2713                self.signature_summary(), arg))
2714          arg_indices_to_values[index] = value
2715          # These arguments in 'kwargs' might also belong to
2716          # positional arguments
2717          missing_arg_indices.discard(index)
2718          consumed_args.append(arg)
2719      for arg in consumed_args:
2720        # After this loop, `kwargs` will only contain keyword_only arguments,
2721        # and all positional_or_keyword arguments have been moved to `inputs`.
2722        kwargs.pop(arg)
2723      inputs = args + _deterministic_dict_values(arg_indices_to_values)
2724      # Exclude positional args with values
2725      if missing_arg_indices:
2726        missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
2727        if len(missing_args) == 1:
2728          raise TypeError("{} missing 1 required argument: {}".format(
2729              self.signature_summary(), missing_args[0]))
2730        else:
2731          raise TypeError("{} missing required arguments: {}".format(
2732              self.signature_summary(), ", ".join(missing_args)))
2733
2734      if kwargs and self._input_signature is not None:
2735        raise TypeError(
2736            "{} got unexpected keyword arguments: {}\n(Cannot define a "
2737            "TensorFlow function from a Python function with keyword arguments "
2738            "when input_signature is provided.)".format(
2739                self.signature_summary(), ", ".join(kwargs)))
2740
2741      if self._fullargspec.kwonlydefaults:
2742        for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
2743          kwargs.setdefault(kwarg, default)
2744
2745    if self._input_signature is None:
2746      inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
2747      kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
2748      return (inputs, kwargs, flat_inputs + flat_kwargs,
2749              filtered_flat_inputs + filtered_flat_kwargs)
2750    else:
2751      assert not kwargs
2752      inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
2753          inputs, self._input_signature, self._flat_input_signature)
2754      return inputs, {}, flat_inputs, filtered_flat_inputs
2755
2756
2757def _as_ndarray(value):
2758  """Converts value to an ndarray, assumes _is_ndarray(value)."""
2759  # TODO(tomhennigan) Support __array_interface__ too.
2760  return value.__array__()
2761
2762
2763def _is_ndarray(value):
2764  """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
2765  # TODO(tomhennigan) Support __array_interface__ too.
2766  return hasattr(value, "__array__") and not (
2767      isinstance(value, ops.Tensor)
2768      or isinstance(value, resource_variable_ops.BaseResourceVariable)
2769      or hasattr(value, "_should_act_as_resource_variable")
2770
2771      # For legacy reasons we do not automatically promote Numpy strings.
2772      or isinstance(value, np.str_)
2773      # NumPy dtypes have __array__ as unbound methods.
2774      or isinstance(value, type)
2775      # CompositeTensors should be flattened instead.
2776      or isinstance(value, composite_tensor.CompositeTensor))
2777
2778
2779def _convert_numpy_inputs(inputs):
2780  """Convert numpy array inputs to tensors."""
2781  # We assume that any CompositeTensors have already converted their components
2782  # from numpy arrays to Tensors, so we don't need to expand composites here for
2783  # the numpy array conversion. Instead, we do so because the flattened inputs
2784  # are eventually passed to ConcreteFunction()._call_flat, which requires
2785  # expanded composites.
2786  flat_inputs = nest.flatten(inputs, expand_composites=True)
2787
2788  # Check for NumPy arrays in arguments and convert them to Tensors.
2789  # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
2790  # finding a way to store them directly in the cache key (currently not
2791  # possible since ndarrays are not hashable).
2792  need_packing = False
2793  filtered_flat_inputs = []
2794  for index, value in enumerate(flat_inputs):
2795    if isinstance(value,
2796                  (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
2797      filtered_flat_inputs.append(value)
2798    elif hasattr(value, "__array__") and not (
2799        hasattr(value, "_should_act_as_resource_variable") or
2800        isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
2801      # This case is equivalent to _is_ndarray(value) == True
2802      a = _as_ndarray(value)
2803      if not isinstance(a, np.ndarray):
2804        raise TypeError("The output of __array__ must be an np.ndarray "
2805                        "(got {} from {}).".format(type(a), type(value)))
2806      flat_inputs[index] = constant_op.constant(a)
2807      filtered_flat_inputs.append(flat_inputs[index])
2808      need_packing = True
2809  if need_packing:
2810    return (nest.pack_sequence_as(
2811        structure=inputs, flat_sequence=flat_inputs,
2812        expand_composites=True), flat_inputs, filtered_flat_inputs)
2813  else:
2814    return inputs, flat_inputs, filtered_flat_inputs
2815
2816
2817def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
2818  """Convert inputs to pass into a function with an explicit signature."""
2819
2820  def format_error_message(inputs, input_signature):
2821    return ("  inputs: (\n" + "    " + ",\n    ".join(str(i) for i in inputs) +
2822            ")\n" + "  input_signature: (\n" + "    " +
2823            ",\n    ".join(str(i) for i in input_signature) + ")")
2824
2825  try:
2826    flatten_inputs = nest.flatten_up_to(
2827        input_signature,
2828        inputs[:len(input_signature)],
2829        expand_composites=True,
2830        check_types=False)  # lists are convert to tuples for `tf.data`.
2831  except ValueError:
2832    raise ValueError("Structure of Python function inputs does not match "
2833                     "input_signature:\n%s" %
2834                     format_error_message(inputs, input_signature))
2835
2836  need_packing = False
2837  for index, (value, spec) in enumerate(zip(flatten_inputs,
2838                                            flat_input_signature)):
2839    if (isinstance(spec, tensor_spec.TensorSpec) and
2840        not _pywrap_utils.IsTensor(value)):
2841      try:
2842        flatten_inputs[index] = ops.convert_to_tensor(
2843            value, dtype_hint=spec.dtype)
2844        need_packing = True
2845      except ValueError:
2846        raise ValueError("When input_signature is provided, all inputs to "
2847                         "the Python function must be convertible to "
2848                         "tensors:\n%s" %
2849                         format_error_message(inputs, input_signature))
2850
2851  if any(not spec.is_compatible_with(other) for spec, other in zip(
2852      flat_input_signature,
2853      flatten_inputs)):
2854    raise ValueError("Python inputs incompatible with input_signature:\n%s" %
2855                     format_error_message(inputs, input_signature))
2856
2857  if need_packing:
2858    inputs = nest.pack_sequence_as(
2859        structure=input_signature,
2860        flat_sequence=flatten_inputs,
2861        expand_composites=True)
2862
2863  flat_inputs = nest.flatten(inputs, expand_composites=True)
2864
2865  return (inputs, flat_inputs, [
2866      t for t in flat_inputs
2867      if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
2868  ])
2869
2870
2871class FunctionCache(object):
2872  """A lightweight container for cached functions.
2873  """
2874
2875  __slots__ = [
2876      "missed", "primary", "arg_relaxed_specs", "arg_relaxed",
2877      "_garbage_collectors"
2878  ]
2879
2880  def __init__(self):
2881    # The set of functions that have been missed; entries are CacheKey with
2882    # input_signature `None` (e.g. a "call context key")
2883    self.missed = set()
2884    # The primary cache, mapping a fully shaped CacheKey to a function.
2885    self.primary = collections.OrderedDict()
2886    # A cache key lookup, mapping a CacheKey generated without shape info to a
2887    # flat list of `TypeSpec`s with relaxed shapes (one for each flattened
2888    # argument). Arguments that are not Tensors or `CompositeTensor`s contain a
2889    # `None` for the corresponding relaxed spec.
2890    self.arg_relaxed_specs = collections.OrderedDict()
2891    # The secondary cache, mapping a CacheKey generated without shape info to a
2892    # function.
2893    self.arg_relaxed = collections.OrderedDict()
2894    # All OrderedDicts require manual garbage collection.
2895    self._garbage_collectors = [
2896        _FunctionGarbageCollector(self.primary),
2897        _FunctionGarbageCollector(self.arg_relaxed),
2898        _FunctionGarbageCollector(self.arg_relaxed_specs)]
2899
2900  def all_values(self):
2901    """A list of all `ConcreteFunction` instances held by this cache."""
2902    # We need to simultaneously make sure our returned concrete functions are
2903    # unique *and* make sure they are returned in a deterministic order for
2904    # serialization.
2905    #
2906    # TODO(b/174215821): It's likely that we ultimately would just prefer to
2907    # choose the most specific concrete function shape given a set of
2908    # arguments. If and when that is implemented, this logic can be revisited.
2909    primary_functions = set(self.primary.values())
2910    return list(self.primary.values()) + [
2911        v for v in self.arg_relaxed.values() if v not in primary_functions]
2912
2913
2914class Function(object):
2915  """Wrapper class for the graph functions defined for a Python function.
2916
2917  See the documentation for `defun` for more information on the semantics of
2918  defined functions.
2919
2920  `Function` class is thread-compatible meaning that minimal usage of defuns
2921  (defining and calling) is thread-safe, but if users call other methods or
2922  invoke the base `python_function` themselves, external synchronization is
2923  necessary.
2924  In addition, Function is not reentrant, so recursive functions need to call
2925  the wrapped function, not the wrapper.
2926  """
2927
2928  def __init__(self,
2929               python_function,
2930               name,
2931               input_signature=None,
2932               attributes=None,
2933               autograph=True,
2934               autograph_options=None,
2935               experimental_relax_shapes=False,
2936               capture_by_value=None,
2937               jit_compile=None,
2938               experimental_follow_type_hints=False):
2939    """Initializes a `Function`.
2940
2941    Args:
2942      python_function: the function to be wrapped.
2943      name: the name given to it.
2944      input_signature: a possibly nested sequence of `TensorSpec` objects
2945        specifying the input signature of this function. If `None`, a separate
2946        function is instantiated for each inferred input signature.
2947      attributes: dict, extra keyword arguments that will be added as attribute
2948        of the function.
2949      autograph: whether to use autograph to compile
2950        `python_function`. See https://www.tensorflow.org/guide/autograph for
2951        more information.
2952      autograph_options: Experimental knobs to control behavior
2953        `when autograph=True`. See https://www.tensorflow.org/guide/autograph
2954        for more information.
2955      experimental_relax_shapes: When true, argument shapes may be relaxed to
2956        avoid unnecessary retracing.
2957      capture_by_value: Experimental. Whether to capture resource variables by
2958        value or reference. If None, will inherit from a parent context or
2959        default to False.
2960      jit_compile: Force-compile the function with XLA, cf.
2961        def_function.Function doc on jit_compile.
2962      experimental_follow_type_hints: See the documentation for `tf.function`.
2963
2964    Raises:
2965      ValueError: if `input_signature` is not None and the `python_function`'s
2966        argspec has keyword arguments.
2967    """
2968    self._python_function = python_function
2969    pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes
2970    self._function_spec = FunctionSpec.from_function_and_signature(
2971        python_function,
2972        input_signature,
2973        is_pure=pure_function,
2974        experimental_follow_type_hints=experimental_follow_type_hints)
2975    self._name = name
2976    self._autograph = autograph
2977    self._autograph_options = autograph_options
2978    self._experimental_relax_shapes = experimental_relax_shapes
2979    self._function_cache = FunctionCache()
2980    self._function_attributes = attributes or {}
2981    self._capture_by_value = capture_by_value
2982    self.tracing_count = 0
2983    if self.input_signature is not None:
2984      self._hashable_input_signature = _make_input_signature_hashable(
2985          self.flat_input_signature)
2986
2987    self._lock = threading.Lock()
2988    # _descriptor_cache is a of instance of a class to an instance-specific
2989    # `Function`, used to make sure defun-decorated methods create different
2990    # functions for each instance.
2991    self._descriptor_cache = weakref.WeakKeyDictionary()
2992    self._jit_compile = jit_compile
2993    self._experimental_follow_type_hints = experimental_follow_type_hints
2994
2995  def __call__(self, *args, **kwargs):
2996    """Calls a graph function specialized to the inputs."""
2997    with self._lock:
2998      (graph_function,
2999       filtered_flat_args) = self._maybe_define_function(args, kwargs)
3000    return graph_function._call_flat(
3001        filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
3002
3003  @property
3004  def python_function(self):
3005    """Returns the wrapped Python function."""
3006    return self._python_function  # pylint: disable=protected-access
3007
3008  @property
3009  def function_spec(self):
3010    return self._function_spec
3011
3012  @property
3013  def input_signature(self):
3014    """Returns the input signature."""
3015    return self._function_spec.input_signature
3016
3017  @property
3018  def flat_input_signature(self):
3019    """Returns the flattened input signature."""
3020    return self._function_spec.flat_input_signature
3021
3022  def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
3023    """Returns a concrete function which cleans up its graph function."""
3024    if self.input_signature:
3025      args, kwargs = None, None
3026    with self._lock:
3027      graph_function, _ = self._maybe_define_function(args, kwargs)
3028    return graph_function
3029
3030  def _get_concrete_function_internal(self, *args, **kwargs):
3031    """Bypasses error checking when getting a graph function."""
3032    graph_function = self._get_concrete_function_internal_garbage_collected(
3033        *args, **kwargs)
3034    # We're returning this concrete function to someone, and they may keep a
3035    # reference to the FuncGraph without keeping a reference to the
3036    # ConcreteFunction object. So we won't clean up the reference cycles
3037    # manually and instead will leave them to Python's garbage collector.
3038    graph_function._garbage_collector.release()  # pylint: disable=protected-access
3039    return graph_function
3040
3041  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
3042    """Returns a `ConcreteFunction` specialized to inputs and execution context.
3043
3044    Unlike `get_concrete_function(...)`, the graph will be deleted when the
3045    returned function is deleted.  It's useful to avoid creating a reference
3046    cycle when you know for sure that the graph will be no longer used without
3047    the returned function.
3048
3049    Args:
3050      *args: inputs to specialize on.
3051      **kwargs: inputs to specialize on.
3052    """
3053    if self.input_signature:
3054      if kwargs:
3055        raise ValueError("Cannot define a TensorFlow function from a Python "
3056                         "function with keyword arguments when "
3057                         "input_signature is provided.")
3058      if args:
3059        # If args are provided, they must match the input signature.
3060        if not is_same_structure(self.input_signature, args):
3061          raise ValueError("Structure of Python function inputs does not match "
3062                           "input_signature.")
3063        flat_inputs = nest.flatten(args, expand_composites=True)
3064        if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec,
3065                                    resource_variable_ops.BaseResourceVariable))
3066               for arg in flat_inputs):
3067          raise ValueError("When input_signature is provided, all inputs to "
3068                           "the Python function must be Tensors, Variables, "
3069                           "tf.TensorSpec or tf.VariableSpec objects.")
3070        if any(not spec.is_compatible_with(other)
3071               for spec, other in zip(self.flat_input_signature, flat_inputs)):
3072          raise ValueError("Python inputs incompatible with input_signature: "
3073                           "inputs (%s), input_signature (%s)" %
3074                           (str(args), str(self.input_signature)))
3075      args, kwargs = None, None
3076    with self._lock:
3077      graph_function, _ = self._maybe_define_function(args, kwargs)
3078      seen_names = set()
3079      captured = object_identity.ObjectIdentitySet(
3080          graph_function.graph.internal_captures)
3081      # pylint: disable=protected-access
3082      graph_function._arg_keywords = []
3083      prefix_counts = {}
3084      # pylint: enable=protected-access
3085      num_positional = 0
3086      for arg in graph_function.graph.inputs:
3087        if arg in captured:
3088          break
3089        num_positional += 1
3090        user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
3091        proposal = user_arg_name
3092        while proposal in seen_names:
3093          index = prefix_counts.get(user_arg_name, 1)
3094          proposal = "{}_{}".format(user_arg_name, index)
3095          prefix_counts[user_arg_name] = index + 1
3096        seen_names.add(proposal)
3097        graph_function._arg_keywords.append(proposal)  # pylint: disable=protected-access
3098      # Anything can be a positional argument, in the same order as .inputs
3099      graph_function._num_positional_args = num_positional  # pylint: disable=protected-access
3100      return graph_function
3101
3102  def get_concrete_function(self, *args, **kwargs):
3103    """Returns a `ConcreteFunction` specialized to inputs and execution context.
3104
3105    Args:
3106      *args: inputs to specialize on. Can be concrete values (e.g. 1)
3107         or `tf.Tensor` or `tf.TensorSpec`.
3108      **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1)
3109         or `tf.Tensor` or `tf.TensorSpec`.
3110    """
3111    graph_function = self._get_concrete_function_garbage_collected(
3112        *args, **kwargs)
3113    graph_function._garbage_collector.release()  # pylint: disable=protected-access
3114    return graph_function
3115
3116  def __get__(self, instance, owner):
3117    """Makes it possible to defun instance methods."""
3118    del owner
3119    # `instance` here is the instance that this `Function` was accessed through
3120    # e.g., for
3121    #
3122    #   class Foo(object):
3123    #
3124    #     @function.defun
3125    #     def bar(self):
3126    #       ...
3127    #
3128    #   foo = Foo()
3129    #   foo.bar()  # `foo.bar` is a `Function` instance
3130    #
3131    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
3132    # new instance of `Function` here to allow different instances each
3133    # to create variables once, thereby allowing methods to be decorated with
3134    # defun. Keeps a cache to avoid retracing the function every time the
3135    # descriptor is accessed.
3136    if instance not in self._descriptor_cache:
3137      if instance is None:
3138        return self
3139      # If there is no instance-specific `Function` in the cache, we construct
3140      # an instance-specific `Function` that uses a weak reference to the
3141      # instance (so that the instance will be correctly gc'd).
3142
3143      # And finally add the wrapped function to the description cache
3144      self._descriptor_cache[instance] = class_method_to_instance_method(
3145          self, instance)
3146
3147    # Return the cached `Function` for the instance
3148    return self._descriptor_cache[instance]
3149
3150  def _cache_key(self,
3151                 args,
3152                 kwargs,
3153                 cache_key_context,
3154                 include_tensor_ranks_only=False):
3155    """Computes the cache key given inputs and execution context."""
3156    if self.input_signature is None:
3157      inputs = (args, kwargs) if kwargs else args
3158      input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
3159                                                    include_tensor_ranks_only)
3160      hashable_input_signature = _make_input_signature_hashable(input_signature)
3161    else:
3162      del args, kwargs
3163      assert not include_tensor_ranks_only
3164      hashable_input_signature = self._hashable_input_signature
3165
3166    (parent_graph, device_functions, colocation_stack, in_cross_replica_context,
3167     variable_policy, xla_context_id) = cache_key_context
3168
3169    return CacheKey(hashable_input_signature, parent_graph, device_functions,
3170                    colocation_stack, in_cross_replica_context, variable_policy,
3171                    xla_context_id)
3172
3173  def _cache_key_context(self):
3174    """Returns execution context."""
3175    ctx = context.context()
3176
3177    # Don't need to open an init_scope if the _cache_key call is in eager mode
3178    # already.
3179    executing_eagerly = ctx.executing_eagerly()
3180    parent_graph = None
3181    xla_context_id = 0
3182    if not executing_eagerly:
3183      # We want to force function retracing for each different
3184      # XLAControlFlowContext, so add `xla_context_id` to the cache key.
3185      xla_context = _enclosing_xla_context()
3186      if xla_context is not None and \
3187            xla_context.RequiresUniqueFunctionRetracing():
3188        xla_context_id = id(xla_context)
3189
3190      with ops.init_scope():
3191        # The graph, or whether we're executing eagerly, should be a part of the
3192        # cache key so we don't improperly capture tensors such as variables.
3193        executing_eagerly = ctx.executing_eagerly()
3194        parent_graph = None if executing_eagerly else ops.get_default_graph()
3195
3196    # pylint: disable=protected-access
3197    default_graph = ops.get_default_graph()
3198    # TODO(b/117617952): The current distribution strategy will affect graph
3199    # building (e.g. accessing different variables from different devices) and
3200    # so requires retracing for each device.
3201    strategy_stack = default_graph._distribution_strategy_stack
3202    uses_distribution_strategy = (
3203        strategy_stack and
3204        strategy_stack[-1].strategy.extended._retrace_functions_for_each_device
3205    )
3206    if executing_eagerly:
3207      colocation_stack = ()
3208      if uses_distribution_strategy:
3209        device_functions = (pydev.merge_device(ctx.device_name),)
3210      else:
3211        device_functions = ()
3212    else:
3213      colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
3214      if (uses_distribution_strategy
3215          or func_graph_module.device_stack_has_callable(
3216              default_graph._device_function_stack)):
3217        # Putting the device in the cache key ensures that call-site device
3218        # annotations are respected.
3219        device_functions = tuple(default_graph._device_functions_outer_to_inner)
3220      else:
3221        device_functions = ()
3222
3223    in_cross_replica_context = False
3224    try:
3225      in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
3226    except (AttributeError, IndexError):
3227      pass
3228
3229    if save_context.in_save_context():
3230      variable_policy = (
3231          save_context.get_save_options().experimental_variable_policy)
3232    else:
3233      variable_policy = None
3234
3235    return (parent_graph, device_functions, colocation_stack,
3236            in_cross_replica_context, variable_policy, xla_context_id)
3237
3238  def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
3239    """Create a `ConcreteFunction` from `args` and `kwargs`."""
3240    self.tracing_count += 1
3241
3242    if self.input_signature is None:
3243      arglen = len(args)
3244    else:
3245      arglen = len(self.input_signature)
3246    base_arg_names = self._function_spec.arg_names[:arglen]
3247    num_missing_args = arglen - len(self._function_spec.arg_names)
3248    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
3249    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
3250    # where arg is based on the self._function_spec.vararg_name.
3251    missing_arg_names = [
3252        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
3253    ]
3254    arg_names = base_arg_names + missing_arg_names
3255    graph_function = ConcreteFunction(
3256        func_graph_module.func_graph_from_py_func(
3257            self._name,
3258            self._python_function,
3259            args,
3260            kwargs,
3261            self.input_signature,
3262            autograph=self._autograph,
3263            autograph_options=self._autograph_options,
3264            arg_names=arg_names,
3265            override_flat_arg_shapes=override_flat_arg_shapes,
3266            capture_by_value=self._capture_by_value),
3267        self._function_attributes,
3268        function_spec=self.function_spec,
3269        # Tell the ConcreteFunction to clean up its graph once it goes out of
3270        # scope. This is not the default behavior since it gets used in some
3271        # places (like Keras) where the FuncGraph lives longer than the
3272        # ConcreteFunction.
3273        shared_func_graph=False)
3274    return graph_function
3275
3276  def _define_function_with_shape_relaxation(self, args, kwargs, flat_args,
3277                                             filtered_flat_args,
3278                                             cache_key_context):
3279    """Define a function, relaxing arg shapes to avoid unnecessary retracing."""
3280    flat_no_comp = nest.flatten((args, kwargs), expand_composites=False)
3281
3282    any_composite_args = any(
3283        isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp)
3284
3285    # Build a cache key where TensorShapes include only rank information (and
3286    # not information about the size of each dimension).
3287    if not any_composite_args:
3288      rank_only_cache_key = self._cache_key(
3289          args, kwargs, cache_key_context, include_tensor_ranks_only=True)
3290    else:
3291      # For the rank-only cache key, replace any composite tensors with
3292      # shape-relaxed TypeSpecs.
3293      (cache_key_args, cache_key_kwargs) = nest.map_structure(
3294          _shape_relaxed_type_for_composite_tensor, (args, kwargs))
3295      rank_only_cache_key = self._cache_key(
3296          cache_key_args,
3297          cache_key_kwargs,
3298          cache_key_context,
3299          include_tensor_ranks_only=True)
3300
3301    arg_specs = [_type_spec_for(x) for x in flat_no_comp]
3302    relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get(
3303        rank_only_cache_key, None)
3304    relaxed_arg_function = self._function_cache.arg_relaxed.get(
3305        rank_only_cache_key, None)
3306
3307    if (relaxed_arg_function is not None
3308        and all(_is_type_subset(x, y) for (x, y) in
3309                zip(relaxed_arg_specs, arg_specs))):
3310      return relaxed_arg_function, filtered_flat_args
3311
3312    if relaxed_arg_specs is None:
3313      relaxed_arg_specs = arg_specs
3314    else:
3315      if len(arg_specs) != len(relaxed_arg_specs):
3316        raise RuntimeError("Expected arg_specs len to match "
3317                           "relaxed_arg_specs len: %d vs. %d"
3318                           % (len(arg_specs), len(relaxed_arg_specs)))
3319      relaxed_arg_specs = [
3320          x if x is None else x.most_specific_compatible_type(y)
3321          for (x, y) in zip(arg_specs, relaxed_arg_specs)]
3322    self._function_cache.arg_relaxed_specs[rank_only_cache_key] = (
3323        relaxed_arg_specs)
3324    relaxed_arg_shapes = [
3325        x if x is None else x.shape
3326        for x in nest.flatten(relaxed_arg_specs, expand_composites=True)]
3327
3328    if any_composite_args:
3329      # Rebuild composite tensors with the relaxed TypeSpecs.  For example,
3330      # if a tf.data iterator is passed as an argument, then we need to relax
3331      # the TensorShapes in its element_spec.
3332      (relaxed_arg_specs, relaxed_kwarg_specs) = nest.pack_sequence_as(
3333          (args, kwargs), relaxed_arg_specs, expand_composites=False)
3334      (args, kwargs) = nest.pack_sequence_as(
3335          (relaxed_arg_specs, relaxed_kwarg_specs),
3336          flat_args,
3337          expand_composites=True)
3338
3339    graph_function = self._create_graph_function(
3340        args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
3341    self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
3342
3343    return (graph_function, [
3344        t for t in nest.flatten((args, kwargs), expand_composites=True)
3345        if isinstance(t, (ops.Tensor,
3346                          resource_variable_ops.BaseResourceVariable))
3347    ])
3348
3349  def _maybe_define_function(self, args, kwargs):
3350    """Gets a function for these inputs, defining it if necessary.
3351
3352    `args` and `kwargs` can be None if this `Function` was created with an
3353    `input_signature`.
3354
3355    Caller must hold self._lock.
3356
3357    Args:
3358      args: The varargs for the Python function.
3359      kwargs: The keyword args for the Python function.
3360
3361    Returns:
3362      A graph function corresponding to the input signature implied by args and
3363      kwargs, as well as filtered flattened inputs (only Tensors and Variables)
3364      that the object should be called with.
3365
3366    Raises:
3367      ValueError: If inputs are incompatible with the input signature.
3368      TypeError: If the function inputs include non-hashable objects
3369      RuntimeError: If there's an internal bug (inconsistency) in handling
3370        shape relaxation retracing.
3371    """
3372    if self.input_signature is None or args is not None or kwargs is not None:
3373      args, kwargs, flat_args, filtered_flat_args = \
3374          self._function_spec.canonicalize_function_inputs(*args, **kwargs)
3375    else:
3376      flat_args, filtered_flat_args = [None], []
3377
3378    cache_key_context = self._cache_key_context()
3379    cache_key = self._cache_key(args, kwargs, cache_key_context)
3380
3381    try:
3382      hash(cache_key)
3383    except TypeError as e:
3384      raise TypeError(
3385          "Arguments supplied to `defun`-generated functions must be"
3386          " hashable.  Original error: %s" % e)
3387
3388    graph_function = self._function_cache.primary.get(cache_key, None)
3389    if graph_function is not None:
3390      return graph_function, filtered_flat_args
3391
3392    with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()):
3393      with trace.Trace("tf.function-graph_building"):
3394        logging.vlog(1,
3395                     "Creating new FuncGraph for Python function %r (key: %r)",
3396                     self._python_function, cache_key)
3397        logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
3398                     args, kwargs)
3399
3400        # pylint: disable=protected-access
3401        call_context_key = cache_key._replace(input_signature=None)
3402        # pylint: disable=protected-access
3403
3404        ag_status = (
3405            ag_ctx.Status.ENABLED
3406            if self._autograph else ag_ctx.Status.DISABLED)
3407        with ag_ctx.ControlStatusCtx(
3408            status=ag_status, options=self._autograph_options):
3409
3410          # Build a function with shape relaxation retracing if:
3411          # 1. shape relaxation is explicitly enabled
3412          # and 2. there's no provided input signature
3413          # and 3. there's been a cache miss for this calling context
3414          if (self._experimental_relax_shapes and
3415              self.input_signature is None and
3416              call_context_key in self._function_cache.missed):
3417            return self._define_function_with_shape_relaxation(
3418                args, kwargs, flat_args, filtered_flat_args, cache_key_context)
3419
3420          self._function_cache.missed.add(call_context_key)
3421          graph_function = self._create_graph_function(args, kwargs)
3422          self._function_cache.primary[cache_key] = graph_function
3423
3424          return graph_function, filtered_flat_args
3425
3426
3427def register(func, *args, **kwargs):
3428  """Register a specialization of a `Function` into the graph.
3429
3430  This won't actually call the function with the inputs, and only put the
3431  function definition into graph. Register function with different input param
3432  will result into multiple version of functions registered in graph.
3433
3434  Args:
3435    func: the `Function` instance that generated by a @defun
3436    *args: input arguments for the Python function.
3437    **kwargs: input keyword arguments for the Python function.
3438
3439  Returns:
3440    a `ConcreteFunction` object specialized to inputs and execution context.
3441
3442  Raises:
3443    ValueError: When the input function is not a defun wrapped python function.
3444  """
3445  if not isinstance(func, Function):
3446    raise ValueError("Only defun function is allowed to be registered. "
3447                     "Got type: %s" % type(func))
3448  concrete_func = func.get_concrete_function(*args, **kwargs)
3449  concrete_func.add_to_graph()
3450  concrete_func.add_gradient_functions_to_graph()
3451  return concrete_func
3452
3453
3454def validate_signature(signature):
3455  if any(not isinstance(arg, tensor_spec.DenseSpec)
3456         for arg in nest.flatten(signature, expand_composites=True)):
3457    raise TypeError("Invalid input_signature {}; input_signature must be "
3458                    "a possibly nested sequence of TensorSpec objects."
3459                    .format(signature))
3460
3461
3462def defun(func=None,
3463          input_signature=None,
3464          autograph=True,
3465          experimental_autograph_options=None,
3466          experimental_relax_shapes=False):
3467  """Compiles a Python function into a callable TensorFlow graph.
3468
3469  `defun` (short for "define function") compiles a Python function
3470  composed of TensorFlow operations into a callable that executes a `tf.Graph`
3471  containing those operations. The callable produced by `defun` contains only
3472  the subgraph of TensorFlow operations that were executed when the Python
3473  function was called with a particular input signature, defined as a list
3474  of the shapes and dtypes of the Python function's Tensor-valued arguments and
3475  the values of its non-Tensor Python objects.
3476
3477  When eager execution is enabled, the ability to create graphs from Python
3478  functions makes it possible to incrementally trade off debuggability and
3479  interactivity for performance.  Functions compiled with `defun` cannot be
3480  inspected with `pdb`; however, executing a graph
3481  generated by `defun` sometimes takes less time and memory than eagerly
3482  executing the corresponding Python function, since specifying computations as
3483  graphs allows for optimizations like automatic buffer reuse and
3484  parallelization among ops. Note that executing a `defun`-compiled function
3485  incurs a small constant overhead, so eagerly executing sufficiently small
3486  Python functions might take less time than executing their corresponding
3487  `defun`-generated graphs.
3488
3489  For a Python function to be compatible with `defun`, all of its arguments must
3490  be hashable Python objects or lists thereof. The function itself may not
3491  modify the list/map structure of its arguments. Additionally, it must return
3492  zero or more `tf.Tensor` objects. If the Python function returns
3493  a `tf.Variable`, its compiled version will return the value of that variable
3494  as a `tf.Tensor`.
3495
3496  Executing a graph generated by `defun` respects device annotations (i.e.,
3497  all `with tf.device` directives present in a Python function will also be
3498  present in its corresponding graph), but it is not yet possible to execute the
3499  generated graphs across multiple machines.
3500
3501  _Example Usage_
3502
3503  ```python
3504  import tensorflow as tf
3505
3506  tf.compat.v1.enable_eager_execution()
3507
3508  # A simple example.
3509  def f(x, y):
3510    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
3511
3512  g = tf.contrib.eager.defun(f)
3513
3514  x = tf.constant([[2.0, 3.0]])
3515  y = tf.constant([[3.0, -2.0]])
3516
3517  # `f` and `g` will return the same value, but `g` will be executed as a
3518  # TensorFlow graph.
3519  assert f(x, y).numpy() == g(x, y).numpy()
3520
3521  # `defun` is capable of compiling Python functions that close over Python
3522  # objects, including Tensors and Variables.
3523  @tf.contrib.eager.defun
3524  def h():
3525    return f(x, y)
3526
3527  assert (h().numpy() == f(x, y).numpy()).all()
3528
3529  # `defun` automatically lifts variables out of the graphs it creates,
3530  # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
3531  # `tf.keras.Model` objects.
3532  class MyModel(tf.keras.Model):
3533
3534    def __init__(self, keep_probability=0.2):
3535      super(MyModel, self).__init__()
3536      self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
3537      self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
3538      self.keep_probability = keep_probability
3539
3540    @tf.contrib.eager.defun
3541    def call(self, inputs, training=True):
3542      x = self.dense2(self.dense1(inputs))
3543      if training:
3544        return tf.nn.dropout(x, self.keep_probability)
3545      else:
3546        return x
3547
3548  model = MyModel()
3549  model(x, training=True)  # executes a graph, with dropout
3550  model(x, training=False) # executes a graph, without dropout
3551
3552  # `defun`-compiled functions are differentiable.
3553  optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
3554  with tf.GradientTape() as tape:
3555    outputs = model(x)
3556  gradient = tape.gradient(outputs, model.trainable_variables)
3557  optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
3558                            model.trainable_variables))
3559  ```
3560
3561  When using `defun`, there are subtleties regarding inputs, Python control
3562  flow, and variable creation that one should be aware of. For concreteness, let
3563  `f` be a Python function that returns zero or more `tf.Tensor` objects and
3564  let `F = defun(f)`. `F` builds a graph for each unique input signature it
3565  sees, Python control flow is baked into graphs, and operations related to
3566  variable initialization are automatically lifted out of the graphs that `F`
3567  generates and placed in the eager context if executing eagerly or into an
3568  outer graph otherwise.
3569
3570  _Input Signatures_
3571
3572  By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
3573  for every unique sequence of the shapes and dtypes of Tensor arguments and
3574  the values of Python objects it is invoked with. For example, calling
3575  `F(tf.random.uniform([2])` will execute a different graph than
3576  `F(tf.random.uniform([3])` because the two inputs have different shapes.
3577  The first time that `F(*args, **kwargs)` is called with a particular sequence
3578  of Tensor shapes and dtypes and Python values, it constructs a graph by
3579  tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
3580  input signature inferred from `(*args, **kwargs)` and cached for future reuse.
3581
3582  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
3583  before being passed to `f`, and are treated as Tensors for caching. This
3584  allows a function to be called multiple times with NumPy arrays having
3585  different values but the same shape and dtype without re-tracing each time.
3586
3587  `tf.contrib.eager.defun` caches graphs for your convenience, letting you
3588  define TensorFlow functions without explicitly specifying their signatures.
3589  However, this policy is conservative and potentially expensive; for example,
3590  when different invocations of your function have differently-shaped Tensor
3591  inputs, this policy might generate more graph functions than necessary. To
3592  eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
3593  optional `input_signature` argument specifying the shapes and dtypes of the
3594  inputs. In particular, the shapes may be partially unspecified, with `None`s
3595  in the unknown dimensions.  When an input signature is provided,
3596  `tf.contrib.eager.defun` will only instantiate a single graph for the
3597  decorated Python function. The following is an example:
3598
3599  ```python
3600  import tensorflow as tf
3601
3602  # The first `TensorSpec` below describes the shape and dtype of `words`,
3603  # and the second describes the shape and dtype of `another_tensor`. Note that
3604  # the last dimension of the `words` `TensorSpec` is left unspecified.
3605  @tf.contrib.eager.defun(input_signature=[
3606    tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
3607    tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
3608  ])
3609  def my_sequence_model(words, another_tensor):
3610    ...
3611
3612  # Note how the third dimension of the first input can vary freely.
3613  words = tf.random.uniform(([50, 300, 10])
3614  second_input = tf.random.uniform([300, 100])
3615  my_sequence_model(words, second_input)
3616
3617  words = tf.random.uniform(([50, 300, 20])
3618  my_sequence_model(words, second_input)
3619
3620  # Passing an input with an incompatible shape will raise an error.
3621  words = tf.random.uniform(([50, 100, 20])
3622  my_sequence_model(words, second_input)  # <---- This will raise an error.
3623
3624  ```
3625
3626  Python functions that are compiled with an `input_signature` must only accept
3627  Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
3628
3629  _Tracing_
3630
3631  Be aware that because `F` only logs TensorFlow operations, all the other
3632  Python code that `f` executes will only shape the _construction_ of the graphs
3633  that `F` executes: the Python code won't be executed when the graphs
3634  themselves are executed, though it will be executed every time the Python
3635  function is traced (and a given Python function might be traced multiple
3636  times, once for each input signature it is invoked with). For example, whereas
3637  the Python function
3638
3639  ```python
3640  import tensorflow as tf
3641  import numpy as np
3642
3643  tf.compat.v1.enable_eager_execution()
3644
3645  def add_noise():
3646    return tf.eye(5) + np.random.randn(5, 5)
3647  ```
3648
3649  will return a different output everytime it is invoked, the compiled function
3650  `compiled = tf.contrib.eager.defun(add_noise)` will return the same value
3651  every time it is called, since a particular random offset generated by NumPy
3652  will be inserted into the graph as a TensorFlow constant. The solution is to
3653  replace the call to `np.random.randn` with `tf.random.normal((5, 5))`.
3654
3655  _Python Side-Effects_
3656
3657  A corollary of the previous discussion on tracing is the following: If a
3658  Python function `f` has Python side-effects, then executing `f` multiple times
3659  will not necessarily be semantically equivalent to executing `F =
3660  tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
3661  that `defun` only captures the subgraph of TensorFlow operations that is
3662  constructed when `f` is called in a graph-building context.
3663
3664  _Python Control Flow_
3665
3666  The structure of many machine learning computations depend upon whether one is
3667  training or validating, and it is common to nest specialized logic under `if
3668  training:` blocks. By mapping each input signature to a unique graph, `defun`
3669  lets users transparently compile such code, as the following code snippet
3670  demonstrates:
3671
3672  ```python
3673  import tensorflow as tf
3674
3675  tf.compat.v1.enable_eager_execution()
3676
3677  @tf.contrib.eager.defun
3678  def lossy_matmul(W, x, training=True):
3679    outputs = tf.matmul(W, x)
3680    if training:
3681      outputs = tf.nn.dropout(outputs, keep_probability=0.2)
3682    return outputs
3683
3684  W = tf.random.normal((3, 5))
3685  x = tf.random.normal((5, 1))
3686
3687  # Executes a graph that applies dropout.
3688  lossy_outputs = lossy_matmul(W, x, training=True)
3689
3690  # Executes a graph that does not apply dropout.
3691  exact_outputs = lossy_matmul(W, x, training=False)
3692  ```
3693
3694  _TensorFlow Control Flow_
3695
3696  When `autograph` is `True`, data-dependent control flow is allowed as well.
3697  Control flow statements that depend on `Tensor` values are staged into
3698  corresponding TensorFlow ops. For example, the following code will work as
3699  expected:
3700
3701  ```python
3702  @tf.contrib.eager.defun
3703  def dynamic_rnn_loop(cell, seq):
3704    state, output = cell.zero_state()
3705    for input in seq:
3706      state, output = cell(input, state)
3707    return output
3708  ```
3709
3710  For more information see `tf.autograph`.
3711
3712  _Variables_
3713
3714  TensorFlow operations related to variable creation and initialization are
3715  automatically lifted out of the graphs generated by `defun`. In practice, this
3716  implies that variable creation and initialization only happen the first time
3717  `F` is called, and that variables are reused every time thereafter. Many
3718  TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
3719  first time they are called and reuse them thereafter. Automatic variable
3720  lifting makes it possible to compile these APIs without extra effort, at the
3721  cost of introducing a discrepancy between the semantics of executing Python
3722  functions and their corresponding compiled functions. For example:
3723
3724  ```python
3725  import tensorflow as tf
3726
3727  tf.compat.v1.enable_eager_execution()
3728
3729  def fn():
3730    x = tf.Variable(0.0)
3731    x.assign_add(1.0)
3732    return x.read_value()
3733
3734  # `fn` is a Python function, so x is created, initialized, and destroyed upon
3735  # every invocation
3736  assert fn().numpy() == fn().numpy() == 1.0
3737
3738  compiled = tf.contrib.eager.defun(fn)
3739
3740  # Compiling `fn` with `defun` hoists all variables outside of the generated
3741  # graph, so initialization happens exactly once.
3742  assert compiled().numpy() == 1.0
3743  assert compiled().numpy() == 2.0
3744  ```
3745
3746  Finally, because each input signature is bound to a unique graph, if your
3747  Python function constructs `tf.Variable` objects, then each graph constructed
3748  for that Python function will reference a unique set of variables. To
3749  circumvent this problem, we recommend against compiling Python functions that
3750  create `tf.Variable` objects. Instead, Python functions should either
3751  lexically close over `tf.Variable` objects or accept them as arguments,
3752  preferably encapsulated in an object-oriented container. If you must create
3753  variables inside your Python function and you want each graph generated for it
3754  to reference the same set of variables, add logic to your Python function that
3755  ensures that variables are only created the first time it is called and are
3756  reused for every subsequent invocation; note that this is precisely what
3757  `tf.keras.layers.Layer` objects do, so we recommend using them to represent
3758  variable-bearing computations whenever possible.
3759
3760  Args:
3761    func: function to be compiled. If `func` is None, returns a
3762      decorator that can be invoked with a single argument - `func`. The
3763      end result is equivalent to providing all the arguments up front.
3764      In other words, defun(input_signature=...)(func) is equivalent to
3765      defun(func, input_signature=...). The former allows
3766      the following use case:
3767        @tf.contrib.eager.defun(input_signature=...)
3768        def foo(...):
3769          ...
3770
3771    input_signature: A possibly nested sequence of
3772      `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
3773      the Tensors that will be supplied to this function. If `None`, a separate
3774      function is instantiated for each inferred input signature.  If a
3775      signature is specified, every input to `func` must be a `Tensor`, and
3776      `func` cannot accept `**kwargs`.
3777    autograph: Whether `func` should be compiled before
3778      constructing the graph. See https://www.tensorflow.org/guide/autograph
3779      for more information.
3780    experimental_autograph_options: Experimental knobs (in the form of a tuple
3781      of tensorflow.autograph.Feature values) to control behavior when
3782      autograph=True.
3783    experimental_relax_shapes: When true, argument shapes may be relaxed to
3784      avoid unnecessary retracing.
3785
3786  Returns:
3787     If `func` is not None, returns a callable that will execute the compiled
3788     function (and return zero or more `tf.Tensor` objects).
3789     If `func` is None, returns a decorator that, when invoked with a single
3790     `func` argument, returns a callable equivalent to the case above.
3791
3792  Raises:
3793    TypeError: If `input_signature` is neither `None` nor a sequence of
3794      `tf.contrib.eager.TensorSpec` objects.
3795  """
3796  return defun_with_attributes(
3797      func=func,
3798      input_signature=input_signature,
3799      autograph=autograph,
3800      experimental_autograph_options=experimental_autograph_options,
3801      experimental_relax_shapes=experimental_relax_shapes)
3802
3803
3804@tf_export("__internal__.function.defun_with_attributes", v1=[])
3805def defun_with_attributes(func=None,
3806                          input_signature=None,
3807                          attributes=None,
3808                          autograph=True,
3809                          experimental_autograph_options=None,
3810                          jit_compile=None,
3811                          experimental_relax_shapes=False,
3812                          experimental_follow_type_hints=False):
3813  """Compiles a Python function into a callable TensorFlow graph.
3814
3815  This function supports adding extra function attributes. See detailed
3816  documentation in defun(). Currently this is not exposed in public API since we
3817  don't expect user to directly use attributes, and attribute won't work by
3818  itself. This assumption might change in future.
3819
3820  Args:
3821    func: function to be compiled.
3822    input_signature: same as defun()'s input_signature.
3823    attributes: A dictionary of arguments which will be added to function def as
3824      attributes. Currently only support primitive types as value, and only
3825      allowlisted attribute name is allowed. Unallowlisted attribute name or
3826      unsupported value will result into ValueError. `func_name` is also one of
3827      the allowlisted argument which is a python string, and sets the name for
3828      this `ConcreteFunction` in the graph.
3829    autograph: same as defun()'s autograph.
3830    experimental_autograph_options: same as defun()'s
3831      experimental_autograph_options.
3832    jit_compile: same as defun()'s jit_compile.
3833    experimental_relax_shapes: same as defun()'s experimental_relax_shapes
3834    experimental_follow_type_hints: see `tf.function`.
3835
3836  Returns:
3837    Same as the return value of defun, with attributes added to the function in
3838    graph.
3839  """
3840  if input_signature is not None:
3841    validate_signature(input_signature)
3842
3843  # TODO(apassos): deal with captured global state. Deal with control flow.
3844  def decorated(function):
3845    try:
3846      if attributes:
3847        name = attributes.pop("func_name", function.__name__)
3848      else:
3849        name = function.__name__
3850    except AttributeError:
3851      name = "function"
3852    return tf_decorator.make_decorator(
3853        function,
3854        Function(
3855            function,
3856            name,
3857            input_signature=input_signature,
3858            attributes=attributes,
3859            autograph=autograph,
3860            autograph_options=experimental_autograph_options,
3861            jit_compile=jit_compile,
3862            experimental_relax_shapes=experimental_relax_shapes,
3863            experimental_follow_type_hints=experimental_follow_type_hints))
3864
3865  # This code path is for the `foo = tfe.defun(foo, ...)` use case
3866  if func is not None:
3867    return decorated(func)
3868
3869  # This code path is for the
3870  #
3871  # @tfe.defun(...)
3872  # def foo(...):
3873  #    ...
3874  #
3875  # use case, which is equivalent to `foo = tfe.defun(...)(foo)`
3876  return decorated
3877
3878
3879# When a method is bound to objects of this type, it allows AutoGraph to
3880# recover a weak reference the original method's self pointer, so that it can
3881# execute it consistent with class_method_to_instance_method's
3882# bound_method_wrapper.
3883# TODO(b/119246461): This is not pretty. Use a descriptor instead?
3884class TfMethodTarget(object):
3885  """Binding target for methods replaced by function and defun."""
3886
3887  __slots__ = ("weakrefself_target__", "weakrefself_func__")
3888
3889  def __init__(self, target, original_python_function):
3890    self.weakrefself_target__ = target
3891    self.weakrefself_func__ = weakref.ref(original_python_function)
3892
3893  @property
3894  def target(self):
3895    return self.weakrefself_target__()
3896
3897  @property
3898  def target_class(self):
3899    true_self = self.weakrefself_target__()
3900    if tf_inspect.isclass(true_self):
3901      # Class method
3902      return true_self
3903    else:
3904      return true_self.__class__
3905
3906  def call(self, args, kwargs):
3907    wrapped_fn = self.weakrefself_func__()
3908    if tf_inspect.ismethod(wrapped_fn):
3909      wrapped_fn = six.get_unbound_function(wrapped_fn)
3910    return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)
3911
3912
3913def class_method_to_instance_method(original_function, instance):
3914  """Constructs a new `Function` with `self` bound."""
3915  weak_instance = weakref.ref(instance)
3916
3917  # Note: while we could bind to a weakref proxy instead, that causes the
3918  # bound method to be unhashable.
3919  bound_method = types_lib.MethodType(
3920      original_function.python_function,
3921      TfMethodTarget(weak_instance, original_function.python_function))
3922
3923  # original_function is expected to be of one of the two `Function` types
3924  # (defined either in function.py or def_function.py).
3925  assert hasattr(original_function, "_name")
3926  assert hasattr(original_function, "_autograph")
3927  assert hasattr(original_function, "_function_spec")
3928  assert hasattr(original_function, "python_function")
3929
3930  weak_bound_method_wrapper = None
3931  def bound_method_wrapper(*args, **kwargs):
3932    """Wraps either a dummy MethodType or a converted AutoGraph function."""
3933    # __wrapped__ allows AutoGraph to swap in a converted function.
3934    strong_bound_method_wrapper = weak_bound_method_wrapper()
3935    wrapped_fn = strong_bound_method_wrapper.__wrapped__
3936
3937    if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
3938      # If __wrapped__ was not replaced, then call original_function.
3939      # TODO(mdan): For better consistency, use the wrapper's call().
3940      wrapped_fn = original_function.python_function
3941      if tf_inspect.ismethod(wrapped_fn):
3942        wrapped_fn = six.get_unbound_function(wrapped_fn)
3943      return wrapped_fn(weak_instance(), *args, **kwargs)
3944
3945    # If __wrapped__ was replaced, then it is always an unbound function.
3946    # However, the replacer is still responsible for attaching self properly.
3947    # TODO(mdan): Is it possible to do it here instead?
3948    return wrapped_fn(*args, **kwargs)
3949  weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
3950
3951  # pylint: disable=protected-access
3952  # We make a dummy MethodType object to generate the correct bound method
3953  # signature. The actual call is to a function with a weak reference to
3954  # `instance`.
3955  instance_func = type(original_function)(
3956      tf_decorator.make_decorator(bound_method, bound_method_wrapper),
3957      name=original_function._name,
3958      autograph=original_function._autograph,
3959      input_signature=original_function.input_signature,
3960      experimental_relax_shapes=original_function._experimental_relax_shapes,
3961      jit_compile=original_function._jit_compile)
3962  # pylint: enable=protected-access
3963
3964  # We wrap the the bound method with tf_decorator so inspection works correctly
3965  wrapped_instance_func = tf_decorator.make_decorator(bound_method,
3966                                                      instance_func)
3967  return wrapped_instance_func
3968
3969
3970class _FunctionGarbageCollector(object):
3971  """Cleans up cycles when a defun goes out of scope."""
3972
3973  __slots__ = ["_cache"]
3974
3975  def __init__(self, cache):
3976    self._cache = cache
3977
3978  def __del__(self):
3979    if func_graph_module is None or memory is None:
3980      return
3981    try:
3982      while self._cache:
3983        self._cache.popitem()
3984      memory.dismantle_ordered_dict(self._cache)
3985    except:  # pylint: disable=bare-except
3986      pass
3987
3988
3989class ConcreteFunctionGarbageCollector(object):
3990  """Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
3991
3992  __slots__ = ["_func_graph"]
3993
3994  def __init__(self, func_graph):
3995    self._func_graph = func_graph
3996
3997  def release(self):
3998    """Call off the FuncGraph deletion."""
3999    self._func_graph = None
4000
4001  def __del__(self):
4002    if func_graph_module is None or memory is None or self._func_graph is None:
4003      return
4004    try:
4005      func_graph_module.dismantle_func_graph(self._func_graph)
4006    except:  # pylint: disable=bare-except
4007      pass
4008
4009
4010class _Marker(object):
4011  """Markers used to pretty-print nested args in function signatures."""
4012
4013  __slots__ = ["_s"]
4014
4015  def __init__(self, s):
4016    self._s = s
4017
4018  def __repr__(self):
4019    return str(self._s)
4020
4021
4022def _structure_summary(structure):
4023  """Displays a summary of the nesting structure of the given value."""
4024
4025  def type_name(x):
4026    if isinstance(x, type_spec.TypeSpec):
4027      return x.value_type.__name__
4028    else:
4029      return type(x).__name__
4030
4031  markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
4032  return str(nest.pack_sequence_as(structure, markers))
4033
4034
4035def _contains_type_spec(value):
4036  return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))
4037