1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining graph functions with eager semantics."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import threading
24import weakref
25import six
26
27from google.protobuf import text_format as _text_format
28from google.protobuf.message import DecodeError
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.python.distribute.parallel_device import parallel_device
31from tensorflow.python.eager import context
32from tensorflow.python.eager import function as function_lib
33from tensorflow.python.eager import lift_to_graph
34from tensorflow.python.eager import monitoring
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import func_graph as func_graph_module
37from tensorflow.python.framework import ops
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import control_flow_util
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import random_ops
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.profiler import trace
46from tensorflow.python.training.tracking import base as trackable
47from tensorflow.python.util import deprecation
48from tensorflow.python.util import nest
49from tensorflow.python.util import object_identity
50from tensorflow.python.util import tf_decorator
51from tensorflow.python.util.tf_export import tf_export
52
53FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10
54FREQUENT_TRACING_WARNING_THRESHOLD = 5
55FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2
56
57
58_tf_function_counter = monitoring.Counter(
59    "/tensorflow/core/tf_function_counter",
60    "Counter for the number of tf.functions created when Eager execution is "
61    "enabled.",
62    # jit_compile is "0" or "1".
63    "jit_compile")
64
65
66class _FrequentTracingDetector(object):
67  """Class keeping track of how many recent calls triggered tracing."""
68
69  __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"]
70
71  def __init__(self):
72    self._calls_per_tracings = []
73    self._total_warning_count = 0
74    self._call_count = 0
75
76  def called_with_tracing(self, function_name, omit_warning):
77    """Updates the list of most recent calls' tracing information.
78
79    Warns the user when recent calls caused retracing too often.
80
81    Args:
82      function_name: the python function being traced.
83      omit_warning: If 'True', this call will not warn the user even if
84        retracing happens too often.
85    """
86    self._call_count += 1
87    self._calls_per_tracings.append(1)
88
89    while self._calls_per_tracings:
90      if (self._call_count - self._calls_per_tracings[0] >
91          FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY):
92        self._call_count -= self._calls_per_tracings.pop(0)
93      else:
94        break
95
96    if (omit_warning or self._total_warning_count >=
97        FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR):
98      return
99    if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD:
100      self._total_warning_count += 1
101      logging.warning(
102          "{} out of the last {} calls to {} triggered tf.function "
103          "retracing. Tracing is expensive and the excessive number of "
104          "tracings could be due to (1) creating @tf.function repeatedly in "
105          "a loop, (2) passing tensors with different shapes, (3) passing "
106          "Python objects instead of tensors. For (1), please define your "
107          "@tf.function outside of the loop. For (2), @tf.function has "
108          "experimental_relax_shapes=True option that relaxes argument "
109          "shapes that can avoid unnecessary retracing. For (3), please "
110          "refer to "
111          "https://www.tensorflow.org/guide/function#controlling_retracing"
112          " and https://www.tensorflow.org/api_docs/python/tf/function for "
113          " more details.".format(
114              len(self._calls_per_tracings), self._call_count, function_name))
115
116  def called_without_tracing(self):
117    # We don't count tracing when users load a concrete function directly or
118    # call get_concrete_function, so the first call can be not a tracing call.
119    if not self._calls_per_tracings:
120      self._calls_per_tracings = [0]
121    self._calls_per_tracings[-1] += 1
122    self._call_count += 1
123
124
125class _FrequentTracingDetectorManager(object):
126  """Class for the management of all _FrequentTracingDetector objects."""
127
128  __slots__ = ["_detectors", "_lock"]
129
130  def __init__(self):
131    self._detectors = weakref.WeakKeyDictionary()  # GUARDED_BY(self._lock)
132    self._lock = threading.Lock()
133
134  def _get_detector(self, key):
135    if key not in self._detectors:
136      self._detectors[key] = _FrequentTracingDetector()
137    return self._detectors[key]
138
139  def called_without_tracing(self, key):
140    with self._lock:
141      detector = self._get_detector(key)
142      detector.called_without_tracing()
143
144  def called_with_tracing(self, key, function_name, omit_warning):
145    with self._lock:
146      detector = self._get_detector(key)
147      detector.called_with_tracing(function_name, omit_warning)
148
149
150_frequent_tracing_detector_manager = _FrequentTracingDetectorManager()
151
152
153class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
154  """Variable which does not lift its initializer out of function context.
155
156  Instances of this variable, when created, build a graph which runs their
157  initializer inside a tf.cond(is_initialized) block.
158
159  This can only be created inside a defun called from (eventually) eager
160  mode. That is, non-function-building graphs are not supported.
161  """
162
163  def __init__(self,
164               initial_value=None,
165               trainable=None,
166               caching_device=None,
167               name=None,
168               dtype=None,
169               constraint=None,
170               add_initializers_to=None,
171               lifted_initializer_graph=None,
172               synchronization=None,
173               aggregation=None,
174               shape=None,
175               **unused_kwargs):
176    """Creates a variable.
177
178    Args:
179      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
180        which is the initial value for the Variable. The initial value must have
181        a shape specified unless `validate_shape` is set to False. Can also be a
182        callable with no argument that returns the initial value when called.
183        (Note that initializer functions from init_ops.py must first be bound
184         to a shape before being used here.)
185      trainable: If `True`, GradientTapes automatically watch uses of this
186        Variable.
187      caching_device: Optional device string or function describing where the
188        Variable should be cached for reading.  Defaults to the Variable's
189        device.  If not `None`, caches on another device.  Typical use is to
190        cache on the device where the Ops using the Variable reside, to
191        deduplicate copying through `Switch` and other conditional statements.
192      name: Optional name for the variable. Defaults to `'Variable'` and gets
193        uniquified automatically.
194      dtype: If set, initial_value will be converted to the given type.
195        If None, either the datatype will be kept (if initial_value is
196       a Tensor) or float32 will be used (if it is a Python object convertible
197       to a Tensor).
198      constraint: An optional projection function to be applied to the variable
199        after being updated by an `Optimizer` (e.g. used to implement norm
200        constraints or value constraints for layer weights). The function must
201        take as input the unprojected Tensor representing the value of the
202        variable and return the Tensor for the projected value
203        (which must have the same shape). Constraints are not safe to
204        use when doing asynchronous distributed training.
205      add_initializers_to: if not None and not in legacy graph mode, the
206        initializer tensor will be added to this map in addition to adding the
207        assignment to the function.
208      lifted_initializer_graph: FuncGraph to try to lift initializers to.
209      synchronization: Indicates when a distributed a variable will be
210        aggregated. Accepted values are constants defined in the class
211        `tf.VariableSynchronization`. By default the synchronization is set to
212        `AUTO` and the current `DistributionStrategy` chooses
213        when to synchronize.
214      aggregation: Indicates how a distributed variable will be aggregated.
215        Accepted values are constants defined in the class
216        `tf.VariableAggregation`.
217      shape: (optional) The shape of this variable. If None, the shape of
218        `initial_value` will be used. When setting this argument to
219        `tf.TensorShape(None)` (representing an unspecified shape), the variable
220        can be assigned with values of different shapes.
221
222    Raises:
223      ValueError: If the initial value is not specified, or does not have a
224        shape and `validate_shape` is `True`.
225      RuntimeError: If called outside of a function definition.
226    """
227    with ops.init_scope():
228      self._in_graph_mode = not context.executing_eagerly()
229    if not ops.inside_function():
230      # If we've been init_scope()d out of the function definition nothing to do
231      # here; we can't really do the capturing or conditional logic.
232      resource_variable_ops.ResourceVariable.__init__(
233          self, initial_value=initial_value, trainable=trainable,
234          caching_device=caching_device, name=name, dtype=dtype,
235          constraint=constraint)
236      return
237    if initial_value is None:
238      raise ValueError("initial_value must be specified.")
239    init_from_fn = callable(initial_value)
240
241    if constraint is not None and not callable(constraint):
242      raise ValueError("The `constraint` argument must be a callable.")
243
244    with ops.name_scope(name, "Variable", []
245                        if init_from_fn else [initial_value]) as scope_name:
246      with ops.name_scope("Initializer"):
247        if init_from_fn:
248          initial_value = initial_value()
249        if isinstance(initial_value, trackable.CheckpointInitialValue):
250          self._maybe_initialize_trackable()
251          self._update_uid = initial_value.checkpoint_position.restore_uid
252          initial_value = initial_value.wrapped_value
253
254        initial_value = ops.convert_to_tensor(initial_value,
255                                              name="initial_value", dtype=dtype)
256      assert initial_value is not None
257
258      # Don't use `shape or initial_value.shape` since TensorShape has
259      # overridden `__bool__`.
260      if shape is None:
261        shape = initial_value.shape
262
263    # Use the constructor for UninitializedVariable to start. Outside the name
264    # scope so we don't double up the prefix.
265    super(UnliftedInitializerVariable, self).__init__(
266        trainable=trainable,
267        caching_device=caching_device,
268        name=name,
269        shape=shape,
270        dtype=initial_value.dtype,
271        constraint=constraint,
272        synchronization=synchronization,
273        aggregation=aggregation,
274        extra_handle_data=initial_value,
275        **unused_kwargs)
276
277    with ops.name_scope(scope_name):
278      if self._in_graph_mode:
279        with ops.init_scope():
280          outer_graph = ops.get_default_graph()
281        func_graph = ops.get_default_graph()
282        function_placeholders = (
283            func_graph.inputs + func_graph.internal_captures)
284        placeholder_ops = set(
285            [tensor.op for tensor in function_placeholders])
286        lifted_initializer = lift_to_graph.lift_to_graph(
287            [initial_value], outer_graph,
288            disallowed_placeholders=placeholder_ops)[initial_value]
289        with ops.init_scope():
290          self._initial_value = lifted_initializer
291          with ops.name_scope("IsInitialized"):
292            self._is_initialized_op = (
293                resource_variable_ops.var_is_initialized_op(self._handle))
294          if initial_value is not None:
295            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
296              self._initializer_op = resource_variable_ops.assign_variable_op(
297                  self._handle, lifted_initializer, name=n)
298      elif context.executing_eagerly():
299        # In this case, both current scope and init scope are eager.
300        # Assign_variable_op will be executed immediately. So we don't need to
301        # add it to "add_initializers_to" to lift it out.
302        with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
303          resource_variable_ops.assign_variable_op(
304              self._handle, initial_value, name=n)
305      else:
306        # Init scope is eager but current scope is graph. We will lift out this
307        # variable by addint it into "add_initializers_to".
308        if add_initializers_to is not None:
309          add_initializers_to.append((self, initial_value))
310
311        def assign_fn():
312          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
313            resource_variable_ops.assign_variable_op(
314                self._handle,
315                initial_value,
316                name=n)
317            # Returning values to keep tf.cond happy.
318          return ops.convert_to_tensor(1)
319        def not_assign_fn():
320          return ops.convert_to_tensor(0)
321        # Note: this cond is always guaranteed to run because we're inside a
322        # defun which will insert automatic control dependencies. It will only
323        # execute assign_fn if lifting failed.
324        graph = ops.get_default_graph()
325
326        # Capture the handle ahead of time in order to avoid querying the shape
327        # of the handle which helps async execution performance
328        graph.capture(self._handle, shape=())
329        control_flow_ops.cond(
330            resource_variable_ops.var_is_initialized_op(self._handle),
331            not_assign_fn, assign_fn)
332
333
334RUN_FUNCTIONS_EAGERLY = False
335
336
337@deprecation.deprecated(
338    None,
339    "Use `tf.config.run_functions_eagerly` instead of the experimental "
340    "version.")
341@tf_export("config.experimental_run_functions_eagerly")
342def experimental_run_functions_eagerly(run_eagerly):
343  """Enables / disables eager execution of `tf.function`s.
344
345  Calling `tf.config.experimental_run_functions_eagerly(True)` will make all
346  invocations of `tf.function` run eagerly instead of running as a traced graph
347  function.
348
349  See `tf.config.run_functions_eagerly` for an example.
350
351  Note: This flag has no effect on functions passed into tf.data transformations
352  as arguments. tf.data functions are never executed eagerly and are always
353  executed as a compiled Tensorflow Graph.
354
355  Args:
356    run_eagerly: Boolean. Whether to run functions eagerly.
357  """
358  return run_functions_eagerly(run_eagerly)
359
360
361@tf_export("config.run_functions_eagerly")
362def run_functions_eagerly(run_eagerly):
363  """Enables / disables eager execution of `tf.function`s.
364
365  Calling `tf.config.run_functions_eagerly(True)` will make all
366  invocations of `tf.function` run eagerly instead of running as a traced graph
367  function.
368
369  This can be useful for debugging.
370
371  >>> def my_func(a):
372  ...  print("Python side effect")
373  ...  return a + a
374  >>> a_fn = tf.function(my_func)
375
376  >>> # A side effect the first time the function is traced
377  >>> a_fn(tf.constant(1))
378  Python side effect
379  <tf.Tensor: shape=(), dtype=int32, numpy=2>
380
381  >>> # No further side effect, as the traced function is called
382  >>> a_fn(tf.constant(2))
383  <tf.Tensor: shape=(), dtype=int32, numpy=4>
384
385  >>> # Now, switch to eager running
386  >>> tf.config.run_functions_eagerly(True)
387  >>> # Side effect, as the function is called directly
388  >>> a_fn(tf.constant(2))
389  Python side effect
390  <tf.Tensor: shape=(), dtype=int32, numpy=4>
391
392  >>> # Turn this back off
393  >>> tf.config.run_functions_eagerly(False)
394
395  Note: This flag has no effect on functions passed into tf.data transformations
396  as arguments. tf.data functions are never executed eagerly and are always
397  executed as a compiled Tensorflow Graph.
398
399  Args:
400    run_eagerly: Boolean. Whether to run functions eagerly.
401  """
402  global RUN_FUNCTIONS_EAGERLY
403  RUN_FUNCTIONS_EAGERLY = bool(run_eagerly)
404
405
406@deprecation.deprecated(
407    None,
408    "Use tf.config.functions_run_eagerly instead of the experimental version.")
409@tf_export("config.experimental_functions_run_eagerly")
410def experimental_functions_run_eagerly():
411  """Returns the value of the `experimental_run_functions_eagerly` setting."""
412  return functions_run_eagerly()
413
414
415@tf_export("config.functions_run_eagerly")
416def functions_run_eagerly():
417  """Returns the value of the `run_functions_eagerly` setting."""
418  return RUN_FUNCTIONS_EAGERLY
419
420
421def _evaluate_var_is_initialized(variables):
422  """Compute booleans indicating whether each variable is initialized."""
423  with ops.init_scope():
424    var_is_initialized = []
425    for v in variables:
426      var_is_initialized.append(
427          resource_variable_ops.var_is_initialized_op(v.handle))
428    try:
429      # Stack all the var_is_initialized values into one tensor and interpret
430      # the numpy value. This will reduce the number of RPCs between client and
431      # worker in the remote case.
432      return array_ops.stack(var_is_initialized).numpy()
433    except errors.UnimplementedError:
434      # Some devices do not support implicit copy-off to host. Fall back to
435      # variable-by-variable processing.
436      for index, v in enumerate(variables):
437        try:
438          numpy_value = var_is_initialized[index].numpy()
439        except errors.UnimplementedError:
440          # This is a variable on a parallel device; we'll extract its value on
441          # each replica and assert that they're identical.
442          components = parallel_device.unpack(var_is_initialized[index])
443          with ops.device(None):
444            components = array_ops.stack(components)
445            all_initialized = math_ops.reduce_all(components).numpy()
446            any_initialized = math_ops.reduce_any(components).numpy()
447          if all_initialized != any_initialized:
448            raise NotImplementedError(
449                ("Some but not all components of a parallel variable {} were "
450                 "initialized between their creation in a tf.function and "
451                 "the function's trace having completed. This is not yet "
452                 "supported; consider initializing either all or none of the "
453                 "components, or moving initialization out of the function."
454                ).format(repr(v)))
455          numpy_value = all_initialized
456        var_is_initialized[index] = numpy_value
457  return var_is_initialized
458
459
460class FunctionDeleter(object):
461
462  __slots__ = ["func_graph"]
463
464  def __init__(self, func_graph):
465    self.func_graph = func_graph
466
467  def __del__(self):
468    try:
469      func_graph_module.dismantle_func_graph(self.func_graph)
470    except:  # pylint: disable=bare-except
471      # Note: bare except here because this can be noisy at shutdown time.
472      pass
473
474
475class Function(object):
476  """Wrapper class for the graph functions defined for a Python function.
477
478  See the documentation for `tf.function` for more information on the semantics
479  of defined functions.
480
481  `Function` is thread-compatible.
482  """
483
484  def __init__(self,
485               python_function,
486               name,
487               input_signature=None,
488               autograph=True,
489               jit_compile=None,
490               experimental_implements=None,
491               experimental_autograph_options=None,
492               experimental_relax_shapes=False,
493               experimental_follow_type_hints=None):
494    """Initializes a `Function`.
495
496    Args:
497      python_function: the function to be wrapped.
498      name: the name given to it.
499      input_signature: See the documentation for `tf.function`.
500      autograph: See the documentation for `tf.function`.
501      jit_compile: See the documentation for `tf.function`.
502      experimental_implements: See the documentation for `tf.function`.
503      experimental_autograph_options: See the documentation for `tf.function`.
504      experimental_relax_shapes: See the documentation for `tf.function`.
505      experimental_follow_type_hints: See the documentation for `tf.function`.
506
507    Raises:
508      ValueError: if `input_signature` is not None and the `python_function`'s
509        argspec has keyword arguments.
510    """
511    self._lock = threading.Lock()
512    self._python_function = python_function
513    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
514        python_function,
515        input_signature,
516        jit_compile=jit_compile,
517        experimental_follow_type_hints=experimental_follow_type_hints,
518    )
519    self._implements = experimental_implements
520    # If `True`, the function uses the rendezvous of the parent. This is only
521    # needed to support code where raw send/recv operations are inserted and
522    # when functions are run in graph mode where they may not be inlined.
523    self._shared_rendezvous = None
524    self._autograph = autograph
525    self._experimental_autograph_options = experimental_autograph_options
526    self._experimental_relax_shapes = experimental_relax_shapes
527    self._jit_compile = jit_compile
528    if experimental_follow_type_hints is None:
529      experimental_follow_type_hints = False
530    self._experimental_follow_type_hints = experimental_follow_type_hints
531    self._created_variables = None  # GUARDED_BY(self._lock)
532    self._stateful_fn = None  # GUARDED_BY(self._lock)
533    self._stateless_fn = None  # GUARDED_BY(self._lock)
534    self._descriptor_cache = weakref.WeakKeyDictionary()
535    self._name = name
536    self._input_signature = input_signature
537    self._key_for_call_stats = self._get_key_for_call_stats()
538    self._omit_frequent_tracing_warning = False
539    ops._tf_function_api_guage.get_cell().set(True)  # pylint: disable=protected-access
540
541  def __getstate__(self):
542    """Custom pickling, to omit unpickleable objects."""
543    result = self.__dict__.copy()
544    del result["_lock"]
545    del result["_descriptor_cache"]
546    del result["_key_for_call_stats"]
547    return result
548
549  def __setstate__(self, state):
550    """Restore from pickled state."""
551    self.__dict__ = state
552    self._lock = threading.Lock()
553    self._descriptor_cache = weakref.WeakKeyDictionary()
554    self._key_for_call_stats = self._get_key_for_call_stats()
555
556  def _get_key_for_call_stats(self):
557    """Returns key instance to track call stats and retracings.
558
559    The key instance a best-effort to preserve global consistency.
560    """
561    target_function = self._python_function
562    # `__wrapped__` is a conventional Python attribute that a higher-order
563    # function keeps its original function's instance.  We also directly use
564    # this attribute for dealing with a class method.  See
565    # `bound_method_wrapper` in `function.py`.  If we don't use `__wrapped__`,
566    # all class methods will return the same `bound_method_wrapper` instance
567    # from this function.
568    while hasattr(target_function, "__wrapped__"):
569      target_function = target_function.__wrapped__
570
571    if hasattr(target_function, "__func__"):
572      target_function = target_function.__func__
573
574    if hasattr(target_function, "__code__"):
575      return target_function.__code__
576
577    return self._python_function
578
579  def _defun_with_scope(self, scope):
580    """Creates a defun wrapped inside a variable creator scope."""
581
582    weak_wrapped_fn = None
583    compile_with_xla = self._jit_compile
584
585    def wrapped_fn(*args, **kwds):
586      """Wraps `self._python_function` in a variable creator scope."""
587      # We register a variable creator with reduced priority. If an outer
588      # variable creator is just modifying keyword arguments to the variable
589      # constructor, this will work harmoniously. Since the `scope` registered
590      # here actually creates the variable, it taking priority would otherwise
591      # ignore the outer creator.
592      #
593      # If an outer variable creator calls the variable constructor manually,
594      # for example creating a MirroredVariable, then they won't call our
595      # creator. This means we won't be able to trace the initialization graph,
596      # and so variable initializers can't depend on function arguments. This is
597      # better than the alternative, tracing the initialization graph but giving
598      # the user a variable type they didn't want.
599      default_graph = ops.get_default_graph()
600      with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
601        # __wrapped__ allows AutoGraph to swap in a converted function. We give
602        # the function a weak reference to itself to avoid a reference cycle.
603        if compile_with_xla and \
604            not control_flow_util.GraphOrParentsInXlaContext(default_graph):
605          xla_context = control_flow_ops.XLAControlFlowContext()
606          try:
607            xla_context.Enter()
608            out = weak_wrapped_fn().__wrapped__(*args, **kwds)
609          finally:
610            xla_context.Exit()
611        else:
612          out = weak_wrapped_fn().__wrapped__(*args, **kwds)
613        return out
614
615    weak_wrapped_fn = weakref.ref(wrapped_fn)
616
617    return self._defun(tf_decorator.make_decorator(
618        self._python_function,
619        wrapped_fn))
620
621  def _create_implements_attribute(self):
622    """Creates the attribute value corresponding to IMPLEMENTS_ATTRIBUTE_NAME."""
623    attributes = {}
624    if isinstance(self._implements, str):
625      # First check if the IMPLEMENTS_ATTRIBUTE_NAME is specified as a
626      # NameAttrList. This is used when apart from the function name being
627      # implemented, a list of attributes is also being specified.
628      # The attributes are specified as key-value pairs in the NameAttrList
629      # of the corresponding AttrValue. The function name will be in the
630      # 'name' field of the NameAttrList. Else, it is just a string
631      # corresponding to the function name.
632      try:
633        implements_attr = six.ensure_text(self._implements, "utf-8")
634        attr_value = attr_value_pb2.AttrValue()
635        nameattrlist = attr_value_pb2.NameAttrList()
636        _text_format.Merge(implements_attr, nameattrlist)
637        attr_value.func.CopyFrom(nameattrlist)
638        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = attr_value
639      except (_text_format.ParseError, DecodeError):
640        attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
641    return attributes
642
643  def _defun(self, fn):
644    """Returns a defun generated from the input function."""
645    attributes = {}
646
647    if self._implements is not None:
648      attributes = self._create_implements_attribute()
649
650    share = self._shared_rendezvous
651    if share is not None:
652      attributes[function_lib.SHARED_RENDEZVOUS_ATTRIBUTE_NAME] = share
653
654    if self._jit_compile is not None:
655      attributes.update(_XlaMustCompile=bool(self._jit_compile))
656      if self._jit_compile:
657        attributes.update(_noinline=True)
658    if not attributes:
659      attributes = None
660    return function_lib.defun_with_attributes(
661        fn,
662        input_signature=self.input_signature,
663        attributes=attributes,
664        autograph=self._autograph,
665        jit_compile=self._jit_compile,
666        experimental_autograph_options=self._experimental_autograph_options,
667        experimental_follow_type_hints=self._experimental_follow_type_hints,
668        experimental_relax_shapes=self._experimental_relax_shapes)
669
670  def _initialize(self, args, kwds, add_initializers_to=None):
671    """Initializes, on the first call.
672
673    Creates two `Function`s, one that will allow creation of variables
674    and one that won't.
675
676    Additionally runs a trace for the `Function` that allows creation
677    of variables.
678
679    Args:
680      args: Arguments to the underlying python callable.
681      kwds: Keyword arguments to the python callable.
682      add_initializers_to: Where to collect variable initializers, if not None.
683    """
684
685    created_variables = []
686    lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
687
688    def variable_capturing_scope(unused_next_creator, **kwds):
689      """Creates UnliftedInitializerVariables and saves references to them."""
690      v = UnliftedInitializerVariable(
691          add_initializers_to=add_initializers_to,
692          lifted_initializer_graph=lifted_initializer_graph, **kwds)
693      created_variables.append(weakref.ref(v))
694      return v
695
696    self._created_variables = created_variables
697    self._stateful_fn = self._defun_with_scope(variable_capturing_scope)
698    self._stateful_fn._name = self._name  # pylint: disable=protected-access
699    # Force the definition of the function for these arguments
700    self._lifted_initializer_graph = lifted_initializer_graph
701    self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
702    self._concrete_stateful_fn = (
703        self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
704            *args, **kwds))
705
706    compiled = bool(self._jit_compile and
707                    not control_flow_util.GraphOrParentsInXlaContext(
708                        ops.get_default_graph()))
709    # For nested functions, increment the counter only when a function with
710    # jit_compile=True is called within a function with jit_compile=False. We
711    # count this special case to correctly record that both jit_compile=True and
712    # jit_compile=False is being used for parts of the outer function.
713    if ops.executing_eagerly_outside_functions() and (
714        context.executing_eagerly() or compiled):
715      # Labels must be strings in Python, so we convert 'compiled' to a string
716      _tf_function_counter.get_cell(str(int(compiled))).increase_by(1)
717
718    def invalid_creator_scope(*unused_args, **unused_kwds):
719      """Disables variable creation."""
720      raise ValueError(
721          "tf.function-decorated function tried to create "
722          "variables on non-first call.")
723
724    self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
725    self._stateless_fn._name = self._name  # pylint: disable=protected-access
726
727  def _clone(self, python_function):
728    """Clone the function with different python function."""
729    f = Function(
730        python_function=(self._python_function
731                         if python_function is None else python_function),
732        name=self._name,
733        input_signature=self._input_signature,
734        autograph=self._autograph,
735        jit_compile=self._jit_compile,
736        experimental_implements=self._implements,
737        experimental_autograph_options=self._experimental_autograph_options,
738        experimental_relax_shapes=self._experimental_relax_shapes,
739        experimental_follow_type_hints=self._experimental_follow_type_hints)
740
741    if self._shared_rendezvous:
742      f._shared_rendezvous = self._shared_rendezvous  # pylint: disable=protected-access
743
744    return f
745
746  def _decorate(self, decorator):
747    """Allows the captured Python function to be decorated in place.
748
749    This method is only safe to call when the Function has not been called by a
750    user. It makes sense to use this method to push a decorator into the
751    function rather than wrapping the function in the decorator.
752
753    We use this in tf.Module to allow user annotated `tf.functions` to remain as
754    `Function` objects but still automatically enter the Module name_scope
755    when they are evaluated like all other methods.
756
757    Args:
758      decorator: A callable accepting a single argument which is the function
759        to decorate and returning a callable result.
760
761    Raises:
762      ValueError: If the function has been called a ValueError is raised.
763    """
764    if self._stateful_fn is not None or self._stateless_fn is not None:
765      raise ValueError(
766          "Functions cannot be decorated after they have been traced.")
767
768    self._python_function = decorator(self._python_function)
769    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
770        self._python_function, self.input_signature)
771
772  # TODO: Remove this private method after updating all its uses
773  # A good moment to do this could be when the experimental label is removed
774  def _get_tracing_count(self):
775    return self.experimental_get_tracing_count()
776
777  def experimental_get_tracing_count(self):
778    """Returns the number of times the function has been traced.
779
780    For more information on when a function is traced and when it is
781    traced multiple times see https://www.tensorflow.org/guide/function.
782    Example:
783
784    >>> @tf.function
785    ... def double(a):
786    ...   return a + a
787    >>> double(tf.constant(1))
788    >>> double(tf.constant(2))
789    >>> double.experimental_get_tracing_count()
790    1
791    >>> double(tf.constant("a"))
792    >>> double.experimental_get_tracing_count()
793    2
794
795
796    The first time experimental_get_tracing_count is called
797    it returns 1, as the function is traced the first
798    time it is called, and the second time the same graph is used
799    since we're calling it with a parameter of the same type.
800
801    The second time experimental_get_tracing_count is called
802    it returns 2, as we called double with a
803    different argument type, and so it was traced again.
804
805    """
806    result = self._stateless_fn.tracing_count if self._stateless_fn else 0
807    result += self._stateful_fn.tracing_count if self._stateful_fn else 0
808    return result
809
810  @property
811  def _run_functions_eagerly(self):
812    return RUN_FUNCTIONS_EAGERLY
813
814  def __call__(self, *args, **kwds):
815    """Calls the graph function and warn too frequent tracings."""
816    if self._run_functions_eagerly:
817      with trace.Trace(self._name, tf_function_call="eager"):
818        return self._python_function(*args, **kwds)
819
820    tracing_count = self.experimental_get_tracing_count()
821    with trace.Trace(self._name) as tm:
822      result = self._call(*args, **kwds)
823      compiler = "xla" if self._jit_compile else "nonXla"
824      new_tracing_count = self.experimental_get_tracing_count()
825      without_tracing = (tracing_count == new_tracing_count)
826      execution_mode = "notTraced" if without_tracing else "traced"
827      tm.set_metadata(tf_function_call=execution_mode + "-" + compiler,
828                      tracing_count=new_tracing_count)
829
830    if context.executing_eagerly():
831      if without_tracing:
832        _frequent_tracing_detector_manager.called_without_tracing(
833            self._key_for_call_stats)
834      else:
835        _frequent_tracing_detector_manager.called_with_tracing(
836            self._key_for_call_stats, self._python_function,
837            self._omit_frequent_tracing_warning)
838
839    return result
840
841  def _call(self, *args, **kwds):
842    """Calls the graph function."""
843    self._lock.acquire()
844    if self._created_variables:
845      # Release the lock early so that multiple threads can perform the call
846      # in parallel.
847      self._lock.release()
848      # In this case we have created variables on the first call, so we run the
849      # defunned version which is guaranteed to never create variables.
850      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
851    elif self._stateful_fn is not None:
852      # Release the lock early so that multiple threads can perform the call
853      # in parallel.
854      self._lock.release()
855      # In this case we have not created variables on the first call. So we can
856      # run the first trace but we should fail if variables are created.
857      results = self._stateful_fn(*args, **kwds)
858      if self._created_variables:
859        raise ValueError("Creating variables on a non-first call to a function"
860                         " decorated with tf.function.")
861      return results
862
863    try:
864      # This is the first call of __call__, so we have to initialize.
865      initializers = []
866      self._initialize(args, kwds, add_initializers_to=initializers)
867    finally:
868      # At this point we know that the initialization is complete (or less
869      # interestingly an exception was raised) so we no longer need a lock.
870      self._lock.release()
871
872    if self._created_variables:
873      try:
874        # Attempt to initialize variables eagerly and without conds by lifting
875        # out initialization graphs. This is the only initialization strategy
876        # compatible with XLA at the moment.
877        self._initialize_uninitialized_variables(initializers)
878      except lift_to_graph.UnliftableError:
879        pass  # Fall through to cond-based initialization.
880      else:
881        # Lifting succeeded, so variables are initialized and we can run the
882        # stateless function.
883        return self._stateless_fn(*args, **kwds)
884    else:
885      _, _, _, filtered_flat_args = \
886          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
887              *args, **kwds)
888      # If we did not create any variables the trace we have is good enough.
889      return self._concrete_stateful_fn._call_flat(
890          filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
891
892    def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args):
893      """Conditionally runs initialization if it's needed."""
894      condition = True
895      for wr in self._created_variables:
896        variable = wr()
897        if variable is None:
898          raise ValueError(
899              "A tf.Variable created inside your tf.function has been"
900              " garbage-collected. Your code needs to keep Python references"
901              " to variables created inside `tf.function`s.\n"
902              "\n"
903              "A common way to raise this error is to create and return a"
904              " variable only referenced inside your function:\n"
905              "\n"
906              "@tf.function\n"
907              "def f():\n"
908              "  v = tf.Variable(1.0)\n"
909              "  return v\n"
910              "\n"
911              "v = f()  # Crashes with this error message!\n"
912              "\n"
913              "The reason this crashes is that @tf.function annotated"
914              " function returns a **`tf.Tensor`** with the **value** of the"
915              " variable when the function is called rather than the"
916              " variable instance itself. As such there is no code holding a"
917              " reference to the `v` created inside the function and Python"
918              " garbage collects it.\n"
919              "\n"
920              "The simplest way to fix this issue is to create variables"
921              " outside the function and capture them:\n"
922              "\n"
923              "v = tf.Variable(1.0)\n"
924              "\n"
925              "@tf.function\n"
926              "def f():\n"
927              "  return v\n"
928              "\n"
929              "f()  # <tf.Tensor: numpy=1.>\n"
930              "v.assign_add(1.)\n"
931              "f()  # <tf.Tensor: numpy=2.>")
932        condition = math_ops.logical_and(
933            condition, resource_variable_ops.var_is_initialized_op(
934                variable.handle))
935      # We want to call stateless_fn if possible because it avoids recomputing
936      # potentially expensive initializers.
937      return control_flow_ops.cond(
938          condition,
939          lambda: self._stateless_fn(*inner_args, **inner_kwds),
940          functools.partial(
941              self._concrete_stateful_fn._call_flat,  # pylint: disable=protected-access
942              inner_filtered_flat_args,
943              captured_inputs=self._concrete_stateful_fn.captured_inputs))
944
945    # We've created variables and are unable to lift the initialization graphs,
946    # so we fall back to initializing with conds while running the function.
947    canon_args, canon_kwds, _, filtered_flat_args = \
948        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
949            *args, **kwds)
950    return function_lib.defun(fn_with_cond)(canon_args, canon_kwds,
951                                            filtered_flat_args)
952
953  def experimental_get_compiler_ir(self, *args, **kwargs):
954    """Returns compiler IR for the compiled function.
955
956    This API is intended *only* for debugging as there are no guarantees on
957    backwards compatibility of returned IR or the allowed values of `stage`.
958
959    Args:
960      *args: Arguments used for compilation; same arguments as used for calling
961        the function. Need to be eager tensors.
962      **kwargs: Keyword arguments used for compilation.
963
964    Returns:
965      Function callable with the following kwargs:
966        - `stage` at which the compiler IR should be serialized. Allowed values
967          are:
968           - `hlo`: HLO output after conversion from TF
969            (https://www.tensorflow.org/xla/operation_semantics).
970           - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized
971             HLO module proto (a bytes object).
972           - `optimized_hlo`: HLO after compiler optimizations.
973           - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the
974             output is a serialized HLO module proto (a bytes object).
975           - `optimized_hlo_dot`: optimized HLO in DOT format suitable for
976             Graphviz.
977        - `device_name` can be either None, in which case the preferred device
978          is used for compilation, or a device name. It can be a full device
979          name, or a partial one, e.g., `/device:CPU:0`.
980
981      For example, for
982
983      ```python
984      @tf.function(jit_compile=True)
985      def f(x):
986        return x + 1
987
988      f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
989      ```
990
991      the output is:
992
993      ```
994      HloModule a_inference_f_13__.9
995
996      ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
997        %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
998        %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
999        %constant.3 = f32[] constant(1)
1000        %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
1001        %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
1002                                     f32[10,10]{1,0} %broadcast.4)
1003        %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
1004        %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
1005        ROOT %get-tuple-element.8 = f32[10,10]{1,0}
1006          get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
1007      }
1008      ```
1009
1010    Raises:
1011      ValueError: If an invalid `stage` is selected or if applied to a function
1012        which is not compiled (`jit_compile=True` is not set).
1013      TypeError: When called with input in graph mode.
1014    """
1015    context.ensure_initialized()
1016    if not self._jit_compile:
1017      raise ValueError("Compiler IR can only be returned for functions marked "
1018                       "with 'jit_compile=True'")
1019
1020    concrete_fn = self.get_concrete_function(*args, **kwargs)
1021    fn_name = concrete_fn.name
1022
1023    # pylint: disable=protected-access
1024    _, _, _, filtered_flat_args = \
1025        concrete_fn._function_spec.canonicalize_function_inputs(
1026            *args, **kwargs)
1027
1028    def compiler_ir_generator(stage="hlo", device_name=None):
1029      # TODO(cheshire): This is a hack to get the current "preferred" device,
1030      # there is no current API to get it otherwise.
1031      if device_name is None:
1032        device_name = random_ops.random_normal([]).device
1033      res_bytes = context.context().get_compiler_ir(
1034          device_name=device_name,
1035          stage=stage,
1036          function_name=fn_name,
1037          args=list(filtered_flat_args) + concrete_fn.captured_inputs)
1038      if stage in ("hlo_serialized", "optimized_hlo_serialized"):
1039        return res_bytes
1040      else:
1041        return res_bytes.decode("utf-8")
1042
1043    return compiler_ir_generator
1044
1045  @property
1046  def python_function(self):
1047    """The python function wrapped in this tf.function."""
1048    return self._python_function
1049
1050  @property
1051  def input_signature(self):
1052    return self._function_spec.input_signature
1053
1054  @property
1055  def function_spec(self):
1056    return self._function_spec
1057
1058  def pretty_printed_concrete_signatures(self, verbose=True):
1059    joiner = "\n\n" if verbose else "\n"
1060    return joiner.join([
1061        c.pretty_printed_signature(verbose=verbose)
1062        for c in self._list_all_concrete_functions()
1063    ])
1064
1065  def _initialize_uninitialized_variables(self, initializers):
1066    """Make and call a `ConcreteFunction` which initializes variables."""
1067
1068    if not initializers:
1069      return
1070
1071    var_is_initialized = _evaluate_var_is_initialized(
1072        [v for v, _ in initializers])
1073
1074    # Note: using defun here avoids an infinite recursion.
1075    # Most of the code in this function runs eagerly with init_scope, where
1076    # autograph is not necessary.
1077    @function_lib.defun(autograph=False)
1078    def initialize_variables():
1079      op_map = object_identity.ObjectIdentityDictionary()
1080
1081      inits = []
1082      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1083        with ops.init_scope():
1084          if is_initialized:
1085            continue
1086        inits.append(init)
1087
1088      if inits:
1089        op_map = lift_to_graph.lift_to_graph(
1090            inits, ops.get_default_graph(), op_map=op_map)
1091      for (v, init), is_initialized in zip(initializers, var_is_initialized):
1092        with ops.init_scope():
1093          if is_initialized:
1094            continue
1095        v.assign(op_map[init], read_value=False)
1096
1097    with ops.init_scope():
1098      return initialize_variables.get_concrete_function()()
1099
1100  def get_initialization_function(self, *args, **kwargs):
1101    """Returns a `ConcreteFunction` which initializes this function's variables.
1102
1103    Requires that this function hasn't been accessed yet through either calling
1104    it or calling get_concrete_function. Fails if we cannot build an initializer
1105    function which does not depend on the concrete values of the inputs to this
1106    function.
1107
1108    Note that running this function will overwrite any values currently assigned
1109    to variables, for example restores from a checkpoint.
1110
1111    Args:
1112      *args: arguments to the underlying python callable.
1113      **kwargs: keyword arguments to the python callable.
1114
1115    Returns:
1116      A `ConcreteFunction` object which initializes the variables of this
1117      function.
1118
1119    Raises:
1120      RuntimeError: if called after the variables have been initialized.
1121    """
1122    with self._lock:
1123      if self._stateful_fn is not None:
1124        raise RuntimeError(
1125            "get_initialization_function cannot be called after the function "
1126            "has been used")
1127      # Here we trace the function, collect the initializers, and attempt to
1128      # extract them and run them eagerly. Fail only if we cannot do so.
1129      initializers = []
1130      self._initialize(args, kwargs, add_initializers_to=initializers)
1131
1132    # Note: using defun here avoids an infinite recursion.
1133    @function_lib.defun
1134    def initialize_variables():
1135      for v, init in initializers:
1136        v.assign(
1137            lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init],
1138            read_value=False)
1139
1140    return initialize_variables.get_concrete_function()
1141
1142  def _list_all_concrete_functions(self):
1143    """Returns all concrete functions."""
1144    if self.input_signature is not None:
1145      self.get_concrete_function()
1146    concrete_functions = []
1147    # pylint: disable=protected-access
1148    if self._stateful_fn:
1149      concrete_functions.extend(
1150          self._stateful_fn._function_cache.all_values())
1151    if self._stateless_fn:
1152      concrete_functions.extend(
1153          self._stateless_fn._function_cache.all_values())
1154    # pylint: enable=protected-access
1155    return concrete_functions
1156
1157  def _list_all_concrete_functions_for_serialization(self):
1158    """Returns all concrete functions for serialization.
1159
1160    Returns:
1161      A list of instances of `ConcreteFunction`.
1162    """
1163    concrete_functions = self._list_all_concrete_functions()
1164    seen_signatures = []
1165    for concrete_function in concrete_functions:
1166      signature = concrete_function.structured_input_signature
1167      flattened = nest.flatten(signature)
1168      if any(
1169          isinstance(arg, func_graph_module.UnknownArgument)
1170          for arg in flattened):
1171        logging.info("Unsupported signature for serialization: %s.", signature)
1172        continue
1173      equal_to_signature = functools.partial(
1174          function_lib.is_same_structure, signature, check_values=True)
1175      if not any(equal_to_signature(s) for s in seen_signatures):
1176        seen_signatures.append(signature)
1177
1178    # Re-create concrete functions for these signatures. Re-creating ensures
1179    # that if the cache key has changed, the function will be traced again.
1180    concrete_functions = []
1181    for args, kwargs in seen_signatures:
1182      concrete_functions.append(self.get_concrete_function(*args, **kwargs))
1183    return concrete_functions
1184
1185  def _get_concrete_function_garbage_collected(self, *args, **kwargs):
1186    """Returns a `ConcreteFunction` specialized to inputs and execution context.
1187
1188    Unlike `get_concrete_function(...)`, the graph will be deleted when the
1189    returned function is deleted.  It's useful to avoid creating a reference
1190    cycle when you know for sure that the graph will be no longer used without
1191    the returned function.
1192
1193    Args:
1194      *args: inputs to specialize on.
1195      **kwargs: inputs to specialize on.
1196
1197    Returns:
1198      A TensorFlow function which takes exactly one `tf.Tensor` per argument.
1199
1200    Raises:
1201      ValueError: if this object has not yet been called on concrete values.
1202    """
1203    with self._lock:
1204      if self._stateful_fn is None:
1205        initializers = []
1206        self._initialize(args, kwargs, add_initializers_to=initializers)
1207        self._initialize_uninitialized_variables(initializers)
1208
1209    if self._created_variables:
1210      # In this case we have created variables on the first call, so we run the
1211      # defunned version which is guaranteed to never create variables.
1212      return self._stateless_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1213          *args, **kwargs)
1214    elif self._stateful_fn is not None:
1215      # In this case we have not created variables on the first call. So we can
1216      # run the first trace but we should fail if variables are created.
1217      concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
1218          *args, **kwargs)
1219      if self._created_variables:
1220        raise ValueError("Creating variables on a non-first call to a function"
1221                         " decorated with tf.function.")
1222      return concrete
1223
1224  def get_concrete_function(self, *args, **kwargs):
1225    """Returns a `ConcreteFunction` specialized to inputs and execution context.
1226
1227    If this `Function` was created with an `input_signature`, `args` and
1228    `kwargs` may be omitted. With an input signature there is only one
1229    concrete function associated with this `Function`.
1230
1231    If there is no fixed `input_signature` associated with this
1232    `Function`, positional and keyword arguments to `get_concrete_function`
1233    follow the same rules as input signature specification, with `tf.TensorSpec`
1234    objects describing `tf.Tensor`s which will be passed to the concrete
1235    function.
1236
1237    Each `tf.Tensor` argument to the concrete function must have a unique name,
1238    either because it is the only one associated with a named argument of the
1239    Python function or because an explicit `name=` was passed to its
1240    `tf.TensorSpec` object. These names become the argument names for the
1241    concrete function.
1242
1243    Arguments to the concrete function may always be specified as keyword
1244    arguments, naming the Tensor input. Positional arguments may be used instead
1245    when each preceding argument to the Python function is a Tensor.
1246
1247    ```python
1248    @tf.function
1249    def f(x):
1250      return x
1251
1252    f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
1253    f_concrete(tf.constant(1.))
1254    f_concrete(x=tf.constant(1.))
1255    ```
1256
1257    Nested structures containing Tensors may be specified when retrieving
1258    concrete functions. Structures with multiple Tensors are expanded into
1259    multiple arguments of the concrete function. Since multiple concrete
1260    function arguments are associated with one argument to the original
1261    function, these Tensors must be named explicitly. Tensors in nested
1262    structures may not be passed using positional arguments when calling the
1263    concrete function.
1264
1265    ```python
1266    f_concrete2 = f.get_concrete_function(
1267        (tf.TensorSpec(None, tf.float64, name="first"),
1268         tf.TensorSpec([], tf.float32, name="second")))
1269    # Keyword arguments are required when identifying Tensors in nested
1270    # structures.
1271    f_concrete2(first=tf.constant([1.]), second=tf.constant(0.))
1272    ```
1273
1274    Functions with fixed input signatures have only one concrete function
1275    associated with them, which can be retrieved without specifying any
1276    arguments. As before Tensors must have unique names, either inferred from
1277    the argument names in the original Python function or specified
1278    explicitly.
1279
1280    ```python
1281    @tf.function(input_signature=(tf.TensorSpec(None, tf.float32)))
1282    def f_sig(y):
1283      return y
1284
1285    f_sig_concrete = f.get_concrete_function()
1286    f_sig_concrete(tf.constant(1.))
1287    f_sig_concrete(y=tf.constant(1.))
1288    ```
1289
1290    Args:
1291      *args: inputs to specialize on.
1292      **kwargs: inputs to specialize on.
1293
1294    Returns:
1295      A TensorFlow function which takes exactly one `tf.Tensor` per argument.
1296
1297    Raises:
1298      ValueError: if this object has not yet been called on concrete values.
1299    """
1300    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1301    concrete._garbage_collector.release()  # pylint: disable=protected-access
1302    return concrete
1303
1304  def __get__(self, instance, owner):
1305    """Makes it possible to defun instance methods."""
1306    del owner
1307    # `instance` here is the instance that this `Function` was accessed through
1308    # e.g., for
1309    #
1310    #   class Foo(object):
1311    #
1312    #     @function.defun
1313    #     def bar(self):
1314    #       ...
1315    #
1316    #   foo = Foo()
1317    #   foo.bar()  # `foo.bar` is a `Function` instance
1318    #
1319    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
1320    # new instance of `Function` here to allow different instances each
1321    # to create variables once, thereby allowing methods to be decorated with
1322    # tf.function. Keeps a cache to avoid retracing the function every time the
1323    # descriptor is accessed.
1324    if instance not in self._descriptor_cache:
1325      if instance is None:
1326        return self
1327      self._descriptor_cache[instance] = (
1328          function_lib.class_method_to_instance_method(self, instance))
1329    return self._descriptor_cache[instance]
1330
1331
1332@tf_export("function")
1333@deprecation.deprecated_args(None,
1334                             "experimental_compile is deprecated, use "
1335                             "jit_compile instead", "experimental_compile")
1336def function(func=None,
1337             input_signature=None,
1338             autograph=True,
1339             jit_compile=None,
1340             experimental_implements=None,
1341             experimental_autograph_options=None,
1342             experimental_relax_shapes=False,
1343             experimental_compile=None,
1344             experimental_follow_type_hints=None):
1345  """Compiles a function into a callable TensorFlow graph.
1346
1347  `tf.function` constructs a callable that executes a TensorFlow graph
1348  (`tf.Graph`) created by trace-compiling the TensorFlow operations in `func`,
1349  effectively executing `func` as a TensorFlow graph.
1350
1351  Example usage:
1352
1353  >>> @tf.function
1354  ... def f(x, y):
1355  ...   return x ** 2 + y
1356  >>> x = tf.constant([2, 3])
1357  >>> y = tf.constant([3, -2])
1358  >>> f(x, y)
1359  <tf.Tensor: ... numpy=array([7, 7], ...)>
1360
1361  _Features_
1362
1363  `func` may use data-dependent control flow, including `if`, `for`, `while`
1364  `break`, `continue` and `return` statements:
1365
1366  >>> @tf.function
1367  ... def f(x):
1368  ...   if tf.reduce_sum(x) > 0:
1369  ...     return x * x
1370  ...   else:
1371  ...     return -x // 2
1372  >>> f(tf.constant(-2))
1373  <tf.Tensor: ... numpy=1>
1374
1375  `func`'s closure may include `tf.Tensor` and `tf.Variable` objects:
1376
1377  >>> @tf.function
1378  ... def f():
1379  ...   return x ** 2 + y
1380  >>> x = tf.constant([-2, -3])
1381  >>> y = tf.Variable([3, -2])
1382  >>> f()
1383  <tf.Tensor: ... numpy=array([7, 7], ...)>
1384
1385  `func` may also use ops with side effects, such as `tf.print`, `tf.Variable`
1386  and others:
1387
1388  >>> v = tf.Variable(1)
1389  >>> @tf.function
1390  ... def f(x):
1391  ...   for i in tf.range(x):
1392  ...     v.assign_add(i)
1393  >>> f(3)
1394  >>> v
1395  <tf.Variable ... numpy=4>
1396
1397  Important: Any Python side-effects (appending to a list, printing with
1398  `print`, etc) will only happen once, when `func` is traced. To have
1399  side-effects executed into your `tf.function` they need to be written
1400  as TF ops:
1401
1402  >>> l = []
1403  >>> @tf.function
1404  ... def f(x):
1405  ...   for i in x:
1406  ...     l.append(i + 1)    # Caution! Will only happen once when tracing
1407  >>> f(tf.constant([1, 2, 3]))
1408  >>> l
1409  [<tf.Tensor ...>]
1410
1411  Instead, use TensorFlow collections like `tf.TensorArray`:
1412
1413  >>> @tf.function
1414  ... def f(x):
1415  ...   ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
1416  ...   for i in range(len(x)):
1417  ...     ta = ta.write(i, x[i] + 1)
1418  ...   return ta.stack()
1419  >>> f(tf.constant([1, 2, 3]))
1420  <tf.Tensor: ..., numpy=array([2, 3, 4], ...)>
1421
1422  _`tf.function` is polymorphic_
1423
1424  Internally, `tf.function` can build more than one graph, to support arguments
1425  with different data types or shapes, since TensorFlow can build more
1426  efficient graphs that are specialized on shapes and dtypes. `tf.function`
1427  also treats any pure Python value as opaque objects, and builds a separate
1428  graph for each set of Python arguments that it encounters.
1429
1430  To obtain an individual graph, use the `get_concrete_function` method of
1431  the callable created by `tf.function`. It can be called with the same
1432  arguments as `func` and returns a special `tf.Graph` object:
1433
1434  >>> @tf.function
1435  ... def f(x):
1436  ...   return x + 1
1437  >>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
1438  True
1439
1440  Caution: Passing python scalars or lists as arguments to `tf.function` will
1441  always build a new graph. To avoid this, pass numeric arguments as Tensors
1442  whenever possible:
1443
1444  >>> @tf.function
1445  ... def f(x):
1446  ...   return tf.abs(x)
1447  >>> f1 = f.get_concrete_function(1)
1448  >>> f2 = f.get_concrete_function(2)  # Slow - builds new graph
1449  >>> f1 is f2
1450  False
1451  >>> f1 = f.get_concrete_function(tf.constant(1))
1452  >>> f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
1453  >>> f1 is f2
1454  True
1455
1456  Python numerical arguments should only be used when they take few distinct
1457  values, such as hyperparameters like the number of layers in a neural network.
1458
1459  _Input signatures_
1460
1461  For Tensor arguments, `tf.function` instantiates a separate graph for every
1462  unique set of input shapes and datatypes. The example below creates two
1463  separate graphs, each specialized to a different shape:
1464
1465  >>> @tf.function
1466  ... def f(x):
1467  ...   return x + 1
1468  >>> vector = tf.constant([1.0, 1.0])
1469  >>> matrix = tf.constant([[3.0]])
1470  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1471  False
1472
1473  An "input signature" can be optionally provided to `tf.function` to control
1474  the graphs traced. The input signature specifies the shape and type of each
1475  Tensor argument to the function using a `tf.TensorSpec` object. More general
1476  shapes can be used. This is useful to avoid creating multiple graphs when
1477  Tensors have dynamic shapes. It also restricts the shape and datatype of
1478  Tensors that can be used:
1479
1480  >>> @tf.function(
1481  ...     input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
1482  ... def f(x):
1483  ...   return x + 1
1484  >>> vector = tf.constant([1.0, 1.0])
1485  >>> matrix = tf.constant([[3.0]])
1486  >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
1487  True
1488
1489  _Variables may only be created once_
1490
1491  `tf.function` only allows creating new `tf.Variable` objects when it is called
1492  for the first time:
1493
1494  >>> class MyModule(tf.Module):
1495  ...   def __init__(self):
1496  ...     self.v = None
1497  ...
1498  ...   @tf.function
1499  ...   def __call__(self, x):
1500  ...     if self.v is None:
1501  ...       self.v = tf.Variable(tf.ones_like(x))
1502  ...     return self.v * x
1503
1504  In general, it is recommended to create stateful objects like `tf.Variable`
1505  outside of `tf.function` and passing them as arguments.
1506
1507  _Using type annotations to improve performance_
1508
1509  'experimental_follow_type_hints` can be used along with type annotations to
1510  improve performance by reducing the number of expensive graph retracings.
1511  For example, an argument annotated with `tf.Tensor` is converted to Tensor
1512  even when the input is a non-Tensor value.
1513
1514  >>> @tf.function(experimental_follow_type_hints=True)
1515  ... def f_with_hints(x: tf.Tensor):
1516  ...   print('Tracing')
1517  ...   return x
1518  >>> @tf.function(experimental_follow_type_hints=False)
1519  ... def f_no_hints(x: tf.Tensor):
1520  ...   print('Tracing')
1521  ...   return x
1522  >>> f_no_hints(1)
1523  Tracing
1524  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1525  >>> f_no_hints(2)
1526  Tracing
1527  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1528  >>> f_with_hints(1)
1529  Tracing
1530  <tf.Tensor: shape=(), dtype=int32, numpy=1>
1531  >>> f_with_hints(2)
1532  <tf.Tensor: shape=(), dtype=int32, numpy=2>
1533
1534  Args:
1535    func: the function to be compiled. If `func` is None, `tf.function` returns
1536      a decorator that can be invoked with a single argument - `func`. In other
1537      words, `tf.function(input_signature=...)(func)` is equivalent to
1538      `tf.function(func, input_signature=...)`. The former can be used as
1539      decorator.
1540    input_signature: A possibly nested sequence of `tf.TensorSpec` objects
1541      specifying the shapes and dtypes of the Tensors that will be supplied to
1542      this function. If `None`, a separate function is instantiated for each
1543      inferred input signature.  If input_signature is specified, every input to
1544      `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
1545    autograph: Whether autograph should be applied on `func` before tracing a
1546      graph. Data-dependent control flow requires `autograph=True`. For more
1547      information, see the [tf.function and AutoGraph guide](
1548      https://www.tensorflow.org/guide/function).
1549    jit_compile: If `True`, compiles the function using
1550      [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
1551      such as fusion, and attempts to emit more efficient code. This may
1552      drastically improve the performance. If set to `True`,
1553      the whole function needs to be compilable by XLA, or an
1554      `errors.InvalidArgumentError` is thrown.
1555      If `None` (default), compiles the function with XLA when running on TPU
1556      and goes through the regular function execution path when running on
1557      other devices.
1558      If `False`, executes the function without XLA compilation.  Set this value
1559      to `False` when directly running a multi-device function on TPUs (e.g. two
1560      TPU cores, one TPU core and its host CPU).
1561      Not all functions are compilable, see a list of
1562      [sharp corners](https://tensorflow.org/xla/known_issues).
1563    experimental_implements: If provided, contains a name of a "known" function
1564      this implements. For example "mycompany.my_recurrent_cell".
1565      This is stored as an attribute in inference function,
1566      which can then be detected when processing serialized function.
1567      See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)  # pylint: disable=line-too-long
1568      for details.  For an example of utilizing this attribute see this
1569      [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
1570      The code above automatically detects and substitutes function that
1571      implements "embedded_matmul" and allows TFLite to substitute its own
1572      implementations. For instance, a tensorflow user can use this
1573       attribute to mark that their function also implements
1574      `embedded_matmul` (perhaps more efficiently!)
1575      by specifying it using this parameter:
1576      `@tf.function(experimental_implements="embedded_matmul")`
1577      This can either be specified as just the string name of the function or
1578      a NameAttrList corresponding to a list of key-value attributes associated
1579      with the function name. The name of the function will be in the 'name'
1580      field of the NameAttrList. To define a formal TF op for this function
1581      implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
1582      project.
1583    experimental_autograph_options: Optional tuple of
1584      `tf.autograph.experimental.Feature` values.
1585    experimental_relax_shapes: When True, `tf.function` may generate fewer,
1586      graphs that are less specialized on input shapes.
1587    experimental_compile: Deprecated alias to 'jit_compile'.
1588    experimental_follow_type_hints: When True, the function may use type
1589      annotations from `func` to optimize the tracing performance. For example,
1590      arguments annotated with `tf.Tensor` will automatically be converted
1591      to a Tensor.
1592
1593  Returns:
1594     If `func` is not None, returns a callable that will execute the compiled
1595     function (and return zero or more `tf.Tensor` objects).
1596     If `func` is None, returns a decorator that, when invoked with a single
1597     `func` argument, returns a callable equivalent to the case above.
1598
1599  Raises:
1600     ValueError when attempting to use jit_compile=True, but XLA support is not
1601     linked.
1602  """
1603  # TODO(mdan): Link to `tf.types` section once published.
1604  if input_signature is not None:
1605    function_lib.validate_signature(input_signature)
1606  if experimental_follow_type_hints is None:
1607    experimental_follow_type_hints = False
1608
1609  def decorated(inner_function):
1610    try:
1611      name = inner_function.__name__
1612    except AttributeError:
1613      name = "function"
1614    return tf_decorator.make_decorator(
1615        inner_function,
1616        decorator_name="tf.function",
1617        decorator_func=Function(
1618            inner_function,
1619            name,
1620            input_signature=input_signature,
1621            autograph=autograph,
1622            experimental_autograph_options=experimental_autograph_options,
1623            experimental_relax_shapes=experimental_relax_shapes,
1624
1625            # TODO(b/171825496): Update once `experimental_compile` is removed
1626            # entirely in favor of 'jit_compile'.
1627            jit_compile=deprecation.deprecated_argument_lookup(
1628                "jit_compile",
1629                jit_compile,
1630                "experimental_compile",
1631                experimental_compile),
1632            experimental_implements=experimental_implements,
1633            experimental_follow_type_hints=experimental_follow_type_hints))
1634
1635  # This code path is for the `foo = tf.function(foo, ...)` use case
1636  if func is not None:
1637    return decorated(func)
1638
1639  # This code path is for the
1640  #
1641  # @tf.function(...)
1642  # def foo(...):
1643  #    ...
1644  #
1645  # use case, which is equivalent to `foo = tf.function(...)(foo)`
1646  return decorated
1647